ultralytic export

This commit is contained in:
Siwat Sirichai 2025-11-11 01:28:19 +07:00
parent bf7b68edb1
commit fdaeb9981c
14 changed files with 2241 additions and 507 deletions

3
.gitignore vendored
View file

@ -5,4 +5,5 @@ __pycache__/
.claude .claude
/models/ /models/
/tracked_objects.json /tracked_objects.json
.trt_cache .trt_cache
.ultralytics_cache

View file

@ -10,4 +10,8 @@
- Buffer should flush after TARGET_FRAME_INTERVAL_MS - Buffer should flush after TARGET_FRAME_INTERVAL_MS
- Blurry asyncio archtecture, require documentations - Blurry asyncio archtecture, require documentations
- Each engine cache to its own random unconcentrated folder
- Workspace for YOLO is fixed to 4GB, why is that?, what is it?

View file

@ -21,10 +21,11 @@ Usage:
import argparse import argparse
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Tuple, List, Optional from typing import List, Optional, Tuple
import torch
import tensorrt as trt
import numpy as np import numpy as np
import tensorrt as trt
import torch
class TensorRTConverter: class TensorRTConverter:
@ -39,7 +40,7 @@ class TensorRTConverter:
verbose: Enable verbose logging verbose: Enable verbose logging
""" """
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.device = torch.device(f'cuda:{gpu_id}') self.device = torch.device(f"cuda:{gpu_id}")
# TensorRT logger # TensorRT logger
log_level = trt.Logger.VERBOSE if verbose else trt.Logger.WARNING log_level = trt.Logger.VERBOSE if verbose else trt.Logger.WARNING
@ -71,13 +72,15 @@ class TensorRTConverter:
raise FileNotFoundError(f"Model file not found: {model_path}") raise FileNotFoundError(f"Model file not found: {model_path}")
# Load model (weights_only=False for models with custom classes) # Load model (weights_only=False for models with custom classes)
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) checkpoint = torch.load(
model_path, map_location=self.device, weights_only=False
)
# Handle different checkpoint formats # Handle different checkpoint formats
if isinstance(checkpoint, dict): if isinstance(checkpoint, dict):
if 'model' in checkpoint: if "model" in checkpoint:
model = checkpoint['model'] model = checkpoint["model"]
elif 'state_dict' in checkpoint: elif "state_dict" in checkpoint:
# Need model architecture - this is a limitation # Need model architecture - this is a limitation
raise ValueError( raise ValueError(
"Checkpoint contains only state_dict. " "Checkpoint contains only state_dict. "
@ -95,9 +98,15 @@ class TensorRTConverter:
print(f"✓ Model loaded successfully") print(f"✓ Model loaded successfully")
return model return model
def export_to_onnx(self, model: torch.nn.Module, input_shape: Tuple[int, ...], def export_to_onnx(
onnx_path: str, dynamic_batch: bool = False, self,
input_names: List[str] = None, output_names: List[str] = None) -> str: model: torch.nn.Module,
input_shape: Tuple[int, ...],
onnx_path: str,
dynamic_batch: bool = False,
input_names: List[str] = None,
output_names: List[str] = None,
) -> str:
""" """
Export PyTorch model to ONNX format (intermediate step). Export PyTorch model to ONNX format (intermediate step).
@ -118,9 +127,9 @@ class TensorRTConverter:
# Default names # Default names
if input_names is None: if input_names is None:
input_names = ['input'] input_names = ["input"]
if output_names is None: if output_names is None:
output_names = ['output'] output_names = ["output"]
# Create dummy input # Create dummy input
dummy_input = torch.randn(*input_shape, device=self.device) dummy_input = torch.randn(*input_shape, device=self.device)
@ -128,10 +137,7 @@ class TensorRTConverter:
# Dynamic axes configuration # Dynamic axes configuration
dynamic_axes = None dynamic_axes = None
if dynamic_batch: if dynamic_batch:
dynamic_axes = { dynamic_axes = {input_names[0]: {0: "batch"}, output_names[0]: {0: "batch"}}
input_names[0]: {0: 'batch'},
output_names[0]: {0: 'batch'}
}
# Export to ONNX # Export to ONNX
torch.onnx.export( torch.onnx.export(
@ -143,16 +149,22 @@ class TensorRTConverter:
dynamic_axes=dynamic_axes, dynamic_axes=dynamic_axes,
opset_version=17, # Use recent ONNX opset opset_version=17, # Use recent ONNX opset
do_constant_folding=True, do_constant_folding=True,
verbose=False verbose=False,
) )
print(f"✓ ONNX model exported to {onnx_path}") print(f"✓ ONNX model exported to {onnx_path}")
return onnx_path return onnx_path
def build_tensorrt_engine_from_onnx(self, onnx_path: str, engine_path: str, def build_tensorrt_engine_from_onnx(
fp16: bool = False, int8: bool = False, self,
max_workspace_size: int = 4, onnx_path: str,
min_batch: int = 1, opt_batch: int = 1, max_batch: int = 1) -> str: engine_path: str,
fp16: bool = False,
int8: bool = False,
min_batch: int = 1,
opt_batch: int = 1,
max_batch: int = 1,
) -> str:
""" """
Build TensorRT engine from ONNX model. Build TensorRT engine from ONNX model.
@ -161,7 +173,6 @@ class TensorRTConverter:
engine_path: Output path for TensorRT engine engine_path: Output path for TensorRT engine
fp16: Enable FP16 precision fp16: Enable FP16 precision
int8: Enable INT8 precision (requires calibration) int8: Enable INT8 precision (requires calibration)
max_workspace_size: Maximum workspace size in GB
min_batch: Minimum batch size for optimization min_batch: Minimum batch size for optimization
opt_batch: Optimal batch size for optimization opt_batch: Optimal batch size for optimization
max_batch: Maximum batch size for optimization max_batch: Maximum batch size for optimization
@ -171,7 +182,6 @@ class TensorRTConverter:
""" """
print(f"\nBuilding TensorRT engine from ONNX...") print(f"\nBuilding TensorRT engine from ONNX...")
print(f"Precision: FP{'16' if fp16 else '32'}{' + INT8' if int8 else ''}") print(f"Precision: FP{'16' if fp16 else '32'}{' + INT8' if int8 else ''}")
print(f"Workspace size: {max_workspace_size} GB")
# Create builder and network # Create builder and network
builder = trt.Builder(self.logger) builder = trt.Builder(self.logger)
@ -182,7 +192,7 @@ class TensorRTConverter:
# Parse ONNX model # Parse ONNX model
print(f"Loading ONNX file from {onnx_path}...") print(f"Loading ONNX file from {onnx_path}...")
with open(onnx_path, 'rb') as f: with open(onnx_path, "rb") as f:
if not parser.parse(f.read()): if not parser.parse(f.read()):
print("ERROR: Failed to parse the ONNX file:") print("ERROR: Failed to parse the ONNX file:")
for error in range(parser.num_errors): for error in range(parser.num_errors):
@ -206,12 +216,6 @@ class TensorRTConverter:
# Create builder config # Create builder config
config = builder.create_builder_config() config = builder.create_builder_config()
# Set workspace size
config.set_memory_pool_limit(
trt.MemoryPoolType.WORKSPACE,
max_workspace_size * (1 << 30) # GB to bytes
)
# Enable precision modes # Enable precision modes
if fp16: if fp16:
if not builder.platform_has_fast_fp16: if not builder.platform_has_fast_fp16:
@ -226,7 +230,9 @@ class TensorRTConverter:
else: else:
config.set_flag(trt.BuilderFlag.INT8) config.set_flag(trt.BuilderFlag.INT8)
print("✓ INT8 mode enabled") print("✓ INT8 mode enabled")
print("Note: INT8 calibration not implemented. Results may be suboptimal.") print(
"Note: INT8 calibration not implemented. Results may be suboptimal."
)
# Set optimization profile for dynamic shapes # Set optimization profile for dynamic shapes
if max_batch > 1 or min_batch != max_batch: if max_batch > 1 or min_batch != max_batch:
@ -260,7 +266,7 @@ class TensorRTConverter:
# Save engine to file # Save engine to file
print(f"Saving engine to {engine_path}...") print(f"Saving engine to {engine_path}...")
with open(engine_path, 'wb') as f: with open(engine_path, "wb") as f:
f.write(serialized_engine) f.write(serialized_engine)
# Get file size # Get file size
@ -270,15 +276,19 @@ class TensorRTConverter:
return engine_path return engine_path
def convert(self, model_path: str, output_path: str, def convert(
input_shape: Tuple[int, ...] = (1, 3, 640, 640), self,
fp16: bool = False, int8: bool = False, model_path: str,
dynamic_batch: bool = False, output_path: str,
max_batch: int = 16, input_shape: Tuple[int, ...] = (1, 3, 640, 640),
workspace_size: int = 4, fp16: bool = False,
input_names: List[str] = None, int8: bool = False,
output_names: List[str] = None, dynamic_batch: bool = False,
keep_onnx: bool = False) -> str: max_batch: int = 16,
input_names: List[str] = None,
output_names: List[str] = None,
keep_onnx: bool = False,
) -> str:
""" """
Convert PyTorch or ONNX model to TensorRT engine. Convert PyTorch or ONNX model to TensorRT engine.
@ -290,7 +300,6 @@ class TensorRTConverter:
int8: Enable INT8 precision int8: Enable INT8 precision
dynamic_batch: Enable dynamic batch size dynamic_batch: Enable dynamic batch size
max_batch: Maximum batch size (for dynamic batching) max_batch: Maximum batch size (for dynamic batching)
workspace_size: TensorRT workspace size in GB
input_names: Custom input names (for PyTorch export) input_names: Custom input names (for PyTorch export)
output_names: Custom output names (for PyTorch export) output_names: Custom output names (for PyTorch export)
keep_onnx: Keep intermediate ONNX file keep_onnx: Keep intermediate ONNX file
@ -304,7 +313,7 @@ class TensorRTConverter:
# Check if input is already ONNX # Check if input is already ONNX
model_path_obj = Path(model_path) model_path_obj = Path(model_path)
is_onnx = model_path_obj.suffix.lower() == '.onnx' is_onnx = model_path_obj.suffix.lower() == ".onnx"
if is_onnx: if is_onnx:
# Direct ONNX to TensorRT conversion # Direct ONNX to TensorRT conversion
@ -319,10 +328,9 @@ class TensorRTConverter:
engine_path=output_path, engine_path=output_path,
fp16=fp16, fp16=fp16,
int8=int8, int8=int8,
max_workspace_size=workspace_size,
min_batch=min_batch, min_batch=min_batch,
opt_batch=opt_batch, opt_batch=opt_batch,
max_batch=max_batch_size max_batch=max_batch_size,
) )
print(f"\n{'=' * 80}") print(f"\n{'=' * 80}")
@ -350,7 +358,7 @@ class TensorRTConverter:
onnx_path=onnx_path, onnx_path=onnx_path,
dynamic_batch=dynamic_batch, dynamic_batch=dynamic_batch,
input_names=input_names, input_names=input_names,
output_names=output_names output_names=output_names,
) )
# Step 3: Build TensorRT engine # Step 3: Build TensorRT engine
@ -363,10 +371,9 @@ class TensorRTConverter:
engine_path=output_path, engine_path=output_path,
fp16=fp16, fp16=fp16,
int8=int8, int8=int8,
max_workspace_size=workspace_size,
min_batch=min_batch, min_batch=min_batch,
opt_batch=opt_batch, opt_batch=opt_batch,
max_batch=max_batch_size max_batch=max_batch_size,
) )
print(f"\n{'=' * 80}") print(f"\n{'=' * 80}")
@ -392,7 +399,7 @@ class TensorRTConverter:
def parse_shape(shape_str: str) -> Tuple[int, ...]: def parse_shape(shape_str: str) -> Tuple[int, ...]:
"""Parse shape string like '1,3,640,640' to tuple""" """Parse shape string like '1,3,640,640' to tuple"""
try: try:
return tuple(int(x) for x in shape_str.split(',')) return tuple(int(x) for x in shape_str.split(","))
except ValueError: except ValueError:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError(
f"Invalid shape format: {shape_str}. Expected format: 1,3,640,640" f"Invalid shape format: {shape_str}. Expected format: 1,3,640,640"
@ -426,97 +433,82 @@ Examples:
# Keep intermediate ONNX file for debugging # Keep intermediate ONNX file for debugging
python convert_pt_to_tensorrt.py --model model.pt --output model.trt \\ python convert_pt_to_tensorrt.py --model model.pt --output model.trt \\
--keep-onnx --keep-onnx
""" """,
) )
# Required arguments # Required arguments
parser.add_argument( parser.add_argument(
'--model', '-m', "--model",
"-m",
type=str, type=str,
required=True, required=True,
help='Path to PyTorch model file (.pt or .pth)' help="Path to PyTorch model file (.pt or .pth)",
) )
parser.add_argument( parser.add_argument(
'--output', '-o', "--output",
"-o",
type=str, type=str,
required=True, required=True,
help='Output path for TensorRT engine (.trt or .engine)' help="Output path for TensorRT engine (.trt or .engine)",
) )
# Optional arguments # Optional arguments
parser.add_argument( parser.add_argument(
'--input-shape', '-s', "--input-shape",
"-s",
type=parse_shape, type=parse_shape,
default=(1, 3, 640, 640), default=(1, 3, 640, 640),
help='Input tensor shape as B,C,H,W (default: 1,3,640,640)' help="Input tensor shape as B,C,H,W (default: 1,3,640,640)",
) )
parser.add_argument( parser.add_argument(
'--fp16', "--fp16",
action='store_true', action="store_true",
help='Enable FP16 precision (faster inference, slightly lower accuracy)' help="Enable FP16 precision (faster inference, slightly lower accuracy)",
) )
parser.add_argument( parser.add_argument(
'--int8', "--int8",
action='store_true', action="store_true",
help='Enable INT8 precision (fastest, requires calibration)' help="Enable INT8 precision (fastest, requires calibration)",
) )
parser.add_argument( parser.add_argument(
'--dynamic-batch', "--dynamic-batch", action="store_true", help="Enable dynamic batch size support"
action='store_true',
help='Enable dynamic batch size support'
) )
parser.add_argument( parser.add_argument(
'--max-batch', "--max-batch",
type=int, type=int,
default=16, default=16,
help='Maximum batch size for dynamic batching (default: 16)' help="Maximum batch size for dynamic batching (default: 16)",
) )
parser.add_argument( parser.add_argument("--gpu", type=int, default=0, help="GPU device ID (default: 0)")
'--workspace-size',
type=int,
default=4,
help='TensorRT workspace size in GB (default: 4)'
)
parser.add_argument( parser.add_argument(
'--gpu', "--input-names",
type=int,
default=0,
help='GPU device ID (default: 0)'
)
parser.add_argument(
'--input-names',
type=str, type=str,
nargs='+', nargs="+",
default=None, default=None,
help='Custom input tensor names (default: ["input"])' help='Custom input tensor names (default: ["input"])',
) )
parser.add_argument( parser.add_argument(
'--output-names', "--output-names",
type=str, type=str,
nargs='+', nargs="+",
default=None, default=None,
help='Custom output tensor names (default: ["output"])' help='Custom output tensor names (default: ["output"])',
) )
parser.add_argument( parser.add_argument(
'--keep-onnx', "--keep-onnx", action="store_true", help="Keep intermediate ONNX file"
action='store_true',
help='Keep intermediate ONNX file'
) )
parser.add_argument( parser.add_argument(
'--verbose', '-v', "--verbose", "-v", action="store_true", help="Enable verbose logging"
action='store_true',
help='Enable verbose logging'
) )
args = parser.parse_args() args = parser.parse_args()
@ -542,10 +534,9 @@ Examples:
int8=args.int8, int8=args.int8,
dynamic_batch=args.dynamic_batch, dynamic_batch=args.dynamic_batch,
max_batch=args.max_batch, max_batch=args.max_batch,
workspace_size=args.workspace_size,
input_names=args.input_names, input_names=args.input_names,
output_names=args.output_names, output_names=args.output_names,
keep_onnx=args.keep_onnx keep_onnx=args.keep_onnx,
) )
print("\n✓ Conversion successful!") print("\n✓ Conversion successful!")
@ -554,6 +545,7 @@ Examples:
print(f"\n✗ Conversion failed: {e}") print(f"\n✗ Conversion failed: {e}")
if args.verbose: if args.verbose:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
sys.exit(1) sys.exit(1)

View file

@ -2,38 +2,67 @@
Services package for RTSP stream processing with GPU acceleration. Services package for RTSP stream processing with GPU acceleration.
""" """
from .stream_decoder import StreamDecoderFactory, StreamDecoder, ConnectionStatus from .base_model_controller import BaseModelController, BatchFrame, BufferState
from .inference_engine import (
BackendType,
EngineMetadata,
IInferenceEngine,
NativeTensorRTEngine,
UltralyticsEngine,
create_engine,
)
from .jpeg_encoder import JPEGEncoderFactory, encode_frame_to_jpeg from .jpeg_encoder import JPEGEncoderFactory, encode_frame_to_jpeg
from .model_repository import TensorRTModelRepository, ModelMetadata, ExecutionContext, SharedEngine from .model_repository import (
from .tracking_controller import ObjectTracker, TrackedObject, Detection ExecutionContext,
from .yolo import YOLOv8Utils, COCO_CLASSES ModelMetadata,
from .model_controller import ModelController, BatchFrame, BufferState SharedEngine,
from .stream_connection_manager import StreamConnectionManager, StreamConnection, TrackingResult TensorRTModelRepository,
)
from .modelstorage import FileModelStorage, IModelStorage
from .pt_converter import PTConverter from .pt_converter import PTConverter
from .modelstorage import IModelStorage, FileModelStorage from .stream_connection_manager import (
StreamConnection,
StreamConnectionManager,
TrackingResult,
)
from .stream_decoder import ConnectionStatus, StreamDecoder, StreamDecoderFactory
from .tensorrt_model_controller import TensorRTModelController
from .tracking_controller import Detection, ObjectTracker, TrackedObject
from .ultralytics_exporter import UltralyticsExporter
from .ultralytics_model_controller import UltralyticsModelController
from .yolo import COCO_CLASSES, YOLOv8Utils
__all__ = [ __all__ = [
'StreamDecoderFactory', "StreamDecoderFactory",
'StreamDecoder', "StreamDecoder",
'ConnectionStatus', "ConnectionStatus",
'JPEGEncoderFactory', "JPEGEncoderFactory",
'encode_frame_to_jpeg', "encode_frame_to_jpeg",
'TensorRTModelRepository', "TensorRTModelRepository",
'ModelMetadata', "ModelMetadata",
'ExecutionContext', "ExecutionContext",
'SharedEngine', "SharedEngine",
'ObjectTracker', "ObjectTracker",
'TrackedObject', "TrackedObject",
'Detection', "Detection",
'YOLOv8Utils', "YOLOv8Utils",
'COCO_CLASSES', "COCO_CLASSES",
'ModelController', "BaseModelController",
'BatchFrame', "TensorRTModelController",
'BufferState', "UltralyticsModelController",
'StreamConnectionManager', "BatchFrame",
'StreamConnection', "BufferState",
'TrackingResult', "StreamConnectionManager",
'PTConverter', "StreamConnection",
'IModelStorage', "TrackingResult",
'FileModelStorage', "PTConverter",
"IModelStorage",
"FileModelStorage",
"IInferenceEngine",
"NativeTensorRTEngine",
"UltralyticsEngine",
"EngineMetadata",
"BackendType",
"create_engine",
"UltralyticsExporter",
] ]

View file

@ -0,0 +1,324 @@
"""
Base Model Controller - Abstract base class for batched inference controllers.
Provides ping-pong buffer architecture with force-switch timeout mechanism.
Implementations handle backend-specific inference (TensorRT, Ultralytics, etc.).
"""
import logging
import threading
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
import torch
logger = logging.getLogger(__name__)
@dataclass
class BatchFrame:
"""Represents a frame in the batch buffer"""
stream_id: str
frame: torch.Tensor # GPU tensor (3, H, W)
timestamp: float
metadata: Dict = field(default_factory=dict)
class BufferState(Enum):
"""State of a ping-pong buffer"""
IDLE = "idle"
FILLING = "filling"
PROCESSING = "processing"
class BaseModelController(ABC):
"""
Abstract base class for batched inference with ping-pong buffers.
This controller accumulates frames from multiple streams into batches,
processes them through an inference backend, and routes results back to
stream-specific callbacks.
Features:
- Ping-pong circular buffers (BufferA/BufferB)
- Force-switch timeout to prevent batch starvation
- Event-driven processing with callbacks
- Thread-safe frame submission
Subclasses must implement:
- _run_batch_inference(): Backend-specific inference logic
"""
def __init__(
self,
model_id: str,
batch_size: int = 16,
force_timeout: float = 0.05,
preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None,
):
self.model_id = model_id
self.batch_size = batch_size
self.force_timeout = force_timeout
self.preprocess_fn = preprocess_fn
self.postprocess_fn = postprocess_fn
# Ping-pong buffers
self.buffer_a: List[BatchFrame] = []
self.buffer_b: List[BatchFrame] = []
# Buffer states
self.active_buffer = "A"
self.buffer_a_state = BufferState.IDLE
self.buffer_b_state = BufferState.IDLE
# Threading coordination
self.buffer_lock = threading.RLock()
self.last_submit_time = time.time()
# Threads
self.timeout_thread: Optional[threading.Thread] = None
self.processor_threads: Dict[str, threading.Thread] = {}
self.running = False
self.stop_event = threading.Event()
# Result callbacks (stream_id -> callback)
self.result_callbacks: Dict[str, Callable] = {}
# Statistics
self.total_frames_processed = 0
self.total_batches_processed = 0
def start(self):
"""Start the controller background threads"""
if self.running:
logger.warning("ModelController already running")
return
self.running = True
self.stop_event.clear()
# Start timeout monitor thread
self.timeout_thread = threading.Thread(
target=self._timeout_monitor, daemon=True
)
self.timeout_thread.start()
# Start processor threads for each buffer
self.processor_threads["A"] = threading.Thread(
target=self._batch_processor, args=("A",), daemon=True
)
self.processor_threads["B"] = threading.Thread(
target=self._batch_processor, args=("B",), daemon=True
)
self.processor_threads["A"].start()
self.processor_threads["B"].start()
logger.info(f"{self.__class__.__name__} started")
def stop(self):
"""Stop the controller and cleanup"""
if not self.running:
return
logger.info(f"Stopping {self.__class__.__name__}...")
self.running = False
self.stop_event.set()
# Wait for threads to finish
if self.timeout_thread and self.timeout_thread.is_alive():
self.timeout_thread.join(timeout=2.0)
for thread in self.processor_threads.values():
if thread and thread.is_alive():
thread.join(timeout=2.0)
# Process any remaining frames
self._process_remaining_buffers()
logger.info(f"{self.__class__.__name__} stopped")
def register_callback(self, stream_id: str, callback: Callable):
"""Register a callback for inference results from a stream"""
self.result_callbacks[stream_id] = callback
logger.debug(f"Registered callback for stream: {stream_id}")
def unregister_callback(self, stream_id: str):
"""Unregister a stream callback"""
self.result_callbacks.pop(stream_id, None)
logger.debug(f"Unregistered callback for stream: {stream_id}")
def submit_frame(
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
):
"""Submit a frame for batched inference"""
with self.buffer_lock:
batch_frame = BatchFrame(
stream_id=stream_id,
frame=frame,
timestamp=time.time(),
metadata=metadata or {},
)
# Add to active buffer
if self.active_buffer == "A":
self.buffer_a.append(batch_frame)
self.buffer_a_state = BufferState.FILLING
buffer_size = len(self.buffer_a)
else:
self.buffer_b.append(batch_frame)
self.buffer_b_state = BufferState.FILLING
buffer_size = len(self.buffer_b)
self.last_submit_time = time.time()
# Check if we should immediately swap (batch full)
if buffer_size >= self.batch_size:
self._try_swap_buffers()
def _timeout_monitor(self):
"""Monitor force-switch timeout"""
while self.running and not self.stop_event.wait(0.01):
with self.buffer_lock:
time_since_submit = time.time() - self.last_submit_time
if time_since_submit >= self.force_timeout:
active_buffer = (
self.buffer_a if self.active_buffer == "A" else self.buffer_b
)
if len(active_buffer) > 0:
self._try_swap_buffers()
def _try_swap_buffers(self):
"""Attempt to swap ping-pong buffers (called with buffer_lock held)"""
inactive_state = (
self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
)
if inactive_state != BufferState.PROCESSING:
old_active = self.active_buffer
self.active_buffer = "B" if old_active == "A" else "A"
if old_active == "A":
self.buffer_a_state = BufferState.PROCESSING
buffer_size = len(self.buffer_a)
else:
self.buffer_b_state = BufferState.PROCESSING
buffer_size = len(self.buffer_b)
logger.debug(
f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})"
)
def _batch_processor(self, buffer_name: str):
"""Background thread that processes a specific buffer when available"""
while self.running and not self.stop_event.is_set():
time.sleep(0.001)
with self.buffer_lock:
if buffer_name == "A":
should_process = self.buffer_a_state == BufferState.PROCESSING
else:
should_process = self.buffer_b_state == BufferState.PROCESSING
if should_process:
self._process_buffer(buffer_name)
def _process_buffer(self, buffer_name: str):
"""Process a buffer through inference"""
# Extract buffer to process
with self.buffer_lock:
if buffer_name == "A":
batch = self.buffer_a.copy()
self.buffer_a.clear()
else:
batch = self.buffer_b.copy()
self.buffer_b.clear()
if len(batch) == 0:
with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
self.buffer_b_state = BufferState.IDLE
return
# Process batch (outside lock to allow concurrent submissions)
try:
start_time = time.time()
results = self._run_batch_inference(batch)
inference_time = time.time() - start_time
self.total_frames_processed += len(batch)
self.total_batches_processed += 1
logger.debug(
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms "
f"({inference_time * 1000 / len(batch):.2f}ms per frame)"
)
# Emit results to callbacks
for batch_frame, result in zip(batch, results):
callback = self.result_callbacks.get(batch_frame.stream_id)
if callback:
try:
callback(result)
except Exception as e:
logger.error(
f"Error in callback for {batch_frame.stream_id}: {e}",
exc_info=True,
)
except Exception as e:
logger.error(f"Error processing batch: {e}", exc_info=True)
finally:
with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
self.buffer_b_state = BufferState.IDLE
@abstractmethod
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""
Run inference on a batch of frames (backend-specific).
Args:
batch: List of BatchFrame objects
Returns:
List of detection results (one per frame)
"""
pass
def _process_remaining_buffers(self):
"""Process any remaining frames in buffers during shutdown"""
if len(self.buffer_a) > 0:
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
self._process_buffer("A")
if len(self.buffer_b) > 0:
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
self._process_buffer("B")
def get_stats(self) -> Dict[str, Any]:
"""Get current buffer statistics"""
return {
"active_buffer": self.active_buffer,
"buffer_a_size": len(self.buffer_a),
"buffer_b_size": len(self.buffer_b),
"buffer_a_state": self.buffer_a_state.value,
"buffer_b_state": self.buffer_b_state.value,
"registered_streams": len(self.result_callbacks),
"total_frames_processed": self.total_frames_processed,
"total_batches_processed": self.total_batches_processed,
"avg_batch_size": (
self.total_frames_processed / self.total_batches_processed
if self.total_batches_processed > 0
else 0
),
}

View file

@ -0,0 +1,635 @@
"""
Inference Engine Abstraction Layer
Provides a unified interface for different inference backends:
- Native TensorRT: Direct TensorRT API with zero-copy GPU tensors
- Ultralytics: YOLO models with built-in pre/postprocessing
- Future: ONNX Runtime, OpenVINO, etc.
All engines support zero-copy GPU tensor inference where possible.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import torch
class BackendType(Enum):
"""Supported inference backend types"""
TENSORRT = "tensorrt"
ULTRALYTICS = "ultralytics"
@classmethod
def from_string(cls, backend: str) -> "BackendType":
"""Convert string to BackendType"""
backend = backend.lower()
for member in cls:
if member.value == backend:
return member
raise ValueError(
f"Unknown backend: {backend}. Available: {[m.value for m in cls]}"
)
@dataclass
class EngineMetadata:
"""Metadata for an inference engine"""
engine_type: str # "tensorrt", "ultralytics", etc.
model_path: str
input_shapes: Dict[str, Tuple[int, ...]]
output_shapes: Dict[str, Tuple[int, ...]]
input_names: List[str]
output_names: List[str]
input_dtypes: Dict[str, torch.dtype]
output_dtypes: Dict[str, torch.dtype]
supports_batching: bool = True
supports_dynamic_shapes: bool = False
extra_info: Dict[str, Any] = None # Backend-specific info
class IInferenceEngine(ABC):
"""
Abstract interface for inference engines.
All implementations must support zero-copy GPU tensor inference:
- Inputs: CUDA tensors on GPU
- Outputs: CUDA tensors on GPU
- No CPU transfers during inference
"""
@abstractmethod
def initialize(
self, model_path: str, device: torch.device, **kwargs
) -> EngineMetadata:
"""
Initialize the inference engine.
Automatically detects model type and handles conversion if needed.
Args:
model_path: Path to model file (.pt, .engine, .trt)
device: GPU device to use
**kwargs: Optional parameters (batch_size, half, workspace, etc.)
Returns:
EngineMetadata with model information
"""
pass
@abstractmethod
def infer(
self, inputs: Dict[str, torch.Tensor], **kwargs
) -> Dict[str, torch.Tensor]:
"""
Run inference on GPU tensors (zero-copy).
Args:
inputs: Dict of input_name -> CUDA tensor
**kwargs: Backend-specific inference parameters
Returns:
Dict of output_name -> CUDA tensor
Raises:
ValueError: If inputs are not CUDA tensors or wrong shape
"""
pass
@abstractmethod
def get_metadata(self) -> EngineMetadata:
"""Get engine metadata"""
pass
@abstractmethod
def cleanup(self):
"""Cleanup resources"""
pass
@property
@abstractmethod
def is_initialized(self) -> bool:
"""Check if engine is initialized"""
pass
@property
@abstractmethod
def device(self) -> torch.device:
"""Get device the engine is running on"""
pass
class NativeTensorRTEngine(IInferenceEngine):
"""
Native TensorRT inference engine with direct API access.
Features:
- Zero-copy GPU tensor inference
- Execution context pooling for concurrent inference
- Support for .trt, .engine files
- Automatic Ultralytics .engine metadata stripping
"""
def __init__(self):
self._engine = None
self._contexts = []
self._metadata = None
self._device = None
self._trt_logger = None
def initialize(
self, model_path: str, device: torch.device, num_contexts: int = 1, **kwargs
) -> EngineMetadata:
"""
Initialize TensorRT engine.
Args:
model_path: Path to .trt or .engine file
device: GPU device
num_contexts: Number of execution contexts for pooling
Returns:
EngineMetadata
"""
import tensorrt as trt
self._device = device
self._trt_logger = trt.Logger(trt.Logger.WARNING)
# Load engine
runtime = trt.Runtime(self._trt_logger)
# Read engine file (handle Ultralytics format)
engine_data = self._load_engine_data(model_path)
self._engine = runtime.deserialize_cuda_engine(engine_data)
if self._engine is None:
raise RuntimeError(f"Failed to load TensorRT engine from {model_path}")
# Create execution contexts
for i in range(num_contexts):
ctx = self._engine.create_execution_context()
if ctx is None:
raise RuntimeError(f"Failed to create execution context {i}")
self._contexts.append(ctx)
# Extract metadata
self._metadata = self._extract_metadata(model_path)
return self._metadata
def _load_engine_data(self, file_path: str) -> bytes:
"""Load engine data, stripping Ultralytics metadata if present"""
import json
with open(file_path, "rb") as f:
# Try to read Ultralytics metadata header
meta_len_bytes = f.read(4)
if len(meta_len_bytes) == 4:
meta_len = int.from_bytes(meta_len_bytes, byteorder="little")
# Sanity check
if 0 < meta_len < 100000:
try:
metadata_bytes = f.read(meta_len)
json.loads(metadata_bytes.decode("utf-8"))
# Valid Ultralytics metadata, rest is engine
return f.read()
except (UnicodeDecodeError, json.JSONDecodeError):
pass
# Not Ultralytics format, read entire file
f.seek(0)
return f.read()
def _extract_metadata(self, model_path: str) -> EngineMetadata:
"""Extract metadata from TensorRT engine"""
import tensorrt as trt
input_shapes = {}
output_shapes = {}
input_names = []
output_names = []
input_dtypes = {}
output_dtypes = {}
trt_to_torch_dtype = {
trt.DataType.FLOAT: torch.float32,
trt.DataType.HALF: torch.float16,
trt.DataType.INT8: torch.int8,
trt.DataType.INT32: torch.int32,
trt.DataType.BOOL: torch.bool,
}
for i in range(self._engine.num_io_tensors):
name = self._engine.get_tensor_name(i)
shape = tuple(self._engine.get_tensor_shape(name))
dtype = trt_to_torch_dtype.get(
self._engine.get_tensor_dtype(name), torch.float32
)
mode = self._engine.get_tensor_mode(name)
if mode == trt.TensorIOMode.INPUT:
input_names.append(name)
input_shapes[name] = shape
input_dtypes[name] = dtype
else:
output_names.append(name)
output_shapes[name] = shape
output_dtypes[name] = dtype
return EngineMetadata(
engine_type="tensorrt",
model_path=model_path,
input_shapes=input_shapes,
output_shapes=output_shapes,
input_names=input_names,
output_names=output_names,
input_dtypes=input_dtypes,
output_dtypes=output_dtypes,
supports_batching=True,
supports_dynamic_shapes=False,
)
def infer(
self,
inputs: Dict[str, torch.Tensor],
context_id: int = 0,
stream: Optional[torch.cuda.Stream] = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""
Run TensorRT inference with zero-copy GPU tensors.
Args:
inputs: Dict of input_name -> CUDA tensor
context_id: Which execution context to use
stream: CUDA stream for async execution
Returns:
Dict of output_name -> CUDA tensor
"""
if not self.is_initialized:
raise RuntimeError("Engine not initialized")
# Validate inputs
for name in self._metadata.input_names:
if name not in inputs:
raise ValueError(f"Missing required input: {name}")
if not inputs[name].is_cuda:
raise ValueError(f"Input '{name}' must be a CUDA tensor")
# Get execution context
if context_id >= len(self._contexts):
raise ValueError(
f"Invalid context_id {context_id}, only {len(self._contexts)} contexts available"
)
context = self._contexts[context_id]
# Prepare outputs
outputs = {}
# Set input tensor addresses
for name in self._metadata.input_names:
input_tensor = inputs[name].contiguous()
context.set_tensor_address(name, input_tensor.data_ptr())
# Allocate and set output tensors
for name in self._metadata.output_names:
output_tensor = torch.empty(
self._metadata.output_shapes[name],
dtype=self._metadata.output_dtypes[name],
device=self._device,
)
outputs[name] = output_tensor
context.set_tensor_address(name, output_tensor.data_ptr())
# Execute
if stream is None:
stream = torch.cuda.Stream(device=self._device)
with torch.cuda.stream(stream):
success = context.execute_async_v3(stream_handle=stream.cuda_stream)
if not success:
raise RuntimeError("TensorRT inference failed")
stream.synchronize()
return outputs
def get_metadata(self) -> EngineMetadata:
"""Get engine metadata"""
if self._metadata is None:
raise RuntimeError("Engine not initialized")
return self._metadata
def cleanup(self):
"""Cleanup TensorRT resources"""
for ctx in self._contexts:
del ctx
self._contexts.clear()
if self._engine is not None:
del self._engine
self._engine = None
self._metadata = None
@property
def is_initialized(self) -> bool:
return self._engine is not None
@property
def device(self) -> torch.device:
return self._device
class UltralyticsEngine(IInferenceEngine):
"""
Ultralytics YOLO inference engine.
Features:
- Zero-copy GPU tensor inference
- Built-in preprocessing/postprocessing for YOLO models
- Supports .pt, .engine formats
- Automatic model export to TensorRT with caching
"""
def __init__(self):
self._model = None
self._metadata = None
self._device = None
self._model_path = None
self._exporter = None
def initialize(
self,
model_path: str,
device: torch.device,
batch: int = 1,
half: bool = False,
imgsz: int = 640,
cache_dir: str = ".ultralytics_cache",
**kwargs,
) -> EngineMetadata:
"""
Initialize Ultralytics YOLO model.
Automatically exports .pt models to .engine format with caching.
Args:
model_path: Path to .pt or .engine file
device: GPU device
batch: Maximum batch size for inference
half: Use FP16 precision
imgsz: Input image size
cache_dir: Directory for caching exported engines
**kwargs: Additional export parameters
Returns:
EngineMetadata
"""
from ultralytics import YOLO
from .ultralytics_exporter import UltralyticsExporter
self._device = device
self._model_path = model_path
# Check if we need to export
model_file = Path(model_path)
final_model_path = model_path
if model_file.suffix == ".pt":
# Use exporter with caching
print(f"Checking for cached TensorRT engine...")
self._exporter = UltralyticsExporter(cache_dir=cache_dir)
_, engine_path = self._exporter.export(
model_path=str(model_path),
device=device.index if device.type == "cuda" else 0,
half=half,
imgsz=imgsz,
batch=batch,
**kwargs,
)
final_model_path = engine_path
print(f"Using TensorRT engine: {engine_path}")
# Load model (Ultralytics handles .engine files natively)
self._model = YOLO(final_model_path)
# Move to device if needed (only for .pt models, .engine already on specific device)
if hasattr(self._model, "model") and self._model.model is not None:
# Check if it's actually a torch model (not a string path for .engine files)
if hasattr(self._model.model, "to"):
self._model.model = self._model.model.to(device)
# Extract metadata
self._metadata = self._extract_metadata()
return self._metadata
def _extract_metadata(self) -> EngineMetadata:
"""Extract metadata from Ultralytics model"""
# Ultralytics models typically expect (B, 3, H, W) input
# and return Results objects, not raw tensors
# Default values
batch_size = -1 # Dynamic batching by default
imgsz = 640
input_shape = (batch_size, 3, imgsz, imgsz)
if hasattr(self._model, "model") and self._model.model is not None:
# Try to get actual input shape from model
try:
# For .engine files, check predictor model
if (
hasattr(self._model, "predictor")
and self._model.predictor is not None
):
predictor = self._model.predictor
# Get image size
if hasattr(predictor, "args") and hasattr(predictor.args, "imgsz"):
imgsz_val = predictor.args.imgsz
if isinstance(imgsz_val, (list, tuple)):
h, w = (
imgsz_val[0],
imgsz_val[1] if len(imgsz_val) > 1 else imgsz_val[0],
)
else:
h = w = imgsz_val
imgsz = h # Use height as reference
# Get batch size from model
if hasattr(predictor, "model"):
pred_model = predictor.model
# For TensorRT engines, check input bindings
if hasattr(pred_model, "bindings"):
# This is a TensorRT AutoBackend
try:
# Get first input binding shape
if hasattr(pred_model, "input_shape"):
shape = pred_model.input_shape
if shape and len(shape) >= 4:
batch_size = shape[0] if shape[0] > 0 else -1
except:
pass
# Try batch attribute
if batch_size == -1 and hasattr(pred_model, "batch"):
batch_size = (
pred_model.batch if pred_model.batch > 0 else -1
)
# Fallback: check model args
if hasattr(self._model.model, "args"):
imgsz_val = getattr(self._model.model.args, "imgsz", 640)
if isinstance(imgsz_val, (list, tuple)):
h, w = (
imgsz_val[0],
imgsz_val[1] if len(imgsz_val) > 1 else imgsz_val[0],
)
else:
h = w = imgsz_val
imgsz = h
input_shape = (batch_size, 3, imgsz, imgsz)
except Exception as e:
logger.warning(f"Could not extract full metadata: {e}")
pass
return EngineMetadata(
engine_type="ultralytics",
model_path=self._model_path,
input_shapes={"images": input_shape},
output_shapes={"results": (-1,)}, # Dynamic, depends on detections
input_names=["images"],
output_names=["results"],
input_dtypes={"images": torch.float32},
output_dtypes={"results": torch.float32},
supports_batching=True,
supports_dynamic_shapes=(batch_size == -1),
extra_info={
"is_yolo": True,
"has_builtin_postprocess": True,
"batch_size": batch_size,
"imgsz": imgsz,
},
)
def infer(
self,
inputs: Dict[str, torch.Tensor],
return_raw: bool = False,
conf: float = 0.25,
iou: float = 0.45,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""
Run Ultralytics inference with zero-copy GPU tensors.
Args:
inputs: Dict with "images" key -> CUDA tensor (B, 3, H, W), normalized [0, 1]
return_raw: If True, return raw tensor output. If False, return Results objects
conf: Confidence threshold
iou: IoU threshold for NMS
Returns:
Dict with inference results
Note:
Input tensor should be normalized to [0, 1] range.
Format: (B, 3, H, W) in RGB color space.
"""
if not self.is_initialized:
raise RuntimeError("Engine not initialized")
# Get input tensor
if "images" not in inputs:
raise ValueError("Input must contain 'images' key")
images = inputs["images"]
if not images.is_cuda:
raise ValueError("Input must be a CUDA tensor")
# Ensure tensor is on correct device
if images.device != self._device:
images = images.to(self._device)
# Run inference
results = self._model(images, conf=conf, iou=iou, verbose=False, **kwargs)
# Return results
# Note: Ultralytics returns Results objects, not raw tensors
# For compatibility, we wrap them in a dict
return {
"results": results,
"raw_predictions": results[0].boxes.data
if len(results) > 0 and hasattr(results[0], "boxes")
else None,
}
def get_metadata(self) -> EngineMetadata:
"""Get engine metadata"""
if self._metadata is None:
raise RuntimeError("Engine not initialized")
return self._metadata
def cleanup(self):
"""Cleanup Ultralytics model"""
if self._model is not None:
del self._model
self._model = None
self._metadata = None
@property
def is_initialized(self) -> bool:
return self._model is not None
@property
def device(self) -> torch.device:
return self._device
def create_engine(backend: str | BackendType, **kwargs) -> IInferenceEngine:
"""
Factory function to create inference engine.
Args:
backend: Backend type (BackendType enum or string: "tensorrt", "ultralytics")
**kwargs: Engine-specific arguments
Returns:
IInferenceEngine instance
Example:
>>> from services import create_engine, BackendType
>>> engine = create_engine(BackendType.TENSORRT)
>>> engine = create_engine("ultralytics")
"""
# Convert string to BackendType if needed
if isinstance(backend, str):
backend = BackendType.from_string(backend)
engines = {
BackendType.TENSORRT: NativeTensorRTEngine,
BackendType.ULTRALYTICS: UltralyticsEngine,
}
if backend not in engines:
raise ValueError(
f"Unknown backend: {backend}. Available: {[b.value for b in BackendType]}"
)
return engines[backend]()

View file

@ -5,21 +5,22 @@ This module provides batched inference coordination using ping-pong circular buf
with force-switch timeout mechanism using threading and callbacks. with force-switch timeout mechanism using threading and callbacks.
""" """
import threading
import torch
from typing import Dict, List, Optional, Callable, Any
from dataclasses import dataclass, field
from enum import Enum
import time
import logging import logging
import queue import queue
import threading
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
import torch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass @dataclass
class BatchFrame: class BatchFrame:
"""Represents a frame in the batch buffer"""
stream_id: str stream_id: str
frame: torch.Tensor # GPU tensor (3, H, W) frame: torch.Tensor # GPU tensor (3, H, W)
timestamp: float timestamp: float
@ -28,6 +29,7 @@ class BatchFrame:
class BufferState(Enum): class BufferState(Enum):
"""State of a ping-pong buffer""" """State of a ping-pong buffer"""
IDLE = "idle" IDLE = "idle"
FILLING = "filling" FILLING = "filling"
PROCESSING = "processing" PROCESSING = "processing"
@ -80,7 +82,9 @@ class ModelController:
f"Will process frames sequentially. Consider rebuilding model with dynamic batching." f"Will process frames sequentially. Consider rebuilding model with dynamic batching."
) )
else: else:
logger.info(f"Model '{model_id}' supports batch_size={self.model_batch_size}") logger.info(
f"Model '{model_id}' supports batch_size={self.model_batch_size}"
)
# Ping-pong buffers # Ping-pong buffers
self.buffer_a: List[BatchFrame] = [] self.buffer_a: List[BatchFrame] = []
@ -130,7 +134,9 @@ class ModelController:
# Fixed batch size # Fixed batch size
return batch_dim return batch_dim
except Exception as e: except Exception as e:
logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1") logger.warning(
f"Could not detect model batch size: {e}. Assuming batch_size=1"
)
return 1 return 1
def start(self): def start(self):
@ -143,14 +149,20 @@ class ModelController:
self.stop_event.clear() self.stop_event.clear()
# Start timeout monitor thread # Start timeout monitor thread
self.timeout_thread = threading.Thread(target=self._timeout_monitor, daemon=True) self.timeout_thread = threading.Thread(
target=self._timeout_monitor, daemon=True
)
self.timeout_thread.start() self.timeout_thread.start()
# Start processor threads for each buffer # Start processor threads for each buffer
self.processor_threads['A'] = threading.Thread(target=self._batch_processor, args=('A',), daemon=True) self.processor_threads["A"] = threading.Thread(
self.processor_threads['B'] = threading.Thread(target=self._batch_processor, args=('B',), daemon=True) target=self._batch_processor, args=("A",), daemon=True
self.processor_threads['A'].start() )
self.processor_threads['B'].start() self.processor_threads["B"] = threading.Thread(
target=self._batch_processor, args=("B",), daemon=True
)
self.processor_threads["A"].start()
self.processor_threads["B"].start()
logger.info("ModelController started") logger.info("ModelController started")
@ -197,10 +209,7 @@ class ModelController:
logger.debug(f"Unregistered callback for stream: {stream_id}") logger.debug(f"Unregistered callback for stream: {stream_id}")
def submit_frame( def submit_frame(
self, self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
stream_id: str,
frame: torch.Tensor,
metadata: Optional[Dict] = None
): ):
""" """
Submit a frame for batched inference. Submit a frame for batched inference.
@ -215,7 +224,7 @@ class ModelController:
stream_id=stream_id, stream_id=stream_id,
frame=frame, frame=frame,
timestamp=time.time(), timestamp=time.time(),
metadata=metadata or {} metadata=metadata or {},
) )
# Add to active buffer # Add to active buffer
@ -242,7 +251,9 @@ class ModelController:
# Check if timeout expired and we have frames waiting # Check if timeout expired and we have frames waiting
if time_since_submit >= self.force_timeout: if time_since_submit >= self.force_timeout:
active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b active_buffer = (
self.buffer_a if self.active_buffer == "A" else self.buffer_b
)
if len(active_buffer) > 0: if len(active_buffer) > 0:
self._try_swap_buffers() self._try_swap_buffers()
@ -254,7 +265,9 @@ class ModelController:
This method should be called with buffer_lock held. This method should be called with buffer_lock held.
""" """
# Check if inactive buffer is available # Check if inactive buffer is available
inactive_state = self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state inactive_state = (
self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
)
if inactive_state != BufferState.PROCESSING: if inactive_state != BufferState.PROCESSING:
# Swap active buffer # Swap active buffer
@ -269,7 +282,9 @@ class ModelController:
self.buffer_b_state = BufferState.PROCESSING self.buffer_b_state = BufferState.PROCESSING
buffer_size = len(self.buffer_b) buffer_size = len(self.buffer_b)
logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})") logger.debug(
f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})"
)
def _batch_processor(self, buffer_name: str): def _batch_processor(self, buffer_name: str):
"""Background thread that processes a specific buffer when available""" """Background thread that processes a specific buffer when available"""
@ -322,8 +337,8 @@ class ModelController:
self.total_batches_processed += 1 self.total_batches_processed += 1
logger.debug( logger.debug(
f"Processed batch of {len(batch)} frames in {inference_time*1000:.2f}ms " f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms "
f"({inference_time*1000/len(batch):.2f}ms per frame)" f"({inference_time * 1000 / len(batch):.2f}ms per frame)"
) )
# Emit results to callbacks # Emit results to callbacks
@ -334,7 +349,10 @@ class ModelController:
try: try:
callback(result) callback(result)
except Exception as e: except Exception as e:
logger.error(f"Error in callback for {batch_frame.stream_id}: {e}", exc_info=True) logger.error(
f"Error in callback for {batch_frame.stream_id}: {e}",
exc_info=True,
)
except Exception as e: except Exception as e:
logger.error(f"Error processing batch: {e}", exc_info=True) logger.error(f"Error processing batch: {e}", exc_info=True)
@ -365,7 +383,9 @@ class ModelController:
# Use true batching for models that support it # Use true batching for models that support it
return self._run_batched_inference(batch) return self._run_batched_inference(batch)
def _run_sequential_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]: def _run_sequential_inference(
self, batch: List[BatchFrame]
) -> List[Dict[str, Any]]:
"""Run inference sequentially for batch_size=1 models""" """Run inference sequentially for batch_size=1 models"""
results = [] results = []
@ -375,13 +395,15 @@ class ModelController:
processed = self.preprocess_fn(batch_frame.frame) processed = self.preprocess_fn(batch_frame.frame)
else: else:
# Ensure we have batch dimension # Ensure we have batch dimension
processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame processed = (
batch_frame.frame.unsqueeze(0)
if batch_frame.frame.dim() == 3
else batch_frame.frame
)
# Run inference for this frame # Run inference for this frame
outputs = self.model_repository.infer( outputs = self.model_repository.infer(
self.model_id, self.model_id, {"images": processed}, synchronize=True
{"images": processed},
synchronize=True
) )
# Postprocess # Postprocess
@ -389,9 +411,13 @@ class ModelController:
try: try:
detections = self.postprocess_fn(outputs) detections = self.postprocess_fn(outputs)
except Exception as e: except Exception as e:
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}") logger.error(
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
)
# Return empty detections on error # Return empty detections on error
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device) detections = torch.zeros(
(0, 6), device=list(outputs.values())[0].device
)
else: else:
detections = outputs detections = outputs
@ -429,32 +455,37 @@ class ModelController:
f"will split into sub-batches" f"will split into sub-batches"
) )
# TODO: Handle splitting into sub-batches # TODO: Handle splitting into sub-batches
batch_tensor = batch_tensor[:self.model_batch_size] batch_tensor = batch_tensor[: self.model_batch_size]
batch = batch[:self.model_batch_size] batch = batch[: self.model_batch_size]
# Run inference # Run inference
outputs = self.model_repository.infer( outputs = self.model_repository.infer(
self.model_id, self.model_id, {"images": batch_tensor}, synchronize=True
{"images": batch_tensor},
synchronize=True
) )
# Postprocess results (split batch back to individual results) # Postprocess results (split batch back to individual results)
results = [] results = []
for i, batch_frame in enumerate(batch): for i, batch_frame in enumerate(batch):
# Extract single frame output from batch # Extract single frame output from batch and clone to ensure memory safety
# This prevents potential race conditions if the output tensors are still
# in use when the next inference batch is processed
frame_output = {} frame_output = {}
for k, v in outputs.items(): for k, v in outputs.items():
# v has shape (N, ...), extract index i and keep batch dimension # v has shape (N, ...), extract index i and keep batch dimension
frame_output[k] = v[i:i+1] # Shape: (1, ...) # Clone to decouple from shared batch output tensor
frame_output[k] = v[i : i + 1].clone() # Shape: (1, ...)
if self.postprocess_fn: if self.postprocess_fn:
try: try:
detections = self.postprocess_fn(frame_output) detections = self.postprocess_fn(frame_output)
except Exception as e: except Exception as e:
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}") logger.error(
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
)
# Return empty detections on error # Return empty detections on error
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device) detections = torch.zeros(
(0, 6), device=list(outputs.values())[0].device
)
else: else:
detections = frame_output detections = frame_output
@ -490,6 +521,8 @@ class ModelController:
"total_batches_processed": self.total_batches_processed, "total_batches_processed": self.total_batches_processed,
"avg_batch_size": ( "avg_batch_size": (
self.total_frames_processed / self.total_batches_processed self.total_frames_processed / self.total_batches_processed
if self.total_batches_processed > 0
else 0
if self.total_batches_processed > 0 else 0 if self.total_batches_processed > 0 else 0
), ),
} }

View file

@ -1,13 +1,14 @@
import threading
import hashlib import hashlib
import json import json
from typing import Optional, Dict, Any, List, Tuple import logging
import threading
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
import torch from typing import Any, Dict, List, Optional, Tuple
import tensorrt as trt import tensorrt as trt
from dataclasses import dataclass import torch
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,6 +16,7 @@ logger = logging.getLogger(__name__)
@dataclass @dataclass
class ModelMetadata: class ModelMetadata:
"""Metadata for a loaded TensorRT model""" """Metadata for a loaded TensorRT model"""
file_path: str file_path: str
file_hash: str file_hash: str
input_shapes: Dict[str, Tuple[int, ...]] input_shapes: Dict[str, Tuple[int, ...]]
@ -30,8 +32,14 @@ class ExecutionContext:
Wrapper for TensorRT execution context with CUDA stream. Wrapper for TensorRT execution context with CUDA stream.
Used in context pool for load balancing. Used in context pool for load balancing.
""" """
def __init__(self, context: trt.IExecutionContext, stream: torch.cuda.Stream,
context_id: int, device: torch.device): def __init__(
self,
context: trt.IExecutionContext,
stream: torch.cuda.Stream,
context_id: int,
device: torch.device,
):
self.context = context self.context = context
self.stream = stream self.stream = stream
self.context_id = context_id self.context_id = context_id
@ -53,8 +61,16 @@ class SharedEngine:
- Contexts are borrowed/returned using mutex locks - Contexts are borrowed/returned using mutex locks
- Load balancing: contexts distributed across requests - Load balancing: contexts distributed across requests
""" """
def __init__(self, engine: trt.ICudaEngine, file_hash: str, file_path: str,
num_contexts: int, device: torch.device, metadata: ModelMetadata): def __init__(
self,
engine: trt.ICudaEngine,
file_hash: str,
file_path: str,
num_contexts: int,
device: torch.device,
metadata: ModelMetadata,
):
self.engine = engine self.engine = engine
self.file_hash = file_hash self.file_hash = file_hash
self.file_path = file_path self.file_path = file_path
@ -80,9 +96,13 @@ class SharedEngine:
self.model_ids: set = set() self.model_ids: set = set()
self.lock = threading.Lock() self.lock = threading.Lock()
print(f"Created context pool with {num_contexts} contexts for engine {file_hash[:8]}...") print(
f"Created context pool with {num_contexts} contexts for engine {file_hash[:8]}..."
)
def acquire_context(self, timeout: Optional[float] = None) -> Optional[ExecutionContext]: def acquire_context(
self, timeout: Optional[float] = None
) -> Optional[ExecutionContext]:
""" """
Acquire an available execution context from the pool. Acquire an available execution context from the pool.
Blocks if all contexts are in use. Blocks if all contexts are in use.
@ -162,7 +182,13 @@ class TensorRTModelRepository:
# Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts! # Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts!
""" """
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4, enable_pt_conversion: bool = True, cache_dir: str = ".trt_cache"): def __init__(
self,
gpu_id: int = 0,
default_num_contexts: int = 4,
enable_pt_conversion: bool = True,
cache_dir: str = ".trt_cache",
):
""" """
Initialize the model repository. Initialize the model repository.
@ -173,7 +199,7 @@ class TensorRTModelRepository:
cache_dir: Directory for caching stripped TensorRT engines and metadata cache_dir: Directory for caching stripped TensorRT engines and metadata
""" """
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.device = torch.device(f'cuda:{gpu_id}') self.device = torch.device(f"cuda:{gpu_id}")
self.default_num_contexts = default_num_contexts self.default_num_contexts = default_num_contexts
self.enable_pt_conversion = enable_pt_conversion self.enable_pt_conversion = enable_pt_conversion
self.cache_dir = Path(cache_dir) self.cache_dir = Path(cache_dir)
@ -195,7 +221,9 @@ class TensorRTModelRepository:
self._pt_converter = None self._pt_converter = None
print(f"TensorRT Model Repository initialized on GPU {gpu_id}") print(f"TensorRT Model Repository initialized on GPU {gpu_id}")
print(f"Default context pool size: {default_num_contexts} contexts per unique model") print(
f"Default context pool size: {default_num_contexts} contexts per unique model"
)
print(f"Cache directory: {self.cache_dir}") print(f"Cache directory: {self.cache_dir}")
if enable_pt_conversion: if enable_pt_conversion:
print(f"PyTorch to TensorRT conversion: enabled") print(f"PyTorch to TensorRT conversion: enabled")
@ -205,6 +233,7 @@ class TensorRTModelRepository:
"""Lazy initialization of PT converter""" """Lazy initialization of PT converter"""
if self._pt_converter is None and self.enable_pt_conversion: if self._pt_converter is None and self.enable_pt_conversion:
from .pt_converter import PTConverter from .pt_converter import PTConverter
self._pt_converter = PTConverter(gpu_id=self.gpu_id) self._pt_converter = PTConverter(gpu_id=self.gpu_id)
logger.info("PT converter initialized") logger.info("PT converter initialized")
return self._pt_converter return self._pt_converter
@ -255,11 +284,11 @@ class TensorRTModelRepository:
# Check if stripped engine already cached # Check if stripped engine already cached
if cache_engine_path.exists(): if cache_engine_path.exists():
logger.info(f"Loading cached stripped engine from {cache_engine_path}") logger.info(f"Loading cached stripped engine from {cache_engine_path}")
with open(cache_engine_path, 'rb') as f: with open(cache_engine_path, "rb") as f:
engine_data = f.read() engine_data = f.read()
else: else:
# Read and process original file # Read and process original file
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
# Try to read Ultralytics metadata header (first 4 bytes = metadata length) # Try to read Ultralytics metadata header (first 4 bytes = metadata length)
try: try:
meta_len_bytes = f.read(4) meta_len_bytes = f.read(4)
@ -278,13 +307,15 @@ class TensorRTModelRepository:
# Save stripped engine to cache # Save stripped engine to cache
logger.info(f"Detected Ultralytics engine format") logger.info(f"Detected Ultralytics engine format")
logger.info(f"Ultralytics metadata: {metadata}") logger.info(f"Ultralytics metadata: {metadata}")
logger.info(f"Caching stripped engine to {cache_engine_path}") logger.info(
f"Caching stripped engine to {cache_engine_path}"
)
with open(cache_engine_path, 'wb') as cache_f: with open(cache_engine_path, "wb") as cache_f:
cache_f.write(engine_data) cache_f.write(engine_data)
# Save metadata separately # Save metadata separately
with open(cache_metadata_path, 'w') as meta_f: with open(cache_metadata_path, "w") as meta_f:
json.dump(metadata, meta_f, indent=2) json.dump(metadata, meta_f, indent=2)
except (UnicodeDecodeError, json.JSONDecodeError): except (UnicodeDecodeError, json.JSONDecodeError):
@ -301,13 +332,15 @@ class TensorRTModelRepository:
except Exception as e: except Exception as e:
# Any error, rewind and read entire file # Any error, rewind and read entire file
logger.warning(f"Error reading engine metadata: {e}, treating as raw TRT engine") logger.warning(
f"Error reading engine metadata: {e}, treating as raw TRT engine"
)
f.seek(0) f.seek(0)
engine_data = f.read() engine_data = f.read()
# Cache the engine data (even if it was already raw TRT) # Cache the engine data (even if it was already raw TRT)
if not cache_engine_path.exists(): if not cache_engine_path.exists():
with open(cache_engine_path, 'wb') as cache_f: with open(cache_engine_path, "wb") as cache_f:
cache_f.write(engine_data) cache_f.write(engine_data)
engine = runtime.deserialize_cuda_engine(engine_data) engine = runtime.deserialize_cuda_engine(engine_data)
@ -316,8 +349,9 @@ class TensorRTModelRepository:
return engine return engine
def _extract_metadata(self, engine: trt.ICudaEngine, def _extract_metadata(
file_path: str, file_hash: str) -> ModelMetadata: self, engine: trt.ICudaEngine, file_path: str, file_hash: str
) -> ModelMetadata:
""" """
Extract metadata from TensorRT engine. Extract metadata from TensorRT engine.
@ -369,15 +403,19 @@ class TensorRTModelRepository:
input_names=input_names, input_names=input_names,
output_names=output_names, output_names=output_names,
input_dtypes=input_dtypes, input_dtypes=input_dtypes,
output_dtypes=output_dtypes output_dtypes=output_dtypes,
) )
def load_model(self, model_id: str, file_path: str, def load_model(
num_contexts: Optional[int] = None, self,
force_reload: bool = False, model_id: str,
pt_input_shapes: Optional[Dict[str, Tuple]] = None, file_path: str,
pt_precision: Optional[torch.dtype] = None, num_contexts: Optional[int] = None,
**pt_conversion_kwargs) -> ModelMetadata: force_reload: bool = False,
pt_input_shapes: Optional[Dict[str, Tuple]] = None,
pt_precision: Optional[torch.dtype] = None,
**pt_conversion_kwargs,
) -> ModelMetadata:
""" """
Load a TensorRT model with the given ID. Load a TensorRT model with the given ID.
@ -410,7 +448,7 @@ class TensorRTModelRepository:
# Check if file is PyTorch model # Check if file is PyTorch model
file_ext = Path(file_path).suffix.lower() file_ext = Path(file_path).suffix.lower()
if file_ext in ['.pt', '.pth']: if file_ext in [".pt", ".pth"]:
if not self.enable_pt_conversion: if not self.enable_pt_conversion:
raise ValueError( raise ValueError(
f"PT file provided but PT conversion is disabled. " f"PT file provided but PT conversion is disabled. "
@ -425,7 +463,7 @@ class TensorRTModelRepository:
file_path, file_path,
input_shapes=pt_input_shapes, input_shapes=pt_input_shapes,
precision=pt_precision, precision=pt_precision,
**pt_conversion_kwargs **pt_conversion_kwargs,
) )
# Update file_path to use converted TRT file # Update file_path to use converted TRT file
@ -455,8 +493,12 @@ class TensorRTModelRepository:
# Check if this file is already loaded (deduplication) # Check if this file is already loaded (deduplication)
if file_hash in self._shared_engines: if file_hash in self._shared_engines:
shared_engine = self._shared_engines[file_hash] shared_engine = self._shared_engines[file_hash]
print(f"Engine already loaded (hash match), reusing engine and context pool...") print(
print(f" Existing model_ids using this engine: {shared_engine.model_ids}") f"Engine already loaded (hash match), reusing engine and context pool..."
)
print(
f" Existing model_ids using this engine: {shared_engine.model_ids}"
)
else: else:
# Load new engine # Load new engine
print(f"Loading TensorRT engine from {file_path}...") print(f"Loading TensorRT engine from {file_path}...")
@ -472,7 +514,7 @@ class TensorRTModelRepository:
file_path=file_path, file_path=file_path,
num_contexts=num_contexts, num_contexts=num_contexts,
device=self.device, device=self.device,
metadata=metadata metadata=metadata,
) )
self._shared_engines[file_hash] = shared_engine self._shared_engines[file_hash] = shared_engine
@ -485,18 +527,29 @@ class TensorRTModelRepository:
print(f"Model '{model_id}' loaded successfully") print(f"Model '{model_id}' loaded successfully")
print(f" Inputs: {shared_engine.metadata.input_names}") print(f" Inputs: {shared_engine.metadata.input_names}")
for name in shared_engine.metadata.input_names: for name in shared_engine.metadata.input_names:
print(f" {name}: {shared_engine.metadata.input_shapes[name]} ({shared_engine.metadata.input_dtypes[name]})") print(
f" {name}: {shared_engine.metadata.input_shapes[name]} ({shared_engine.metadata.input_dtypes[name]})"
)
print(f" Outputs: {shared_engine.metadata.output_names}") print(f" Outputs: {shared_engine.metadata.output_names}")
for name in shared_engine.metadata.output_names: for name in shared_engine.metadata.output_names:
print(f" {name}: {shared_engine.metadata.output_shapes[name]} ({shared_engine.metadata.output_dtypes[name]})") print(
f" {name}: {shared_engine.metadata.output_shapes[name]} ({shared_engine.metadata.output_dtypes[name]})"
)
print(f" Context pool size: {num_contexts}") print(f" Context pool size: {num_contexts}")
print(f" Model IDs sharing this engine: {shared_engine.get_reference_count()}") print(
f" Model IDs sharing this engine: {shared_engine.get_reference_count()}"
)
print(f" Unique engines in VRAM: {len(self._shared_engines)}") print(f" Unique engines in VRAM: {len(self._shared_engines)}")
return shared_engine.metadata return shared_engine.metadata
def infer(self, model_id: str, inputs: Dict[str, torch.Tensor], def infer(
synchronize: bool = True, timeout: Optional[float] = 5.0) -> Dict[str, torch.Tensor]: self,
model_id: str,
inputs: Dict[str, torch.Tensor],
synchronize: bool = True,
timeout: Optional[float] = 5.0,
) -> Dict[str, torch.Tensor]:
""" """
Run GPU-to-GPU inference with the specified model using context pooling. Run GPU-to-GPU inference with the specified model using context pooling.
@ -519,7 +572,9 @@ class TensorRTModelRepository:
""" """
# Get shared engine # Get shared engine
if model_id not in self._model_to_hash: if model_id not in self._model_to_hash:
raise KeyError(f"Model '{model_id}' not found. Available: {list(self._model_to_hash.keys())}") raise KeyError(
f"Model '{model_id}' not found. Available: {list(self._model_to_hash.keys())}"
)
file_hash = self._model_to_hash[model_id] file_hash = self._model_to_hash[model_id]
shared_engine = self._shared_engines[file_hash] shared_engine = self._shared_engines[file_hash]
@ -536,7 +591,9 @@ class TensorRTModelRepository:
# Check device # Check device
if tensor.device != self.device: if tensor.device != self.device:
print(f"Warning: Input '{name}' on {tensor.device}, moving to {self.device}") print(
f"Warning: Input '{name}' on {tensor.device}, moving to {self.device}"
)
inputs[name] = tensor.to(self.device) inputs[name] = tensor.to(self.device)
# Acquire context from pool (mutex-based) # Acquire context from pool (mutex-based)
@ -562,9 +619,7 @@ class TensorRTModelRepository:
output_dtype = metadata.output_dtypes[name] output_dtype = metadata.output_dtypes[name]
output_tensor = torch.empty( output_tensor = torch.empty(
output_shape, output_shape, dtype=output_dtype, device=self.device
dtype=output_dtype,
device=self.device
) )
# NOTE: Don't track these tensors - they're returned to caller and consumed # NOTE: Don't track these tensors - they're returned to caller and consumed
@ -584,9 +639,23 @@ class TensorRTModelRepository:
if not success: if not success:
raise RuntimeError(f"Inference failed for model '{model_id}'") raise RuntimeError(f"Inference failed for model '{model_id}'")
# Synchronize if requested # CRITICAL: Always synchronize before releasing context
if synchronize: # Even if caller requested async execution, we MUST sync before
exec_ctx.stream.synchronize() # releasing the context to prevent race conditions where the next
# inference using this context overwrites tensor addresses while
# the current batch is still being processed.
exec_ctx.stream.synchronize()
# Clone outputs to new tensors to ensure memory safety
# This prevents race conditions where the next batch using this context
# could overwrite the output tensor addresses before the caller
# finishes processing these results.
if not synchronize:
# For async mode, clone to decouple from context
cloned_outputs = {}
for name, tensor in outputs.items():
cloned_outputs[name] = tensor.clone()
outputs = cloned_outputs
return outputs return outputs
@ -594,8 +663,12 @@ class TensorRTModelRepository:
# Always release context back to pool # Always release context back to pool
shared_engine.release_context(exec_ctx) shared_engine.release_context(exec_ctx)
def infer_batch(self, model_id: str, batch_inputs: List[Dict[str, torch.Tensor]], def infer_batch(
synchronize: bool = True) -> List[Dict[str, torch.Tensor]]: self,
model_id: str,
batch_inputs: List[Dict[str, torch.Tensor]],
synchronize: bool = True,
) -> List[Dict[str, torch.Tensor]]:
""" """
Run inference on multiple inputs. Run inference on multiple inputs.
Contexts are borrowed/returned for each input, enabling parallel processing. Contexts are borrowed/returned for each input, enabling parallel processing.
@ -641,9 +714,13 @@ class TensorRTModelRepository:
if remaining_refs == 0: if remaining_refs == 0:
shared_engine.cleanup() shared_engine.cleanup()
del self._shared_engines[file_hash] del self._shared_engines[file_hash]
print(f"Model '{model_id}' unloaded, engine removed from VRAM (0 references)") print(
f"Model '{model_id}' unloaded, engine removed from VRAM (0 references)"
)
else: else:
print(f"Model '{model_id}' unloaded, engine kept in VRAM ({remaining_refs} references)") print(
f"Model '{model_id}' unloaded, engine kept in VRAM ({remaining_refs} references)"
)
# Remove from model_id mapping # Remove from model_id mapping
del self._model_to_hash[model_id] del self._model_to_hash[model_id]
@ -702,26 +779,26 @@ class TensorRTModelRepository:
metadata = shared_engine.metadata metadata = shared_engine.metadata
return { return {
'model_id': model_id, "model_id": model_id,
'file_path': metadata.file_path, "file_path": metadata.file_path,
'file_hash': metadata.file_hash[:16] + '...', "file_hash": metadata.file_hash[:16] + "...",
'engine_references': shared_engine.get_reference_count(), "engine_references": shared_engine.get_reference_count(),
'context_pool_size': shared_engine.num_contexts, "context_pool_size": shared_engine.num_contexts,
'shared_with_model_ids': list(shared_engine.model_ids), "shared_with_model_ids": list(shared_engine.model_ids),
'inputs': { "inputs": {
name: { name: {
'shape': metadata.input_shapes[name], "shape": metadata.input_shapes[name],
'dtype': str(metadata.input_dtypes[name]) "dtype": str(metadata.input_dtypes[name]),
} }
for name in metadata.input_names for name in metadata.input_names
}, },
'outputs': { "outputs": {
name: { name: {
'shape': metadata.output_shapes[name], "shape": metadata.output_shapes[name],
'dtype': str(metadata.output_dtypes[name]) "dtype": str(metadata.output_dtypes[name]),
} }
for name in metadata.output_names for name in metadata.output_names
} },
} }
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
@ -733,24 +810,25 @@ class TensorRTModelRepository:
""" """
with self._repo_lock: with self._repo_lock:
total_contexts = sum( total_contexts = sum(
engine.num_contexts engine.num_contexts for engine in self._shared_engines.values()
for engine in self._shared_engines.values()
) )
return { return {
'total_model_ids': len(self._model_to_hash), "total_model_ids": len(self._model_to_hash),
'unique_engines': len(self._shared_engines), "unique_engines": len(self._shared_engines),
'total_contexts': total_contexts, "total_contexts": total_contexts,
'memory_efficiency': f"{len(self._model_to_hash)} model IDs using only {len(self._shared_engines)} engines", "memory_efficiency": f"{len(self._model_to_hash)} model IDs using only {len(self._shared_engines)} engines",
'gpu_id': self.gpu_id, "gpu_id": self.gpu_id,
'models': list(self._model_to_hash.keys()) "models": list(self._model_to_hash.keys()),
} }
def __repr__(self): def __repr__(self):
with self._repo_lock: with self._repo_lock:
return (f"TensorRTModelRepository(gpu={self.gpu_id}, " return (
f"model_ids={len(self._model_to_hash)}, " f"TensorRTModelRepository(gpu={self.gpu_id}, "
f"unique_engines={len(self._shared_engines)})") f"model_ids={len(self._model_to_hash)}, "
f"unique_engines={len(self._shared_engines)})"
)
def __del__(self): def __del__(self):
"""Cleanup all models on deletion""" """Cleanup all models on deletion"""

View file

@ -5,25 +5,28 @@ This module provides high-level connection management for multiple RTSP streams,
coordinating decoders, batched inference, and tracking with callbacks and threading. coordinating decoders, batched inference, and tracking with callbacks and threading.
""" """
import threading
import time
from typing import Dict, Optional, Callable, Tuple, Any, List
from dataclasses import dataclass
from enum import Enum
import logging import logging
import queue import queue
import threading
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
from .model_controller import ModelController from .base_model_controller import BaseModelController
from .stream_decoder import StreamDecoderFactory
from .model_repository import TensorRTModelRepository from .model_repository import TensorRTModelRepository
from .stream_decoder import StreamDecoderFactory
from .tensorrt_model_controller import TensorRTModelController
from .ultralytics_model_controller import UltralyticsModelController
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ConnectionStatus(Enum): class ConnectionStatus(Enum):
"""Status of a stream connection""" """Status of a stream connection"""
CONNECTING = "connecting" CONNECTING = "connecting"
CONNECTED = "connected" CONNECTED = "connected"
DISCONNECTED = "disconnected" DISCONNECTED = "disconnected"
@ -33,6 +36,7 @@ class ConnectionStatus(Enum):
@dataclass @dataclass
class TrackingResult: class TrackingResult:
"""Result emitted to user callbacks""" """Result emitted to user callbacks"""
stream_id: str stream_id: str
timestamp: float timestamp: float
tracked_objects: List # List of TrackedObject from TrackingController tracked_objects: List # List of TrackedObject from TrackingController
@ -61,7 +65,7 @@ class StreamConnection:
self, self,
stream_id: str, stream_id: str,
decoder, decoder,
model_controller: ModelController, model_controller: BaseModelController,
tracking_controller, tracking_controller,
poll_interval: float = 0.01, poll_interval: float = 0.01,
): ):
@ -107,7 +111,9 @@ class StreamConnection:
break break
else: else:
# Timeout - but don't fail hard, let it try to connect in background # Timeout - but don't fail hard, let it try to connect in background
logger.warning(f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying...") logger.warning(
f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying..."
)
self.status = ConnectionStatus.CONNECTING self.status = ConnectionStatus.CONNECTING
def stop(self): def stop(self):
@ -144,28 +150,42 @@ class StreamConnection:
self.last_frame_time = time.time() self.last_frame_time = time.time()
self.frame_count += 1 self.frame_count += 1
# CRITICAL: Clone the GPU tensor to decouple from decoder's frame buffer
# The decoder reuses frame buffer memory, so we must copy the tensor
# before submitting to async batched inference to prevent race conditions
# where the decoder overwrites memory while inference is still reading it.
cloned_tensor = frame_ref.rgb_tensor.clone()
# Submit to model controller for batched inference # Submit to model controller for batched inference
# Pass the FrameReference in metadata so we can free it later # Pass the FrameReference in metadata so we can free it later
self.model_controller.submit_frame( self.model_controller.submit_frame(
stream_id=self.stream_id, stream_id=self.stream_id,
frame=frame_ref.rgb_tensor, frame=cloned_tensor, # Use cloned tensor, not original
metadata={ metadata={
"frame_number": self.frame_count, "frame_number": self.frame_count,
"shape": tuple(frame_ref.rgb_tensor.shape), "shape": tuple(cloned_tensor.shape),
"frame_ref": frame_ref, # Store reference for later cleanup "frame_ref": frame_ref, # Store reference for later cleanup
} },
) )
# Update connection status based on decoder status # Update connection status based on decoder status
if self.decoder.is_connected() and self.status != ConnectionStatus.CONNECTED: if (
self.decoder.is_connected()
and self.status != ConnectionStatus.CONNECTED
):
logger.info(f"Stream {self.stream_id} reconnected") logger.info(f"Stream {self.stream_id} reconnected")
self.status = ConnectionStatus.CONNECTED self.status = ConnectionStatus.CONNECTED
elif not self.decoder.is_connected() and self.status == ConnectionStatus.CONNECTED: elif (
not self.decoder.is_connected()
and self.status == ConnectionStatus.CONNECTED
):
logger.warning(f"Stream {self.stream_id} disconnected") logger.warning(f"Stream {self.stream_id} disconnected")
self.status = ConnectionStatus.DISCONNECTED self.status = ConnectionStatus.DISCONNECTED
except Exception as e: except Exception as e:
logger.error(f"Error processing frame for {self.stream_id}: {e}", exc_info=True) logger.error(
f"Error processing frame for {self.stream_id}: {e}", exc_info=True
)
self.error_queue.put(e) self.error_queue.put(e)
self.status = ConnectionStatus.ERROR self.status = ConnectionStatus.ERROR
# Free the frame on error # Free the frame on error
@ -205,7 +225,10 @@ class StreamConnection:
self.result_queue.put(tracking_result) self.result_queue.put(tracking_result)
except Exception as e: except Exception as e:
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True) logger.error(
f"Error handling inference result for {self.stream_id}: {e}",
exc_info=True,
)
self.error_queue.put(e) self.error_queue.put(e)
finally: finally:
# Free the frame reference - this is the last point in the pipeline # Free the frame reference - this is the last point in the pipeline
@ -235,12 +258,16 @@ class StreamConnection:
if confidence < min_confidence: if confidence < min_confidence:
continue continue
detection_list.append(Detection( detection_list.append(
bbox=det[:4].cpu().tolist(), Detection(
confidence=confidence, bbox=det[:4].cpu().tolist(),
class_id=int(det[5]) if det.shape[0] > 5 else 0, confidence=confidence,
class_name=f"class_{int(det[5])}" if det.shape[0] > 5 else "unknown" class_id=int(det[5]) if det.shape[0] > 5 else 0,
)) class_name=f"class_{int(det[5])}"
if det.shape[0] > 5
else "unknown",
)
)
# Update tracker with detections (will scale bboxes to frame space) # Update tracker with detections (will scale bboxes to frame space)
return self.tracking_controller.update(detection_list, frame_shape=frame_shape) return self.tracking_controller.update(detection_list, frame_shape=frame_shape)
@ -319,21 +346,38 @@ class StreamConnectionManager:
force_timeout: float = 0.05, force_timeout: float = 0.05,
poll_interval: float = 0.01, poll_interval: float = 0.01,
enable_pt_conversion: bool = True, enable_pt_conversion: bool = True,
backend: str = "tensorrt", # "tensorrt" or "ultralytics"
): ):
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.batch_size = batch_size self.batch_size = batch_size
self.force_timeout = force_timeout self.force_timeout = force_timeout
self.poll_interval = poll_interval self.poll_interval = poll_interval
self.backend = backend.lower()
# Factories # Factories
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id) self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
self.model_repository = TensorRTModelRepository(
gpu_id=gpu_id, # Initialize inference engine based on backend
enable_pt_conversion=enable_pt_conversion self.inference_engine = None
) self.model_repository = None # Legacy - will be removed
if self.backend == "ultralytics":
# Use Ultralytics native YOLO inference
from .inference_engine import UltralyticsEngine
self.inference_engine = UltralyticsEngine()
logger.info("Using Ultralytics inference engine")
else:
# Use native TensorRT inference
self.model_repository = TensorRTModelRepository(
gpu_id=gpu_id, enable_pt_conversion=enable_pt_conversion
)
logger.info("Using native TensorRT inference engine")
# Controllers # Controllers
self.model_controller: Optional[ModelController] = None self.model_controller = (
None # Will be TensorRTModelController or UltralyticsModelController
)
# Connections # Connections
self.connections: Dict[str, StreamConnection] = {} self.connections: Dict[str, StreamConnection] = {}
@ -350,7 +394,7 @@ class StreamConnectionManager:
num_contexts: int = 4, num_contexts: int = 4,
pt_input_shapes: Optional[Dict] = None, pt_input_shapes: Optional[Dict] = None,
pt_precision: Optional[Any] = None, pt_precision: Optional[Any] = None,
**pt_conversion_kwargs **pt_conversion_kwargs,
): ):
""" """
Initialize the manager with a model. Initialize the manager with a model.
@ -382,28 +426,58 @@ class StreamConnectionManager:
) )
""" """
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}") logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
logger.info(f"Backend: {self.backend}")
# Load model (synchronous) # Initialize engine based on backend
self.model_repository.load_model( if self.backend == "ultralytics":
model_id, # Use Ultralytics native inference
model_path, logger.info("Initializing Ultralytics YOLO engine...")
num_contexts=num_contexts, device = torch.device(f"cuda:{self.gpu_id}")
pt_input_shapes=pt_input_shapes,
pt_precision=pt_precision,
**pt_conversion_kwargs
)
logger.info(f"Loaded model {model_id} from {model_path}")
# Create model controller metadata = self.inference_engine.initialize(
self.model_controller = ModelController( model_path=model_path,
model_repository=self.model_repository, device=device,
model_id=model_id, batch=self.batch_size,
batch_size=self.batch_size, half=False, # Use FP32 for now
force_timeout=self.force_timeout, imgsz=640,
preprocess_fn=preprocess_fn, **pt_conversion_kwargs,
postprocess_fn=postprocess_fn, )
) logger.info(f"Ultralytics engine initialized: {metadata}")
self.model_controller.start()
# Create Ultralytics model controller
self.model_controller = UltralyticsModelController(
inference_engine=self.inference_engine,
model_id=model_id,
batch_size=self.batch_size,
force_timeout=self.force_timeout,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
self.model_controller.start()
else:
# Use native TensorRT with model repository
logger.info("Initializing TensorRT engine...")
self.model_repository.load_model(
model_id,
model_path,
num_contexts=num_contexts,
pt_input_shapes=pt_input_shapes,
pt_precision=pt_precision,
**pt_conversion_kwargs,
)
logger.info(f"Loaded model {model_id} from {model_path}")
# Create TensorRT model controller
self.model_controller = TensorRTModelController(
model_repository=self.model_repository,
model_id=model_id,
batch_size=self.batch_size,
force_timeout=self.force_timeout,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
self.model_controller.start()
# Don't create a shared tracking controller here # Don't create a shared tracking controller here
# Each stream will get its own tracking controller to avoid track accumulation # Each stream will get its own tracking controller to avoid track accumulation
@ -452,12 +526,13 @@ class StreamConnectionManager:
# Create lightweight tracker (NO model_repository dependency!) # Create lightweight tracker (NO model_repository dependency!)
from .tracking_controller import ObjectTracker from .tracking_controller import ObjectTracker
tracking_controller = ObjectTracker( tracking_controller = ObjectTracker(
gpu_id=self.gpu_id, gpu_id=self.gpu_id,
tracker_type="iou", tracker_type="iou",
max_age=30, max_age=30,
iou_threshold=0.3, iou_threshold=0.3,
class_names=None # TODO: pass class names if available class_names=None, # TODO: pass class names if available
) )
logger.info(f"Created lightweight ObjectTracker for stream {stream_id}") logger.info(f"Created lightweight ObjectTracker for stream {stream_id}")
@ -472,8 +547,7 @@ class StreamConnectionManager:
# Register callback with model controller # Register callback with model controller
self.model_controller.register_callback( self.model_controller.register_callback(
stream_id, stream_id, connection._handle_inference_result
connection._handle_inference_result
) )
# Start connection # Start connection
@ -487,14 +561,12 @@ class StreamConnectionManager:
threading.Thread( threading.Thread(
target=self._forward_results, target=self._forward_results,
args=(connection, on_tracking_result), args=(connection, on_tracking_result),
daemon=True daemon=True,
).start() ).start()
if on_error: if on_error:
threading.Thread( threading.Thread(
target=self._forward_errors, target=self._forward_errors, args=(connection, on_error), daemon=True
args=(connection, on_error),
daemon=True
).start() ).start()
logger.info(f"Stream {stream_id} connected successfully") logger.info(f"Stream {stream_id} connected successfully")
@ -549,7 +621,10 @@ class StreamConnectionManager:
for result in connection.tracking_results(): for result in connection.tracking_results():
callback(result) callback(result)
except Exception as e: except Exception as e:
logger.error(f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True) logger.error(
f"Error in result forwarding for {connection.stream_id}: {e}",
exc_info=True,
)
def _forward_errors(self, connection: StreamConnection, callback: Callable): def _forward_errors(self, connection: StreamConnection, callback: Callable):
""" """
@ -563,7 +638,10 @@ class StreamConnectionManager:
for error in connection.errors(): for error in connection.errors():
callback(error) callback(error)
except Exception as e: except Exception as e:
logger.error(f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True) logger.error(
f"Error in error forwarding for {connection.stream_id}: {e}",
exc_info=True,
)
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
""" """
@ -581,7 +659,9 @@ class StreamConnectionManager:
"force_timeout": self.force_timeout, "force_timeout": self.force_timeout,
"poll_interval": self.poll_interval, "poll_interval": self.poll_interval,
}, },
"model_controller": self.model_controller.get_stats() if self.model_controller else {}, "model_controller": self.model_controller.get_stats()
if self.model_controller
else {},
"connections": { "connections": {
stream_id: conn.get_stats() stream_id: conn.get_stats()
for stream_id, conn in self.connections.items() for stream_id, conn in self.connections.items()

View file

@ -0,0 +1,182 @@
"""
TensorRT Model Controller - Native TensorRT inference with batched processing.
"""
import logging
from typing import Any, Callable, Dict, List, Optional
import torch
from .base_model_controller import BaseModelController, BatchFrame
logger = logging.getLogger(__name__)
class TensorRTModelController(BaseModelController):
"""
Model controller for native TensorRT inference.
Uses TensorRTModelRepository for GPU-accelerated inference with
context pooling and deduplication.
"""
def __init__(
self,
model_repository,
model_id: str,
batch_size: int = 16,
force_timeout: float = 0.05,
preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None,
):
super().__init__(
model_id=model_id,
batch_size=batch_size,
force_timeout=force_timeout,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
self.model_repository = model_repository
# Detect model's actual batch size from input shape
self.model_batch_size = self._detect_model_batch_size()
if self.model_batch_size == 1:
logger.warning(
f"Model '{model_id}' has fixed batch_size=1. "
f"Will process frames sequentially."
)
else:
logger.info(
f"Model '{model_id}' supports batch_size={self.model_batch_size}"
)
def _detect_model_batch_size(self) -> int:
"""Detect the model's batch size from its input shape"""
try:
metadata = self.model_repository.get_metadata(self.model_id)
first_input_name = metadata.input_names[0]
input_shape = metadata.input_shapes[first_input_name]
batch_dim = input_shape[0]
if batch_dim == -1:
return self.batch_size # Dynamic batch size
else:
return batch_dim # Fixed batch size
except Exception as e:
logger.warning(
f"Could not detect model batch size: {e}. Assuming batch_size=1"
)
return 1
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""Run TensorRT inference on a batch of frames"""
if self.model_batch_size == 1:
return self._run_sequential_inference(batch)
else:
return self._run_batched_inference(batch)
def _run_sequential_inference(
self, batch: List[BatchFrame]
) -> List[Dict[str, Any]]:
"""Run inference sequentially for batch_size=1 models"""
results = []
for batch_frame in batch:
# Preprocess frame
if self.preprocess_fn:
processed = self.preprocess_fn(batch_frame.frame)
else:
processed = (
batch_frame.frame.unsqueeze(0)
if batch_frame.frame.dim() == 3
else batch_frame.frame
)
# Run inference
outputs = self.model_repository.infer(
self.model_id, {"images": processed}, synchronize=True
)
# Postprocess
if self.postprocess_fn:
try:
detections = self.postprocess_fn(outputs)
except Exception as e:
logger.error(
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
)
detections = torch.zeros(
(0, 6), device=list(outputs.values())[0].device
)
else:
detections = outputs
result = {
"stream_id": batch_frame.stream_id,
"timestamp": batch_frame.timestamp,
"detections": detections,
"metadata": batch_frame.metadata,
}
results.append(result)
return results
def _run_batched_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""Run true batched inference for models that support it"""
# Preprocess frames
preprocessed = []
for batch_frame in batch:
if self.preprocess_fn:
processed = self.preprocess_fn(batch_frame.frame)
if processed.dim() == 4 and processed.shape[0] == 1:
processed = processed.squeeze(0)
else:
processed = batch_frame.frame
preprocessed.append(processed)
# Stack into batch tensor
batch_tensor = torch.stack(preprocessed, dim=0)
# Limit to model's max batch size
if batch_tensor.shape[0] > self.model_batch_size:
logger.warning(
f"Batch size {batch_tensor.shape[0]} exceeds model max {self.model_batch_size}"
)
batch_tensor = batch_tensor[: self.model_batch_size]
batch = batch[: self.model_batch_size]
# Run inference
outputs = self.model_repository.infer(
self.model_id, {"images": batch_tensor}, synchronize=True
)
# Postprocess results (split batch back to individual results)
results = []
for i, batch_frame in enumerate(batch):
# Extract single frame output and clone for memory safety
frame_output = {}
for k, v in outputs.items():
frame_output[k] = v[i : i + 1].clone()
if self.postprocess_fn:
try:
detections = self.postprocess_fn(frame_output)
except Exception as e:
logger.error(
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
)
detections = torch.zeros(
(0, 6), device=list(outputs.values())[0].device
)
else:
detections = frame_output
result = {
"stream_id": batch_frame.stream_id,
"timestamp": batch_frame.timestamp,
"detections": detections,
"metadata": batch_frame.metadata,
}
results.append(result)
return results

View file

@ -0,0 +1,222 @@
"""
Ultralytics YOLO Model Exporter with Caching
Exports YOLO .pt models to TensorRT .engine format using Ultralytics library.
Provides proper NMS and postprocessing built into the engine.
Caches exported engines to avoid redundant exports.
"""
import hashlib
import json
import logging
from pathlib import Path
from typing import Optional, Tuple
logger = logging.getLogger(__name__)
class UltralyticsExporter:
"""
Export YOLO models using Ultralytics with caching.
Features:
- Exports .pt models to TensorRT .engine format
- Caches exported engines by source file hash
- Saves metadata about exported models
- Reuses cached engines when available
"""
def __init__(self, cache_dir: str = ".ultralytics_cache"):
"""
Initialize exporter.
Args:
cache_dir: Directory for caching exported engines
"""
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Ultralytics exporter cache directory: {self.cache_dir}")
@staticmethod
def compute_file_hash(file_path: str) -> str:
"""
Compute SHA256 hash of a file.
Args:
file_path: Path to file
Returns:
Hexadecimal hash string
"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(65536), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
def export(
self,
model_path: str,
device: int = 0,
half: bool = False,
imgsz: int = 640,
batch: int = 1,
**export_kwargs,
) -> Tuple[str, str]:
"""
Export YOLO model to TensorRT engine with caching.
Args:
model_path: Path to .pt model file
device: GPU device ID
half: Use FP16 precision
imgsz: Input image size (default: 640)
batch: Maximum batch size for inference
**export_kwargs: Additional arguments for Ultralytics export
Returns:
Tuple of (engine_hash, engine_path)
Raises:
FileNotFoundError: If model file doesn't exist
RuntimeError: If export fails
"""
model_path = Path(model_path).resolve()
if not model_path.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
# Compute hash of source model
logger.info(f"Computing hash for {model_path}...")
model_hash = self.compute_file_hash(str(model_path))
logger.info(f"Model hash: {model_hash[:16]}...")
# Create export config hash (includes export parameters)
export_config = {
"model_hash": model_hash,
"device": device,
"half": half,
"imgsz": imgsz,
"batch": batch,
**export_kwargs,
}
config_str = json.dumps(export_config, sort_keys=True)
config_hash = hashlib.sha256(config_str.encode()).hexdigest()
# Check cache
cache_engine_path = self.cache_dir / f"{config_hash}.engine"
cache_metadata_path = self.cache_dir / f"{config_hash}_metadata.json"
if cache_engine_path.exists():
logger.info(f"Found cached engine: {cache_engine_path}")
logger.info(f"Reusing cached export (config hash: {config_hash[:16]}...)")
# Load and return metadata
if cache_metadata_path.exists():
with open(cache_metadata_path, "r") as f:
metadata = json.load(f)
logger.info(f"Cached engine metadata: {metadata}")
return config_hash, str(cache_engine_path)
# Export using Ultralytics
logger.info(f"Exporting YOLO model to TensorRT engine...")
logger.info(f" Source: {model_path}")
logger.info(f" Device: GPU {device}")
logger.info(f" Precision: {'FP16' if half else 'FP32'}")
logger.info(f" Image size: {imgsz}")
logger.info(f" Batch size: {batch}")
try:
from ultralytics import YOLO
# Load model
model = YOLO(str(model_path))
# Export to TensorRT
exported_path = model.export(
format="engine",
device=device,
half=half,
imgsz=imgsz,
batch=batch,
verbose=True,
**export_kwargs,
)
logger.info(f"Export complete: {exported_path}")
# Copy to cache
import shutil
shutil.copy(exported_path, cache_engine_path)
logger.info(f"Cached engine: {cache_engine_path}")
# Save metadata
metadata = {
"source_model": str(model_path),
"model_hash": model_hash,
"config_hash": config_hash,
"device": device,
"half": half,
"imgsz": imgsz,
"batch": batch,
"export_kwargs": export_kwargs,
"exported_path": str(exported_path),
"cached_path": str(cache_engine_path),
}
with open(cache_metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
logger.info(f"Saved metadata: {cache_metadata_path}")
return config_hash, str(cache_engine_path)
except Exception as e:
logger.error(f"Export failed: {e}")
raise RuntimeError(f"Failed to export YOLO model: {e}")
def get_cached_engine(self, model_path: str, **export_kwargs) -> Optional[str]:
"""
Get cached engine path if it exists.
Args:
model_path: Path to .pt model
**export_kwargs: Export parameters (must match cached export)
Returns:
Path to cached engine or None if not cached
"""
try:
model_path = Path(model_path).resolve()
if not model_path.exists():
return None
# Compute hashes
model_hash = self.compute_file_hash(str(model_path))
export_config = {"model_hash": model_hash, **export_kwargs}
config_str = json.dumps(export_config, sort_keys=True)
config_hash = hashlib.sha256(config_str.encode()).hexdigest()
cache_engine_path = self.cache_dir / f"{config_hash}.engine"
if cache_engine_path.exists():
return str(cache_engine_path)
return None
except Exception as e:
logger.warning(f"Failed to check cache: {e}")
return None
def clear_cache(self):
"""Clear all cached engines"""
import shutil
if self.cache_dir.exists():
shutil.rmtree(self.cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
logger.info("Cache cleared")

View file

@ -0,0 +1,217 @@
"""
Ultralytics Model Controller - YOLO inference with batched processing.
"""
import logging
from typing import Any, Callable, Dict, List, Optional
import torch
from .base_model_controller import BaseModelController, BatchFrame
logger = logging.getLogger(__name__)
class UltralyticsModelController(BaseModelController):
"""
Model controller for Ultralytics YOLO inference.
Uses UltralyticsEngine which wraps the Ultralytics YOLO model with
native TensorRT backend for GPU-accelerated inference.
"""
def __init__(
self,
inference_engine,
model_id: str,
batch_size: int = 16,
force_timeout: float = 0.05,
preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None,
):
# Auto-detect actual batch size from the YOLO engine
engine_batch_size = self._detect_engine_batch_size(inference_engine)
# If engine has fixed batch size, use it. Otherwise use user's batch_size
actual_batch_size = engine_batch_size if engine_batch_size > 0 else batch_size
super().__init__(
model_id=model_id,
batch_size=actual_batch_size,
force_timeout=force_timeout,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
self.inference_engine = inference_engine
self.engine_batch_size = engine_batch_size # Store for padding logic
if engine_batch_size > 0:
logger.info(
f"Ultralytics engine has fixed batch_size={engine_batch_size}, "
f"will pad batches to match"
)
else:
logger.info(
f"Ultralytics engine supports dynamic batching, "
f"using max batch_size={actual_batch_size}"
)
def _detect_engine_batch_size(self, inference_engine) -> int:
"""
Detect the batch size from Ultralytics engine.
Returns:
Fixed batch size (e.g., 2, 4, 8) or -1 for dynamic batching
"""
try:
# Get engine metadata
metadata = inference_engine.get_metadata()
# Check input shape for batch dimension
if "images" in metadata.input_shapes:
input_shape = metadata.input_shapes["images"]
batch_dim = input_shape[0]
if batch_dim > 0:
# Fixed batch size
return batch_dim
else:
# Dynamic batch size (-1)
return -1
# Fallback: try to get from model directly
if (
hasattr(inference_engine, "_model")
and inference_engine._model is not None
):
model = inference_engine._model
# Try to get batch info from Ultralytics model
if hasattr(model, "predictor") and model.predictor is not None:
predictor = model.predictor
if hasattr(predictor, "model") and hasattr(
predictor.model, "batch"
):
return predictor.model.batch
# Try to get from model.model (for .engine files)
if hasattr(model, "model"):
# For TensorRT engines, check input shape
if hasattr(model.model, "get_input_details"):
details = model.model.get_input_details()
if details and len(details) > 0:
shape = details[0].get("shape")
if shape and len(shape) > 0:
return shape[0] if shape[0] > 0 else -1
except Exception as e:
logger.warning(f"Could not detect engine batch size: {e}")
# Default: assume dynamic batching
return -1
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""
Run Ultralytics YOLO inference on a batch of frames.
Ultralytics handles batching natively and returns Results objects.
"""
# Preprocess frames
preprocessed = []
for batch_frame in batch:
if self.preprocess_fn:
processed = self.preprocess_fn(batch_frame.frame)
# Ensure shape is (C, H, W) not (1, C, H, W)
if processed.dim() == 4 and processed.shape[0] == 1:
processed = processed.squeeze(0)
else:
processed = batch_frame.frame
preprocessed.append(processed)
# Stack into batch tensor: (B, C, H, W)
batch_tensor = torch.stack(preprocessed, dim=0)
actual_batch_size = len(batch)
# Handle fixed batch size engines (pad if needed)
if self.engine_batch_size > 0:
# Engine has fixed batch size
if batch_tensor.shape[0] > self.engine_batch_size:
# Truncate to engine's max batch size
logger.warning(
f"Batch size {batch_tensor.shape[0]} exceeds engine max {self.engine_batch_size}, truncating"
)
batch_tensor = batch_tensor[: self.engine_batch_size]
batch = batch[: self.engine_batch_size]
actual_batch_size = self.engine_batch_size
elif batch_tensor.shape[0] < self.engine_batch_size:
# Pad to match engine's fixed batch size
padding_size = self.engine_batch_size - batch_tensor.shape[0]
# Replicate last frame to pad (cheaper than zeros)
padding = batch_tensor[-1:].repeat(padding_size, 1, 1, 1)
batch_tensor = torch.cat([batch_tensor, padding], dim=0)
logger.debug(
f"Padded batch from {actual_batch_size} to {self.engine_batch_size} frames"
)
else:
# Dynamic batching - just limit to max
if batch_tensor.shape[0] > self.batch_size:
logger.warning(
f"Batch size {batch_tensor.shape[0]} exceeds configured max {self.batch_size}"
)
batch_tensor = batch_tensor[: self.batch_size]
batch = batch[: self.batch_size]
actual_batch_size = self.batch_size
# Run Ultralytics inference
# Input should be (B, 3, H, W) in range [0, 1], RGB format
outputs = self.inference_engine.infer(
inputs={"images": batch_tensor},
conf=0.25, # Confidence threshold
iou=0.45, # NMS IoU threshold
)
# Ultralytics returns Results objects in outputs["results"]
yolo_results = outputs["results"]
# Convert Results objects to our standard format
# Only process actual batch size (ignore padded results if any)
results = []
for i in range(actual_batch_size):
batch_frame = batch[i]
yolo_result = yolo_results[i]
# Extract detections from YOLO Results object
# yolo_result.boxes.data has format: [x1, y1, x2, y2, conf, cls]
if hasattr(yolo_result, "boxes") and yolo_result.boxes is not None:
detections = yolo_result.boxes.data # Already a tensor on GPU
else:
# No detections
detections = torch.zeros((0, 6), device=batch_tensor.device)
# Apply custom postprocessing if provided
if self.postprocess_fn:
try:
# For Ultralytics, postprocess_fn might do additional filtering
# Pass the raw boxes tensor in the same format as TensorRT output
detections = self.postprocess_fn(
{
"output0": detections.unsqueeze(
0
) # Add batch dim for compatibility
}
)
except Exception as e:
logger.error(
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
)
detections = torch.zeros((0, 6), device=batch_tensor.device)
result = {
"stream_id": batch_frame.stream_id,
"timestamp": batch_frame.timestamp,
"detections": detections,
"metadata": batch_frame.metadata,
"yolo_result": yolo_result, # Keep original Results object for debugging
}
results.append(result)
return results

View file

@ -101,14 +101,15 @@ class YOLOv8Utils:
# Get output tensor (first and only output) # Get output tensor (first and only output)
output_name = list(outputs.keys())[0] output_name = list(outputs.keys())[0]
output = outputs[output_name] # (1, 84, 8400) output = outputs[output_name] # (1, 4+num_classes, 8400)
# Transpose to (1, 8400, 84) for easier processing # Transpose to (1, 8400, 4+num_classes) for easier processing
output = output.transpose(1, 2).squeeze(0) # (8400, 84) output = output.transpose(1, 2).squeeze(0) # (8400, 4+num_classes)
# Split bbox coordinates and class scores (vectorized) # Split bbox coordinates and class scores (vectorized)
# Format: [cx, cy, w, h, class_scores...]
bboxes = output[:, :4] # (8400, 4) - (cx, cy, w, h) bboxes = output[:, :4] # (8400, 4) - (cx, cy, w, h)
class_scores = output[:, 4:] # (8400, 80) class_scores = output[:, 4:] # (8400, num_classes) - dynamically sized
# Get max class score and corresponding class ID for all anchors (vectorized) # Get max class score and corresponding class ID for all anchors (vectorized)
max_scores, class_ids = torch.max(class_scores, dim=1) # (8400,), (8400,) max_scores, class_ids = torch.max(class_scores, dim=1) # (8400,), (8400,)

View file

@ -9,193 +9,25 @@ This script demonstrates:
- Automatic PT to TensorRT conversion - Automatic PT to TensorRT conversion
""" """
import time
import os import os
import torch import threading
import time
import cv2 import cv2
import numpy as np import numpy as np
import torch
from dotenv import load_dotenv from dotenv import load_dotenv
from services import ( from services import (
StreamConnectionManager,
YOLOv8Utils,
COCO_CLASSES, COCO_CLASSES,
StreamConnectionManager,
UltralyticsExporter,
YOLOv8Utils,
) )
# Load environment variables # Load environment variables
load_dotenv() load_dotenv()
def main_single_stream():
"""Single stream example with event-driven architecture."""
print("=" * 80)
print("Event-Driven GPU-Accelerated Object Tracking - Single Stream")
print("=" * 80)
# Configuration
GPU_ID = 0
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # Transparent loading: .pt, .engine, or .trt
STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
BATCH_SIZE = 4
FORCE_TIMEOUT = 0.05
ENABLE_DISPLAY = os.getenv('ENABLE_DISPLAY', 'false').lower() == 'true' # Set to 'true' to enable OpenCV display
MAX_FRAMES = int(os.getenv('MAX_FRAMES', '300')) # Stop after N frames (0 = unlimited)
print(f"\nConfiguration:")
print(f" GPU: {GPU_ID}")
print(f" Model: {MODEL_PATH}")
print(f" Stream: {STREAM_URL}")
print(f" Batch size: {BATCH_SIZE}")
print(f" Force timeout: {FORCE_TIMEOUT}s")
print(f" Display: {'Enabled' if ENABLE_DISPLAY else 'Disabled (inference only)'}")
print(f" Max frames: {MAX_FRAMES if MAX_FRAMES > 0 else 'Unlimited'}\n")
# Create StreamConnectionManager with PT conversion enabled
print("[1/3] Creating StreamConnectionManager...")
manager = StreamConnectionManager(
gpu_id=GPU_ID,
batch_size=BATCH_SIZE,
force_timeout=FORCE_TIMEOUT,
enable_pt_conversion=True # Enable PT conversion
)
print("✓ Manager created")
# Initialize with model (transparent loading - no manual parameters needed)
print("\n[2/3] Initializing model...")
print("Note: YOLO models auto-convert to native TensorRT .engine (first time only)")
print("Metadata is auto-detected from model - no manual input_shapes needed!\n")
try:
manager.initialize(
model_path=MODEL_PATH,
model_id="detector",
preprocess_fn=YOLOv8Utils.preprocess,
postprocess_fn=YOLOv8Utils.postprocess,
num_contexts=4
# Note: No pt_input_shapes or pt_precision needed for YOLO models!
)
print("✓ Manager initialized")
except Exception as e:
print(f"✗ Failed to initialize: {e}")
import traceback
traceback.print_exc()
return
# Connect stream
print("\n[3/3] Connecting to stream...")
try:
connection = manager.connect_stream(
rtsp_url=STREAM_URL,
stream_id="camera_1",
buffer_size=30
)
print(f"✓ Stream connected: camera_1")
except Exception as e:
print(f"✗ Failed to connect stream: {e}")
return
print(f"\n{'=' * 80}")
print("Event-driven tracking is running!")
print("Press Ctrl+C to stop")
print(f"{'=' * 80}\n")
# Stream results with optional OpenCV visualization
result_count = 0
start_time = time.time()
# Create window only if display is enabled
if ENABLE_DISPLAY:
cv2.namedWindow("Object Tracking", cv2.WINDOW_NORMAL)
cv2.resizeWindow("Object Tracking", 1280, 720)
try:
for result in connection.tracking_results():
result_count += 1
# Check if we've reached max frames
if MAX_FRAMES > 0 and result_count >= MAX_FRAMES:
print(f"\n✓ Reached max frames limit ({MAX_FRAMES})")
break
# OpenCV visualization (only if enabled)
if ENABLE_DISPLAY:
# Get latest frame from decoder (as CPU numpy array)
frame = connection.decoder.get_latest_frame_cpu(rgb=True)
if frame is not None:
# Convert to BGR for OpenCV
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
# Draw tracked objects
for obj in result.tracked_objects:
# Get bbox coordinates
x1, y1, x2, y2 = map(int, obj.bbox)
# Draw bounding box
cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Draw track ID and class name
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
# Draw label background
cv2.rectangle(frame_bgr, (x1, y1 - label_size[1] - 10),
(x1 + label_size[0], y1), (0, 255, 0), -1)
# Draw label text
cv2.putText(frame_bgr, label, (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
# Draw FPS and object count
elapsed = time.time() - start_time
fps = result_count / elapsed if elapsed > 0 else 0
info_text = f"FPS: {fps:.1f} | Objects: {len(result.tracked_objects)} | Frame: {result_count}"
cv2.putText(frame_bgr, info_text, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# Display frame
cv2.imshow("Object Tracking", frame_bgr)
# Check for 'q' key to quit
if cv2.waitKey(1) & 0xFF == ord('q'):
print(f"\n✓ Quit by user (pressed 'q')")
break
# Print stats every 30 results
if result_count % 30 == 0:
elapsed = time.time() - start_time
fps = result_count / elapsed if elapsed > 0 else 0
print(f"\nResults: {result_count} | FPS: {fps:.1f}")
print(f" Stream: {result.stream_id}")
print(f" Objects: {len(result.tracked_objects)}")
if result.tracked_objects:
class_counts = {}
for obj in result.tracked_objects:
class_counts[obj.class_name] = class_counts.get(obj.class_name, 0) + 1
print(f" Classes: {class_counts}")
except KeyboardInterrupt:
print(f"\n✓ Interrupted by user")
# Cleanup
print(f"\n{'=' * 80}")
print("Cleanup")
print(f"{'=' * 80}")
# Close OpenCV window if it was opened
if ENABLE_DISPLAY:
cv2.destroyAllWindows()
connection.stop()
manager.shutdown()
print("✓ Stopped")
# Final stats
elapsed = time.time() - start_time
avg_fps = result_count / elapsed if elapsed > 0 else 0
print(f"\nFinal: {result_count} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)")
def main_multi_stream(): def main_multi_stream():
"""Multi-stream example with batched inference.""" """Multi-stream example with batched inference."""
@ -206,14 +38,18 @@ def main_multi_stream():
# Configuration # Configuration
GPU_ID = 0 GPU_ID = 0
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # Transparent loading: .pt, .engine, or .trt MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # Transparent loading: .pt, .engine, or .trt
BATCH_SIZE = 16 USE_ULTRALYTICS = (
os.getenv("USE_ULTRALYTICS", "true").lower() == "true"
) # Use Ultralytics engine for YOLO
BATCH_SIZE = 2 # Reduced to 2 to avoid GPU memory issues
FORCE_TIMEOUT = 0.05 FORCE_TIMEOUT = 0.05
ENABLE_DISPLAY = os.getenv("ENABLE_DISPLAY", "true").lower() == "true"
# Load camera URLs # Load camera URLs
camera_urls = [] camera_urls = []
i = 1 i = 1
while True: while True:
url = os.getenv(f'CAMERA_URL_{i}') url = os.getenv(f"CAMERA_URL_{i}")
if url: if url:
camera_urls.append((f"camera_{i}", url)) camera_urls.append((f"camera_{i}", url))
i += 1 i += 1
@ -230,13 +66,16 @@ def main_multi_stream():
print(f" Streams: {len(camera_urls)}") print(f" Streams: {len(camera_urls)}")
print(f" Batch size: {BATCH_SIZE}\n") print(f" Batch size: {BATCH_SIZE}\n")
# Create manager with PT conversion # Create manager with backend selection
print("[1/3] Creating StreamConnectionManager...") print("[1/3] Creating StreamConnectionManager...")
backend = "ultralytics"
print(f" Backend: {backend}")
manager = StreamConnectionManager( manager = StreamConnectionManager(
gpu_id=GPU_ID, gpu_id=GPU_ID,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
force_timeout=FORCE_TIMEOUT, force_timeout=FORCE_TIMEOUT,
enable_pt_conversion=True enable_pt_conversion=True,
backend=backend,
) )
print("✓ Manager created") print("✓ Manager created")
@ -248,30 +87,52 @@ def main_multi_stream():
model_id="detector", model_id="detector",
preprocess_fn=YOLOv8Utils.preprocess, preprocess_fn=YOLOv8Utils.preprocess,
postprocess_fn=YOLOv8Utils.postprocess, postprocess_fn=YOLOv8Utils.postprocess,
num_contexts=8 num_contexts=1, # Single context to minimize GPU memory usage
# Note: No pt_input_shapes or pt_precision needed for YOLO models! # Note: No pt_input_shapes or pt_precision needed for YOLO models!
) )
print("✓ Manager initialized") print("✓ Manager initialized")
except Exception as e: except Exception as e:
print(f"✗ Failed to initialize: {e}") print(f"✗ Failed to initialize: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return return
# Connect all streams # Connect all streams in parallel using threads
print(f"\n[3/3] Connecting {len(camera_urls)} streams...") print(f"\n[3/3] Connecting {len(camera_urls)} streams in parallel...")
connections = {} connections = {}
for stream_id, rtsp_url in camera_urls: connection_threads = []
connection_results = {}
def connect_stream(stream_id, rtsp_url):
"""Thread worker to connect a single stream"""
try: try:
conn = manager.connect_stream( conn = manager.connect_stream(
rtsp_url=rtsp_url, rtsp_url=rtsp_url, stream_id=stream_id, buffer_size=3
stream_id=stream_id,
buffer_size=5
) )
connections[stream_id] = conn connection_results[stream_id] = ("success", conn)
print(f"✓ Connected: {stream_id}")
except Exception as e: except Exception as e:
print(f"✗ Failed {stream_id}: {e}") connection_results[stream_id] = ("error", str(e))
# Start all connection threads
for stream_id, rtsp_url in camera_urls:
thread = threading.Thread(
target=connect_stream, args=(stream_id, rtsp_url), daemon=True
)
thread.start()
connection_threads.append(thread)
# Wait for all connections to complete
for thread in connection_threads:
thread.join()
# Collect results
for stream_id, (status, result) in connection_results.items():
if status == "success":
connections[stream_id] = result
print(f"✓ Connected: {stream_id}")
else:
print(f"✗ Failed {stream_id}: {result}")
if not connections: if not connections:
print("No streams connected") print("No streams connected")
@ -284,10 +145,20 @@ def main_multi_stream():
print(f"{'=' * 80}\n") print(f"{'=' * 80}\n")
# Track stats # Track stats
stream_stats = {sid: {'count': 0, 'start': time.time()} for sid in connections.keys()} stream_stats = {
sid: {"count": 0, "start": time.time()} for sid in connections.keys()
}
total_results = 0 total_results = 0
start_time = time.time() start_time = time.time()
# Create windows for each stream if display enabled
if ENABLE_DISPLAY:
for stream_id in connections.keys():
cv2.namedWindow(stream_id, cv2.WINDOW_NORMAL)
cv2.resizeWindow(
stream_id, 640, 360
) # Smaller windows for multiple streams
try: try:
# Merge all result queues from all connections # Merge all result queues from all connections
import queue as queue_module import queue as queue_module
@ -306,27 +177,92 @@ def main_multi_stream():
stream_id = result.stream_id stream_id = result.stream_id
if stream_id in stream_stats: if stream_id in stream_stats:
stream_stats[stream_id]['count'] += 1 stream_stats[stream_id]["count"] += 1
# Display visualization if enabled
if ENABLE_DISPLAY:
# Get latest frame from decoder (already in CPU memory as numpy RGB)
frame_rgb = conn.decoder.get_latest_frame_cpu(rgb=True)
if frame_rgb is not None:
# Convert RGB to BGR for OpenCV
frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
# Draw bounding boxes
for obj in result.tracked_objects:
x1, y1, x2, y2 = map(int, obj.bbox)
# Draw box
cv2.rectangle(
frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2
)
# Draw label with ID and class
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
(label_w, label_h), _ = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
cv2.rectangle(
frame_bgr,
(x1, y1 - label_h - 10),
(x1 + label_w, y1),
(0, 255, 0),
-1,
)
cv2.putText(
frame_bgr,
label,
(x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 0, 0),
1,
)
# Show FPS on frame
s_elapsed = time.time() - stream_stats[stream_id]["start"]
s_fps = (
stream_stats[stream_id]["count"] / s_elapsed
if s_elapsed > 0
else 0
)
fps_text = f"{stream_id}: {s_fps:.1f} FPS | {len(result.tracked_objects)} objects"
cv2.putText(
frame_bgr,
fps_text,
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(0, 255, 0),
2,
)
# Display
cv2.imshow(stream_id, frame_bgr)
# Print stats every 100 results # Print stats every 100 results
if total_results % 100 == 0: if total_results % 100 == 0:
elapsed = time.time() - start_time elapsed = time.time() - start_time
total_fps = total_results / elapsed if elapsed > 0 else 0 total_fps = total_results / elapsed if elapsed > 0 else 0
print(f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS") print(
f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS"
)
for sid, stats in stream_stats.items(): for sid, stats in stream_stats.items():
s_elapsed = time.time() - stats['start'] s_elapsed = time.time() - stats["start"]
s_fps = stats['count'] / s_elapsed if s_elapsed > 0 else 0 s_fps = stats["count"] / s_elapsed if s_elapsed > 0 else 0
print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)") print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)")
except queue_module.Empty: except queue_module.Empty:
continue continue
# Process OpenCV events to keep windows responsive
if ENABLE_DISPLAY:
cv2.waitKey(1)
# Small sleep if no results to avoid busy loop # Small sleep if no results to avoid busy loop
if not got_result: if not got_result:
time.sleep(0.01) time.sleep(0.01)
except KeyboardInterrupt: except KeyboardInterrupt:
print(f"\n✓ Interrupted") print(f"\n✓ Interrupted")
@ -335,6 +271,10 @@ def main_multi_stream():
print("Cleanup") print("Cleanup")
print(f"{'=' * 80}") print(f"{'=' * 80}")
# Close OpenCV windows if they were opened
if ENABLE_DISPLAY:
cv2.destroyAllWindows()
for conn in connections.values(): for conn in connections.values():
conn.stop() conn.stop()
manager.shutdown() manager.shutdown()
@ -347,8 +287,4 @@ def main_multi_stream():
if __name__ == "__main__": if __name__ == "__main__":
import sys main_multi_stream()
if len(sys.argv) > 1 and sys.argv[1] == "single":
main_single_stream()
else:
main_multi_stream()