ultralytic export
This commit is contained in:
parent
bf7b68edb1
commit
fdaeb9981c
14 changed files with 2241 additions and 507 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -5,4 +5,5 @@ __pycache__/
|
||||||
.claude
|
.claude
|
||||||
/models/
|
/models/
|
||||||
/tracked_objects.json
|
/tracked_objects.json
|
||||||
.trt_cache
|
.trt_cache
|
||||||
|
.ultralytics_cache
|
||||||
|
|
|
||||||
|
|
@ -10,4 +10,8 @@
|
||||||
|
|
||||||
- Buffer should flush after TARGET_FRAME_INTERVAL_MS
|
- Buffer should flush after TARGET_FRAME_INTERVAL_MS
|
||||||
|
|
||||||
- Blurry asyncio archtecture, require documentations
|
- Blurry asyncio archtecture, require documentations
|
||||||
|
|
||||||
|
- Each engine cache to its own random unconcentrated folder
|
||||||
|
|
||||||
|
- Workspace for YOLO is fixed to 4GB, why is that?, what is it?
|
||||||
|
|
|
||||||
|
|
@ -21,10 +21,11 @@ Usage:
|
||||||
import argparse
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple, List, Optional
|
from typing import List, Optional, Tuple
|
||||||
import torch
|
|
||||||
import tensorrt as trt
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import tensorrt as trt
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class TensorRTConverter:
|
class TensorRTConverter:
|
||||||
|
|
@ -39,7 +40,7 @@ class TensorRTConverter:
|
||||||
verbose: Enable verbose logging
|
verbose: Enable verbose logging
|
||||||
"""
|
"""
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
self.device = torch.device(f'cuda:{gpu_id}')
|
self.device = torch.device(f"cuda:{gpu_id}")
|
||||||
|
|
||||||
# TensorRT logger
|
# TensorRT logger
|
||||||
log_level = trt.Logger.VERBOSE if verbose else trt.Logger.WARNING
|
log_level = trt.Logger.VERBOSE if verbose else trt.Logger.WARNING
|
||||||
|
|
@ -71,13 +72,15 @@ class TensorRTConverter:
|
||||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||||
|
|
||||||
# Load model (weights_only=False for models with custom classes)
|
# Load model (weights_only=False for models with custom classes)
|
||||||
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
|
checkpoint = torch.load(
|
||||||
|
model_path, map_location=self.device, weights_only=False
|
||||||
|
)
|
||||||
|
|
||||||
# Handle different checkpoint formats
|
# Handle different checkpoint formats
|
||||||
if isinstance(checkpoint, dict):
|
if isinstance(checkpoint, dict):
|
||||||
if 'model' in checkpoint:
|
if "model" in checkpoint:
|
||||||
model = checkpoint['model']
|
model = checkpoint["model"]
|
||||||
elif 'state_dict' in checkpoint:
|
elif "state_dict" in checkpoint:
|
||||||
# Need model architecture - this is a limitation
|
# Need model architecture - this is a limitation
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Checkpoint contains only state_dict. "
|
"Checkpoint contains only state_dict. "
|
||||||
|
|
@ -95,9 +98,15 @@ class TensorRTConverter:
|
||||||
print(f"✓ Model loaded successfully")
|
print(f"✓ Model loaded successfully")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def export_to_onnx(self, model: torch.nn.Module, input_shape: Tuple[int, ...],
|
def export_to_onnx(
|
||||||
onnx_path: str, dynamic_batch: bool = False,
|
self,
|
||||||
input_names: List[str] = None, output_names: List[str] = None) -> str:
|
model: torch.nn.Module,
|
||||||
|
input_shape: Tuple[int, ...],
|
||||||
|
onnx_path: str,
|
||||||
|
dynamic_batch: bool = False,
|
||||||
|
input_names: List[str] = None,
|
||||||
|
output_names: List[str] = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Export PyTorch model to ONNX format (intermediate step).
|
Export PyTorch model to ONNX format (intermediate step).
|
||||||
|
|
||||||
|
|
@ -118,9 +127,9 @@ class TensorRTConverter:
|
||||||
|
|
||||||
# Default names
|
# Default names
|
||||||
if input_names is None:
|
if input_names is None:
|
||||||
input_names = ['input']
|
input_names = ["input"]
|
||||||
if output_names is None:
|
if output_names is None:
|
||||||
output_names = ['output']
|
output_names = ["output"]
|
||||||
|
|
||||||
# Create dummy input
|
# Create dummy input
|
||||||
dummy_input = torch.randn(*input_shape, device=self.device)
|
dummy_input = torch.randn(*input_shape, device=self.device)
|
||||||
|
|
@ -128,10 +137,7 @@ class TensorRTConverter:
|
||||||
# Dynamic axes configuration
|
# Dynamic axes configuration
|
||||||
dynamic_axes = None
|
dynamic_axes = None
|
||||||
if dynamic_batch:
|
if dynamic_batch:
|
||||||
dynamic_axes = {
|
dynamic_axes = {input_names[0]: {0: "batch"}, output_names[0]: {0: "batch"}}
|
||||||
input_names[0]: {0: 'batch'},
|
|
||||||
output_names[0]: {0: 'batch'}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Export to ONNX
|
# Export to ONNX
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
|
|
@ -143,16 +149,22 @@ class TensorRTConverter:
|
||||||
dynamic_axes=dynamic_axes,
|
dynamic_axes=dynamic_axes,
|
||||||
opset_version=17, # Use recent ONNX opset
|
opset_version=17, # Use recent ONNX opset
|
||||||
do_constant_folding=True,
|
do_constant_folding=True,
|
||||||
verbose=False
|
verbose=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"✓ ONNX model exported to {onnx_path}")
|
print(f"✓ ONNX model exported to {onnx_path}")
|
||||||
return onnx_path
|
return onnx_path
|
||||||
|
|
||||||
def build_tensorrt_engine_from_onnx(self, onnx_path: str, engine_path: str,
|
def build_tensorrt_engine_from_onnx(
|
||||||
fp16: bool = False, int8: bool = False,
|
self,
|
||||||
max_workspace_size: int = 4,
|
onnx_path: str,
|
||||||
min_batch: int = 1, opt_batch: int = 1, max_batch: int = 1) -> str:
|
engine_path: str,
|
||||||
|
fp16: bool = False,
|
||||||
|
int8: bool = False,
|
||||||
|
min_batch: int = 1,
|
||||||
|
opt_batch: int = 1,
|
||||||
|
max_batch: int = 1,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Build TensorRT engine from ONNX model.
|
Build TensorRT engine from ONNX model.
|
||||||
|
|
||||||
|
|
@ -161,7 +173,6 @@ class TensorRTConverter:
|
||||||
engine_path: Output path for TensorRT engine
|
engine_path: Output path for TensorRT engine
|
||||||
fp16: Enable FP16 precision
|
fp16: Enable FP16 precision
|
||||||
int8: Enable INT8 precision (requires calibration)
|
int8: Enable INT8 precision (requires calibration)
|
||||||
max_workspace_size: Maximum workspace size in GB
|
|
||||||
min_batch: Minimum batch size for optimization
|
min_batch: Minimum batch size for optimization
|
||||||
opt_batch: Optimal batch size for optimization
|
opt_batch: Optimal batch size for optimization
|
||||||
max_batch: Maximum batch size for optimization
|
max_batch: Maximum batch size for optimization
|
||||||
|
|
@ -171,7 +182,6 @@ class TensorRTConverter:
|
||||||
"""
|
"""
|
||||||
print(f"\nBuilding TensorRT engine from ONNX...")
|
print(f"\nBuilding TensorRT engine from ONNX...")
|
||||||
print(f"Precision: FP{'16' if fp16 else '32'}{' + INT8' if int8 else ''}")
|
print(f"Precision: FP{'16' if fp16 else '32'}{' + INT8' if int8 else ''}")
|
||||||
print(f"Workspace size: {max_workspace_size} GB")
|
|
||||||
|
|
||||||
# Create builder and network
|
# Create builder and network
|
||||||
builder = trt.Builder(self.logger)
|
builder = trt.Builder(self.logger)
|
||||||
|
|
@ -182,7 +192,7 @@ class TensorRTConverter:
|
||||||
|
|
||||||
# Parse ONNX model
|
# Parse ONNX model
|
||||||
print(f"Loading ONNX file from {onnx_path}...")
|
print(f"Loading ONNX file from {onnx_path}...")
|
||||||
with open(onnx_path, 'rb') as f:
|
with open(onnx_path, "rb") as f:
|
||||||
if not parser.parse(f.read()):
|
if not parser.parse(f.read()):
|
||||||
print("ERROR: Failed to parse the ONNX file:")
|
print("ERROR: Failed to parse the ONNX file:")
|
||||||
for error in range(parser.num_errors):
|
for error in range(parser.num_errors):
|
||||||
|
|
@ -206,12 +216,6 @@ class TensorRTConverter:
|
||||||
# Create builder config
|
# Create builder config
|
||||||
config = builder.create_builder_config()
|
config = builder.create_builder_config()
|
||||||
|
|
||||||
# Set workspace size
|
|
||||||
config.set_memory_pool_limit(
|
|
||||||
trt.MemoryPoolType.WORKSPACE,
|
|
||||||
max_workspace_size * (1 << 30) # GB to bytes
|
|
||||||
)
|
|
||||||
|
|
||||||
# Enable precision modes
|
# Enable precision modes
|
||||||
if fp16:
|
if fp16:
|
||||||
if not builder.platform_has_fast_fp16:
|
if not builder.platform_has_fast_fp16:
|
||||||
|
|
@ -226,7 +230,9 @@ class TensorRTConverter:
|
||||||
else:
|
else:
|
||||||
config.set_flag(trt.BuilderFlag.INT8)
|
config.set_flag(trt.BuilderFlag.INT8)
|
||||||
print("✓ INT8 mode enabled")
|
print("✓ INT8 mode enabled")
|
||||||
print("Note: INT8 calibration not implemented. Results may be suboptimal.")
|
print(
|
||||||
|
"Note: INT8 calibration not implemented. Results may be suboptimal."
|
||||||
|
)
|
||||||
|
|
||||||
# Set optimization profile for dynamic shapes
|
# Set optimization profile for dynamic shapes
|
||||||
if max_batch > 1 or min_batch != max_batch:
|
if max_batch > 1 or min_batch != max_batch:
|
||||||
|
|
@ -260,7 +266,7 @@ class TensorRTConverter:
|
||||||
|
|
||||||
# Save engine to file
|
# Save engine to file
|
||||||
print(f"Saving engine to {engine_path}...")
|
print(f"Saving engine to {engine_path}...")
|
||||||
with open(engine_path, 'wb') as f:
|
with open(engine_path, "wb") as f:
|
||||||
f.write(serialized_engine)
|
f.write(serialized_engine)
|
||||||
|
|
||||||
# Get file size
|
# Get file size
|
||||||
|
|
@ -270,15 +276,19 @@ class TensorRTConverter:
|
||||||
|
|
||||||
return engine_path
|
return engine_path
|
||||||
|
|
||||||
def convert(self, model_path: str, output_path: str,
|
def convert(
|
||||||
input_shape: Tuple[int, ...] = (1, 3, 640, 640),
|
self,
|
||||||
fp16: bool = False, int8: bool = False,
|
model_path: str,
|
||||||
dynamic_batch: bool = False,
|
output_path: str,
|
||||||
max_batch: int = 16,
|
input_shape: Tuple[int, ...] = (1, 3, 640, 640),
|
||||||
workspace_size: int = 4,
|
fp16: bool = False,
|
||||||
input_names: List[str] = None,
|
int8: bool = False,
|
||||||
output_names: List[str] = None,
|
dynamic_batch: bool = False,
|
||||||
keep_onnx: bool = False) -> str:
|
max_batch: int = 16,
|
||||||
|
input_names: List[str] = None,
|
||||||
|
output_names: List[str] = None,
|
||||||
|
keep_onnx: bool = False,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert PyTorch or ONNX model to TensorRT engine.
|
Convert PyTorch or ONNX model to TensorRT engine.
|
||||||
|
|
||||||
|
|
@ -290,7 +300,6 @@ class TensorRTConverter:
|
||||||
int8: Enable INT8 precision
|
int8: Enable INT8 precision
|
||||||
dynamic_batch: Enable dynamic batch size
|
dynamic_batch: Enable dynamic batch size
|
||||||
max_batch: Maximum batch size (for dynamic batching)
|
max_batch: Maximum batch size (for dynamic batching)
|
||||||
workspace_size: TensorRT workspace size in GB
|
|
||||||
input_names: Custom input names (for PyTorch export)
|
input_names: Custom input names (for PyTorch export)
|
||||||
output_names: Custom output names (for PyTorch export)
|
output_names: Custom output names (for PyTorch export)
|
||||||
keep_onnx: Keep intermediate ONNX file
|
keep_onnx: Keep intermediate ONNX file
|
||||||
|
|
@ -304,7 +313,7 @@ class TensorRTConverter:
|
||||||
|
|
||||||
# Check if input is already ONNX
|
# Check if input is already ONNX
|
||||||
model_path_obj = Path(model_path)
|
model_path_obj = Path(model_path)
|
||||||
is_onnx = model_path_obj.suffix.lower() == '.onnx'
|
is_onnx = model_path_obj.suffix.lower() == ".onnx"
|
||||||
|
|
||||||
if is_onnx:
|
if is_onnx:
|
||||||
# Direct ONNX to TensorRT conversion
|
# Direct ONNX to TensorRT conversion
|
||||||
|
|
@ -319,10 +328,9 @@ class TensorRTConverter:
|
||||||
engine_path=output_path,
|
engine_path=output_path,
|
||||||
fp16=fp16,
|
fp16=fp16,
|
||||||
int8=int8,
|
int8=int8,
|
||||||
max_workspace_size=workspace_size,
|
|
||||||
min_batch=min_batch,
|
min_batch=min_batch,
|
||||||
opt_batch=opt_batch,
|
opt_batch=opt_batch,
|
||||||
max_batch=max_batch_size
|
max_batch=max_batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"\n{'=' * 80}")
|
print(f"\n{'=' * 80}")
|
||||||
|
|
@ -350,7 +358,7 @@ class TensorRTConverter:
|
||||||
onnx_path=onnx_path,
|
onnx_path=onnx_path,
|
||||||
dynamic_batch=dynamic_batch,
|
dynamic_batch=dynamic_batch,
|
||||||
input_names=input_names,
|
input_names=input_names,
|
||||||
output_names=output_names
|
output_names=output_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 3: Build TensorRT engine
|
# Step 3: Build TensorRT engine
|
||||||
|
|
@ -363,10 +371,9 @@ class TensorRTConverter:
|
||||||
engine_path=output_path,
|
engine_path=output_path,
|
||||||
fp16=fp16,
|
fp16=fp16,
|
||||||
int8=int8,
|
int8=int8,
|
||||||
max_workspace_size=workspace_size,
|
|
||||||
min_batch=min_batch,
|
min_batch=min_batch,
|
||||||
opt_batch=opt_batch,
|
opt_batch=opt_batch,
|
||||||
max_batch=max_batch_size
|
max_batch=max_batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"\n{'=' * 80}")
|
print(f"\n{'=' * 80}")
|
||||||
|
|
@ -392,7 +399,7 @@ class TensorRTConverter:
|
||||||
def parse_shape(shape_str: str) -> Tuple[int, ...]:
|
def parse_shape(shape_str: str) -> Tuple[int, ...]:
|
||||||
"""Parse shape string like '1,3,640,640' to tuple"""
|
"""Parse shape string like '1,3,640,640' to tuple"""
|
||||||
try:
|
try:
|
||||||
return tuple(int(x) for x in shape_str.split(','))
|
return tuple(int(x) for x in shape_str.split(","))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise argparse.ArgumentTypeError(
|
raise argparse.ArgumentTypeError(
|
||||||
f"Invalid shape format: {shape_str}. Expected format: 1,3,640,640"
|
f"Invalid shape format: {shape_str}. Expected format: 1,3,640,640"
|
||||||
|
|
@ -426,97 +433,82 @@ Examples:
|
||||||
# Keep intermediate ONNX file for debugging
|
# Keep intermediate ONNX file for debugging
|
||||||
python convert_pt_to_tensorrt.py --model model.pt --output model.trt \\
|
python convert_pt_to_tensorrt.py --model model.pt --output model.trt \\
|
||||||
--keep-onnx
|
--keep-onnx
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Required arguments
|
# Required arguments
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--model', '-m',
|
"--model",
|
||||||
|
"-m",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help='Path to PyTorch model file (.pt or .pth)'
|
help="Path to PyTorch model file (.pt or .pth)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--output', '-o',
|
"--output",
|
||||||
|
"-o",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help='Output path for TensorRT engine (.trt or .engine)'
|
help="Output path for TensorRT engine (.trt or .engine)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optional arguments
|
# Optional arguments
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--input-shape', '-s',
|
"--input-shape",
|
||||||
|
"-s",
|
||||||
type=parse_shape,
|
type=parse_shape,
|
||||||
default=(1, 3, 640, 640),
|
default=(1, 3, 640, 640),
|
||||||
help='Input tensor shape as B,C,H,W (default: 1,3,640,640)'
|
help="Input tensor shape as B,C,H,W (default: 1,3,640,640)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--fp16',
|
"--fp16",
|
||||||
action='store_true',
|
action="store_true",
|
||||||
help='Enable FP16 precision (faster inference, slightly lower accuracy)'
|
help="Enable FP16 precision (faster inference, slightly lower accuracy)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--int8',
|
"--int8",
|
||||||
action='store_true',
|
action="store_true",
|
||||||
help='Enable INT8 precision (fastest, requires calibration)'
|
help="Enable INT8 precision (fastest, requires calibration)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--dynamic-batch',
|
"--dynamic-batch", action="store_true", help="Enable dynamic batch size support"
|
||||||
action='store_true',
|
|
||||||
help='Enable dynamic batch size support'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--max-batch',
|
"--max-batch",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help='Maximum batch size for dynamic batching (default: 16)'
|
help="Maximum batch size for dynamic batching (default: 16)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument("--gpu", type=int, default=0, help="GPU device ID (default: 0)")
|
||||||
'--workspace-size',
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help='TensorRT workspace size in GB (default: 4)'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--gpu',
|
"--input-names",
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help='GPU device ID (default: 0)'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--input-names',
|
|
||||||
type=str,
|
type=str,
|
||||||
nargs='+',
|
nargs="+",
|
||||||
default=None,
|
default=None,
|
||||||
help='Custom input tensor names (default: ["input"])'
|
help='Custom input tensor names (default: ["input"])',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--output-names',
|
"--output-names",
|
||||||
type=str,
|
type=str,
|
||||||
nargs='+',
|
nargs="+",
|
||||||
default=None,
|
default=None,
|
||||||
help='Custom output tensor names (default: ["output"])'
|
help='Custom output tensor names (default: ["output"])',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--keep-onnx',
|
"--keep-onnx", action="store_true", help="Keep intermediate ONNX file"
|
||||||
action='store_true',
|
|
||||||
help='Keep intermediate ONNX file'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--verbose', '-v',
|
"--verbose", "-v", action="store_true", help="Enable verbose logging"
|
||||||
action='store_true',
|
|
||||||
help='Enable verbose logging'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
@ -542,10 +534,9 @@ Examples:
|
||||||
int8=args.int8,
|
int8=args.int8,
|
||||||
dynamic_batch=args.dynamic_batch,
|
dynamic_batch=args.dynamic_batch,
|
||||||
max_batch=args.max_batch,
|
max_batch=args.max_batch,
|
||||||
workspace_size=args.workspace_size,
|
|
||||||
input_names=args.input_names,
|
input_names=args.input_names,
|
||||||
output_names=args.output_names,
|
output_names=args.output_names,
|
||||||
keep_onnx=args.keep_onnx
|
keep_onnx=args.keep_onnx,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\n✓ Conversion successful!")
|
print("\n✓ Conversion successful!")
|
||||||
|
|
@ -554,6 +545,7 @@ Examples:
|
||||||
print(f"\n✗ Conversion failed: {e}")
|
print(f"\n✗ Conversion failed: {e}")
|
||||||
if args.verbose:
|
if args.verbose:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,38 +2,67 @@
|
||||||
Services package for RTSP stream processing with GPU acceleration.
|
Services package for RTSP stream processing with GPU acceleration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .stream_decoder import StreamDecoderFactory, StreamDecoder, ConnectionStatus
|
from .base_model_controller import BaseModelController, BatchFrame, BufferState
|
||||||
|
from .inference_engine import (
|
||||||
|
BackendType,
|
||||||
|
EngineMetadata,
|
||||||
|
IInferenceEngine,
|
||||||
|
NativeTensorRTEngine,
|
||||||
|
UltralyticsEngine,
|
||||||
|
create_engine,
|
||||||
|
)
|
||||||
from .jpeg_encoder import JPEGEncoderFactory, encode_frame_to_jpeg
|
from .jpeg_encoder import JPEGEncoderFactory, encode_frame_to_jpeg
|
||||||
from .model_repository import TensorRTModelRepository, ModelMetadata, ExecutionContext, SharedEngine
|
from .model_repository import (
|
||||||
from .tracking_controller import ObjectTracker, TrackedObject, Detection
|
ExecutionContext,
|
||||||
from .yolo import YOLOv8Utils, COCO_CLASSES
|
ModelMetadata,
|
||||||
from .model_controller import ModelController, BatchFrame, BufferState
|
SharedEngine,
|
||||||
from .stream_connection_manager import StreamConnectionManager, StreamConnection, TrackingResult
|
TensorRTModelRepository,
|
||||||
|
)
|
||||||
|
from .modelstorage import FileModelStorage, IModelStorage
|
||||||
from .pt_converter import PTConverter
|
from .pt_converter import PTConverter
|
||||||
from .modelstorage import IModelStorage, FileModelStorage
|
from .stream_connection_manager import (
|
||||||
|
StreamConnection,
|
||||||
|
StreamConnectionManager,
|
||||||
|
TrackingResult,
|
||||||
|
)
|
||||||
|
from .stream_decoder import ConnectionStatus, StreamDecoder, StreamDecoderFactory
|
||||||
|
from .tensorrt_model_controller import TensorRTModelController
|
||||||
|
from .tracking_controller import Detection, ObjectTracker, TrackedObject
|
||||||
|
from .ultralytics_exporter import UltralyticsExporter
|
||||||
|
from .ultralytics_model_controller import UltralyticsModelController
|
||||||
|
from .yolo import COCO_CLASSES, YOLOv8Utils
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'StreamDecoderFactory',
|
"StreamDecoderFactory",
|
||||||
'StreamDecoder',
|
"StreamDecoder",
|
||||||
'ConnectionStatus',
|
"ConnectionStatus",
|
||||||
'JPEGEncoderFactory',
|
"JPEGEncoderFactory",
|
||||||
'encode_frame_to_jpeg',
|
"encode_frame_to_jpeg",
|
||||||
'TensorRTModelRepository',
|
"TensorRTModelRepository",
|
||||||
'ModelMetadata',
|
"ModelMetadata",
|
||||||
'ExecutionContext',
|
"ExecutionContext",
|
||||||
'SharedEngine',
|
"SharedEngine",
|
||||||
'ObjectTracker',
|
"ObjectTracker",
|
||||||
'TrackedObject',
|
"TrackedObject",
|
||||||
'Detection',
|
"Detection",
|
||||||
'YOLOv8Utils',
|
"YOLOv8Utils",
|
||||||
'COCO_CLASSES',
|
"COCO_CLASSES",
|
||||||
'ModelController',
|
"BaseModelController",
|
||||||
'BatchFrame',
|
"TensorRTModelController",
|
||||||
'BufferState',
|
"UltralyticsModelController",
|
||||||
'StreamConnectionManager',
|
"BatchFrame",
|
||||||
'StreamConnection',
|
"BufferState",
|
||||||
'TrackingResult',
|
"StreamConnectionManager",
|
||||||
'PTConverter',
|
"StreamConnection",
|
||||||
'IModelStorage',
|
"TrackingResult",
|
||||||
'FileModelStorage',
|
"PTConverter",
|
||||||
|
"IModelStorage",
|
||||||
|
"FileModelStorage",
|
||||||
|
"IInferenceEngine",
|
||||||
|
"NativeTensorRTEngine",
|
||||||
|
"UltralyticsEngine",
|
||||||
|
"EngineMetadata",
|
||||||
|
"BackendType",
|
||||||
|
"create_engine",
|
||||||
|
"UltralyticsExporter",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
324
services/base_model_controller.py
Normal file
324
services/base_model_controller.py
Normal file
|
|
@ -0,0 +1,324 @@
|
||||||
|
"""
|
||||||
|
Base Model Controller - Abstract base class for batched inference controllers.
|
||||||
|
|
||||||
|
Provides ping-pong buffer architecture with force-switch timeout mechanism.
|
||||||
|
Implementations handle backend-specific inference (TensorRT, Ultralytics, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchFrame:
|
||||||
|
"""Represents a frame in the batch buffer"""
|
||||||
|
|
||||||
|
stream_id: str
|
||||||
|
frame: torch.Tensor # GPU tensor (3, H, W)
|
||||||
|
timestamp: float
|
||||||
|
metadata: Dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class BufferState(Enum):
|
||||||
|
"""State of a ping-pong buffer"""
|
||||||
|
|
||||||
|
IDLE = "idle"
|
||||||
|
FILLING = "filling"
|
||||||
|
PROCESSING = "processing"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModelController(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for batched inference with ping-pong buffers.
|
||||||
|
|
||||||
|
This controller accumulates frames from multiple streams into batches,
|
||||||
|
processes them through an inference backend, and routes results back to
|
||||||
|
stream-specific callbacks.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Ping-pong circular buffers (BufferA/BufferB)
|
||||||
|
- Force-switch timeout to prevent batch starvation
|
||||||
|
- Event-driven processing with callbacks
|
||||||
|
- Thread-safe frame submission
|
||||||
|
|
||||||
|
Subclasses must implement:
|
||||||
|
- _run_batch_inference(): Backend-specific inference logic
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
batch_size: int = 16,
|
||||||
|
force_timeout: float = 0.05,
|
||||||
|
preprocess_fn: Optional[Callable] = None,
|
||||||
|
postprocess_fn: Optional[Callable] = None,
|
||||||
|
):
|
||||||
|
self.model_id = model_id
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.force_timeout = force_timeout
|
||||||
|
self.preprocess_fn = preprocess_fn
|
||||||
|
self.postprocess_fn = postprocess_fn
|
||||||
|
|
||||||
|
# Ping-pong buffers
|
||||||
|
self.buffer_a: List[BatchFrame] = []
|
||||||
|
self.buffer_b: List[BatchFrame] = []
|
||||||
|
|
||||||
|
# Buffer states
|
||||||
|
self.active_buffer = "A"
|
||||||
|
self.buffer_a_state = BufferState.IDLE
|
||||||
|
self.buffer_b_state = BufferState.IDLE
|
||||||
|
|
||||||
|
# Threading coordination
|
||||||
|
self.buffer_lock = threading.RLock()
|
||||||
|
self.last_submit_time = time.time()
|
||||||
|
|
||||||
|
# Threads
|
||||||
|
self.timeout_thread: Optional[threading.Thread] = None
|
||||||
|
self.processor_threads: Dict[str, threading.Thread] = {}
|
||||||
|
self.running = False
|
||||||
|
self.stop_event = threading.Event()
|
||||||
|
|
||||||
|
# Result callbacks (stream_id -> callback)
|
||||||
|
self.result_callbacks: Dict[str, Callable] = {}
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
self.total_frames_processed = 0
|
||||||
|
self.total_batches_processed = 0
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Start the controller background threads"""
|
||||||
|
if self.running:
|
||||||
|
logger.warning("ModelController already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.running = True
|
||||||
|
self.stop_event.clear()
|
||||||
|
|
||||||
|
# Start timeout monitor thread
|
||||||
|
self.timeout_thread = threading.Thread(
|
||||||
|
target=self._timeout_monitor, daemon=True
|
||||||
|
)
|
||||||
|
self.timeout_thread.start()
|
||||||
|
|
||||||
|
# Start processor threads for each buffer
|
||||||
|
self.processor_threads["A"] = threading.Thread(
|
||||||
|
target=self._batch_processor, args=("A",), daemon=True
|
||||||
|
)
|
||||||
|
self.processor_threads["B"] = threading.Thread(
|
||||||
|
target=self._batch_processor, args=("B",), daemon=True
|
||||||
|
)
|
||||||
|
self.processor_threads["A"].start()
|
||||||
|
self.processor_threads["B"].start()
|
||||||
|
|
||||||
|
logger.info(f"{self.__class__.__name__} started")
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stop the controller and cleanup"""
|
||||||
|
if not self.running:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Stopping {self.__class__.__name__}...")
|
||||||
|
self.running = False
|
||||||
|
self.stop_event.set()
|
||||||
|
|
||||||
|
# Wait for threads to finish
|
||||||
|
if self.timeout_thread and self.timeout_thread.is_alive():
|
||||||
|
self.timeout_thread.join(timeout=2.0)
|
||||||
|
|
||||||
|
for thread in self.processor_threads.values():
|
||||||
|
if thread and thread.is_alive():
|
||||||
|
thread.join(timeout=2.0)
|
||||||
|
|
||||||
|
# Process any remaining frames
|
||||||
|
self._process_remaining_buffers()
|
||||||
|
logger.info(f"{self.__class__.__name__} stopped")
|
||||||
|
|
||||||
|
def register_callback(self, stream_id: str, callback: Callable):
|
||||||
|
"""Register a callback for inference results from a stream"""
|
||||||
|
self.result_callbacks[stream_id] = callback
|
||||||
|
logger.debug(f"Registered callback for stream: {stream_id}")
|
||||||
|
|
||||||
|
def unregister_callback(self, stream_id: str):
|
||||||
|
"""Unregister a stream callback"""
|
||||||
|
self.result_callbacks.pop(stream_id, None)
|
||||||
|
logger.debug(f"Unregistered callback for stream: {stream_id}")
|
||||||
|
|
||||||
|
def submit_frame(
|
||||||
|
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
|
||||||
|
):
|
||||||
|
"""Submit a frame for batched inference"""
|
||||||
|
with self.buffer_lock:
|
||||||
|
batch_frame = BatchFrame(
|
||||||
|
stream_id=stream_id,
|
||||||
|
frame=frame,
|
||||||
|
timestamp=time.time(),
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to active buffer
|
||||||
|
if self.active_buffer == "A":
|
||||||
|
self.buffer_a.append(batch_frame)
|
||||||
|
self.buffer_a_state = BufferState.FILLING
|
||||||
|
buffer_size = len(self.buffer_a)
|
||||||
|
else:
|
||||||
|
self.buffer_b.append(batch_frame)
|
||||||
|
self.buffer_b_state = BufferState.FILLING
|
||||||
|
buffer_size = len(self.buffer_b)
|
||||||
|
|
||||||
|
self.last_submit_time = time.time()
|
||||||
|
|
||||||
|
# Check if we should immediately swap (batch full)
|
||||||
|
if buffer_size >= self.batch_size:
|
||||||
|
self._try_swap_buffers()
|
||||||
|
|
||||||
|
def _timeout_monitor(self):
|
||||||
|
"""Monitor force-switch timeout"""
|
||||||
|
while self.running and not self.stop_event.wait(0.01):
|
||||||
|
with self.buffer_lock:
|
||||||
|
time_since_submit = time.time() - self.last_submit_time
|
||||||
|
|
||||||
|
if time_since_submit >= self.force_timeout:
|
||||||
|
active_buffer = (
|
||||||
|
self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
||||||
|
)
|
||||||
|
if len(active_buffer) > 0:
|
||||||
|
self._try_swap_buffers()
|
||||||
|
|
||||||
|
def _try_swap_buffers(self):
|
||||||
|
"""Attempt to swap ping-pong buffers (called with buffer_lock held)"""
|
||||||
|
inactive_state = (
|
||||||
|
self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
|
||||||
|
)
|
||||||
|
|
||||||
|
if inactive_state != BufferState.PROCESSING:
|
||||||
|
old_active = self.active_buffer
|
||||||
|
self.active_buffer = "B" if old_active == "A" else "A"
|
||||||
|
|
||||||
|
if old_active == "A":
|
||||||
|
self.buffer_a_state = BufferState.PROCESSING
|
||||||
|
buffer_size = len(self.buffer_a)
|
||||||
|
else:
|
||||||
|
self.buffer_b_state = BufferState.PROCESSING
|
||||||
|
buffer_size = len(self.buffer_b)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _batch_processor(self, buffer_name: str):
|
||||||
|
"""Background thread that processes a specific buffer when available"""
|
||||||
|
while self.running and not self.stop_event.is_set():
|
||||||
|
time.sleep(0.001)
|
||||||
|
|
||||||
|
with self.buffer_lock:
|
||||||
|
if buffer_name == "A":
|
||||||
|
should_process = self.buffer_a_state == BufferState.PROCESSING
|
||||||
|
else:
|
||||||
|
should_process = self.buffer_b_state == BufferState.PROCESSING
|
||||||
|
|
||||||
|
if should_process:
|
||||||
|
self._process_buffer(buffer_name)
|
||||||
|
|
||||||
|
def _process_buffer(self, buffer_name: str):
|
||||||
|
"""Process a buffer through inference"""
|
||||||
|
# Extract buffer to process
|
||||||
|
with self.buffer_lock:
|
||||||
|
if buffer_name == "A":
|
||||||
|
batch = self.buffer_a.copy()
|
||||||
|
self.buffer_a.clear()
|
||||||
|
else:
|
||||||
|
batch = self.buffer_b.copy()
|
||||||
|
self.buffer_b.clear()
|
||||||
|
|
||||||
|
if len(batch) == 0:
|
||||||
|
with self.buffer_lock:
|
||||||
|
if buffer_name == "A":
|
||||||
|
self.buffer_a_state = BufferState.IDLE
|
||||||
|
else:
|
||||||
|
self.buffer_b_state = BufferState.IDLE
|
||||||
|
return
|
||||||
|
|
||||||
|
# Process batch (outside lock to allow concurrent submissions)
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
results = self._run_batch_inference(batch)
|
||||||
|
inference_time = time.time() - start_time
|
||||||
|
|
||||||
|
self.total_frames_processed += len(batch)
|
||||||
|
self.total_batches_processed += 1
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms "
|
||||||
|
f"({inference_time * 1000 / len(batch):.2f}ms per frame)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit results to callbacks
|
||||||
|
for batch_frame, result in zip(batch, results):
|
||||||
|
callback = self.result_callbacks.get(batch_frame.stream_id)
|
||||||
|
if callback:
|
||||||
|
try:
|
||||||
|
callback(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error in callback for {batch_frame.stream_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing batch: {e}", exc_info=True)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
with self.buffer_lock:
|
||||||
|
if buffer_name == "A":
|
||||||
|
self.buffer_a_state = BufferState.IDLE
|
||||||
|
else:
|
||||||
|
self.buffer_b_state = BufferState.IDLE
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Run inference on a batch of frames (backend-specific).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: List of BatchFrame objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of detection results (one per frame)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _process_remaining_buffers(self):
|
||||||
|
"""Process any remaining frames in buffers during shutdown"""
|
||||||
|
if len(self.buffer_a) > 0:
|
||||||
|
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
|
||||||
|
self._process_buffer("A")
|
||||||
|
if len(self.buffer_b) > 0:
|
||||||
|
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
|
||||||
|
self._process_buffer("B")
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get current buffer statistics"""
|
||||||
|
return {
|
||||||
|
"active_buffer": self.active_buffer,
|
||||||
|
"buffer_a_size": len(self.buffer_a),
|
||||||
|
"buffer_b_size": len(self.buffer_b),
|
||||||
|
"buffer_a_state": self.buffer_a_state.value,
|
||||||
|
"buffer_b_state": self.buffer_b_state.value,
|
||||||
|
"registered_streams": len(self.result_callbacks),
|
||||||
|
"total_frames_processed": self.total_frames_processed,
|
||||||
|
"total_batches_processed": self.total_batches_processed,
|
||||||
|
"avg_batch_size": (
|
||||||
|
self.total_frames_processed / self.total_batches_processed
|
||||||
|
if self.total_batches_processed > 0
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
}
|
||||||
635
services/inference_engine.py
Normal file
635
services/inference_engine.py
Normal file
|
|
@ -0,0 +1,635 @@
|
||||||
|
"""
|
||||||
|
Inference Engine Abstraction Layer
|
||||||
|
|
||||||
|
Provides a unified interface for different inference backends:
|
||||||
|
- Native TensorRT: Direct TensorRT API with zero-copy GPU tensors
|
||||||
|
- Ultralytics: YOLO models with built-in pre/postprocessing
|
||||||
|
- Future: ONNX Runtime, OpenVINO, etc.
|
||||||
|
|
||||||
|
All engines support zero-copy GPU tensor inference where possible.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class BackendType(Enum):
|
||||||
|
"""Supported inference backend types"""
|
||||||
|
|
||||||
|
TENSORRT = "tensorrt"
|
||||||
|
ULTRALYTICS = "ultralytics"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_string(cls, backend: str) -> "BackendType":
|
||||||
|
"""Convert string to BackendType"""
|
||||||
|
backend = backend.lower()
|
||||||
|
for member in cls:
|
||||||
|
if member.value == backend:
|
||||||
|
return member
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown backend: {backend}. Available: {[m.value for m in cls]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EngineMetadata:
|
||||||
|
"""Metadata for an inference engine"""
|
||||||
|
|
||||||
|
engine_type: str # "tensorrt", "ultralytics", etc.
|
||||||
|
model_path: str
|
||||||
|
input_shapes: Dict[str, Tuple[int, ...]]
|
||||||
|
output_shapes: Dict[str, Tuple[int, ...]]
|
||||||
|
input_names: List[str]
|
||||||
|
output_names: List[str]
|
||||||
|
input_dtypes: Dict[str, torch.dtype]
|
||||||
|
output_dtypes: Dict[str, torch.dtype]
|
||||||
|
supports_batching: bool = True
|
||||||
|
supports_dynamic_shapes: bool = False
|
||||||
|
extra_info: Dict[str, Any] = None # Backend-specific info
|
||||||
|
|
||||||
|
|
||||||
|
class IInferenceEngine(ABC):
|
||||||
|
"""
|
||||||
|
Abstract interface for inference engines.
|
||||||
|
|
||||||
|
All implementations must support zero-copy GPU tensor inference:
|
||||||
|
- Inputs: CUDA tensors on GPU
|
||||||
|
- Outputs: CUDA tensors on GPU
|
||||||
|
- No CPU transfers during inference
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def initialize(
|
||||||
|
self, model_path: str, device: torch.device, **kwargs
|
||||||
|
) -> EngineMetadata:
|
||||||
|
"""
|
||||||
|
Initialize the inference engine.
|
||||||
|
|
||||||
|
Automatically detects model type and handles conversion if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to model file (.pt, .engine, .trt)
|
||||||
|
device: GPU device to use
|
||||||
|
**kwargs: Optional parameters (batch_size, half, workspace, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EngineMetadata with model information
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def infer(
|
||||||
|
self, inputs: Dict[str, torch.Tensor], **kwargs
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Run inference on GPU tensors (zero-copy).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Dict of input_name -> CUDA tensor
|
||||||
|
**kwargs: Backend-specific inference parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict of output_name -> CUDA tensor
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If inputs are not CUDA tensors or wrong shape
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_metadata(self) -> EngineMetadata:
|
||||||
|
"""Get engine metadata"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def is_initialized(self) -> bool:
|
||||||
|
"""Check if engine is initialized"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
"""Get device the engine is running on"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NativeTensorRTEngine(IInferenceEngine):
|
||||||
|
"""
|
||||||
|
Native TensorRT inference engine with direct API access.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Zero-copy GPU tensor inference
|
||||||
|
- Execution context pooling for concurrent inference
|
||||||
|
- Support for .trt, .engine files
|
||||||
|
- Automatic Ultralytics .engine metadata stripping
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._engine = None
|
||||||
|
self._contexts = []
|
||||||
|
self._metadata = None
|
||||||
|
self._device = None
|
||||||
|
self._trt_logger = None
|
||||||
|
|
||||||
|
def initialize(
|
||||||
|
self, model_path: str, device: torch.device, num_contexts: int = 1, **kwargs
|
||||||
|
) -> EngineMetadata:
|
||||||
|
"""
|
||||||
|
Initialize TensorRT engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to .trt or .engine file
|
||||||
|
device: GPU device
|
||||||
|
num_contexts: Number of execution contexts for pooling
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EngineMetadata
|
||||||
|
"""
|
||||||
|
import tensorrt as trt
|
||||||
|
|
||||||
|
self._device = device
|
||||||
|
self._trt_logger = trt.Logger(trt.Logger.WARNING)
|
||||||
|
|
||||||
|
# Load engine
|
||||||
|
runtime = trt.Runtime(self._trt_logger)
|
||||||
|
|
||||||
|
# Read engine file (handle Ultralytics format)
|
||||||
|
engine_data = self._load_engine_data(model_path)
|
||||||
|
|
||||||
|
self._engine = runtime.deserialize_cuda_engine(engine_data)
|
||||||
|
if self._engine is None:
|
||||||
|
raise RuntimeError(f"Failed to load TensorRT engine from {model_path}")
|
||||||
|
|
||||||
|
# Create execution contexts
|
||||||
|
for i in range(num_contexts):
|
||||||
|
ctx = self._engine.create_execution_context()
|
||||||
|
if ctx is None:
|
||||||
|
raise RuntimeError(f"Failed to create execution context {i}")
|
||||||
|
self._contexts.append(ctx)
|
||||||
|
|
||||||
|
# Extract metadata
|
||||||
|
self._metadata = self._extract_metadata(model_path)
|
||||||
|
|
||||||
|
return self._metadata
|
||||||
|
|
||||||
|
def _load_engine_data(self, file_path: str) -> bytes:
|
||||||
|
"""Load engine data, stripping Ultralytics metadata if present"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
# Try to read Ultralytics metadata header
|
||||||
|
meta_len_bytes = f.read(4)
|
||||||
|
if len(meta_len_bytes) == 4:
|
||||||
|
meta_len = int.from_bytes(meta_len_bytes, byteorder="little")
|
||||||
|
|
||||||
|
# Sanity check
|
||||||
|
if 0 < meta_len < 100000:
|
||||||
|
try:
|
||||||
|
metadata_bytes = f.read(meta_len)
|
||||||
|
json.loads(metadata_bytes.decode("utf-8"))
|
||||||
|
# Valid Ultralytics metadata, rest is engine
|
||||||
|
return f.read()
|
||||||
|
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Not Ultralytics format, read entire file
|
||||||
|
f.seek(0)
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
def _extract_metadata(self, model_path: str) -> EngineMetadata:
|
||||||
|
"""Extract metadata from TensorRT engine"""
|
||||||
|
import tensorrt as trt
|
||||||
|
|
||||||
|
input_shapes = {}
|
||||||
|
output_shapes = {}
|
||||||
|
input_names = []
|
||||||
|
output_names = []
|
||||||
|
input_dtypes = {}
|
||||||
|
output_dtypes = {}
|
||||||
|
|
||||||
|
trt_to_torch_dtype = {
|
||||||
|
trt.DataType.FLOAT: torch.float32,
|
||||||
|
trt.DataType.HALF: torch.float16,
|
||||||
|
trt.DataType.INT8: torch.int8,
|
||||||
|
trt.DataType.INT32: torch.int32,
|
||||||
|
trt.DataType.BOOL: torch.bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in range(self._engine.num_io_tensors):
|
||||||
|
name = self._engine.get_tensor_name(i)
|
||||||
|
shape = tuple(self._engine.get_tensor_shape(name))
|
||||||
|
dtype = trt_to_torch_dtype.get(
|
||||||
|
self._engine.get_tensor_dtype(name), torch.float32
|
||||||
|
)
|
||||||
|
mode = self._engine.get_tensor_mode(name)
|
||||||
|
|
||||||
|
if mode == trt.TensorIOMode.INPUT:
|
||||||
|
input_names.append(name)
|
||||||
|
input_shapes[name] = shape
|
||||||
|
input_dtypes[name] = dtype
|
||||||
|
else:
|
||||||
|
output_names.append(name)
|
||||||
|
output_shapes[name] = shape
|
||||||
|
output_dtypes[name] = dtype
|
||||||
|
|
||||||
|
return EngineMetadata(
|
||||||
|
engine_type="tensorrt",
|
||||||
|
model_path=model_path,
|
||||||
|
input_shapes=input_shapes,
|
||||||
|
output_shapes=output_shapes,
|
||||||
|
input_names=input_names,
|
||||||
|
output_names=output_names,
|
||||||
|
input_dtypes=input_dtypes,
|
||||||
|
output_dtypes=output_dtypes,
|
||||||
|
supports_batching=True,
|
||||||
|
supports_dynamic_shapes=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def infer(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, torch.Tensor],
|
||||||
|
context_id: int = 0,
|
||||||
|
stream: Optional[torch.cuda.Stream] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Run TensorRT inference with zero-copy GPU tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Dict of input_name -> CUDA tensor
|
||||||
|
context_id: Which execution context to use
|
||||||
|
stream: CUDA stream for async execution
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict of output_name -> CUDA tensor
|
||||||
|
"""
|
||||||
|
if not self.is_initialized:
|
||||||
|
raise RuntimeError("Engine not initialized")
|
||||||
|
|
||||||
|
# Validate inputs
|
||||||
|
for name in self._metadata.input_names:
|
||||||
|
if name not in inputs:
|
||||||
|
raise ValueError(f"Missing required input: {name}")
|
||||||
|
if not inputs[name].is_cuda:
|
||||||
|
raise ValueError(f"Input '{name}' must be a CUDA tensor")
|
||||||
|
|
||||||
|
# Get execution context
|
||||||
|
if context_id >= len(self._contexts):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid context_id {context_id}, only {len(self._contexts)} contexts available"
|
||||||
|
)
|
||||||
|
|
||||||
|
context = self._contexts[context_id]
|
||||||
|
|
||||||
|
# Prepare outputs
|
||||||
|
outputs = {}
|
||||||
|
|
||||||
|
# Set input tensor addresses
|
||||||
|
for name in self._metadata.input_names:
|
||||||
|
input_tensor = inputs[name].contiguous()
|
||||||
|
context.set_tensor_address(name, input_tensor.data_ptr())
|
||||||
|
|
||||||
|
# Allocate and set output tensors
|
||||||
|
for name in self._metadata.output_names:
|
||||||
|
output_tensor = torch.empty(
|
||||||
|
self._metadata.output_shapes[name],
|
||||||
|
dtype=self._metadata.output_dtypes[name],
|
||||||
|
device=self._device,
|
||||||
|
)
|
||||||
|
outputs[name] = output_tensor
|
||||||
|
context.set_tensor_address(name, output_tensor.data_ptr())
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
if stream is None:
|
||||||
|
stream = torch.cuda.Stream(device=self._device)
|
||||||
|
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
success = context.execute_async_v3(stream_handle=stream.cuda_stream)
|
||||||
|
if not success:
|
||||||
|
raise RuntimeError("TensorRT inference failed")
|
||||||
|
|
||||||
|
stream.synchronize()
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def get_metadata(self) -> EngineMetadata:
|
||||||
|
"""Get engine metadata"""
|
||||||
|
if self._metadata is None:
|
||||||
|
raise RuntimeError("Engine not initialized")
|
||||||
|
return self._metadata
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup TensorRT resources"""
|
||||||
|
for ctx in self._contexts:
|
||||||
|
del ctx
|
||||||
|
self._contexts.clear()
|
||||||
|
|
||||||
|
if self._engine is not None:
|
||||||
|
del self._engine
|
||||||
|
self._engine = None
|
||||||
|
|
||||||
|
self._metadata = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_initialized(self) -> bool:
|
||||||
|
return self._engine is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self._device
|
||||||
|
|
||||||
|
|
||||||
|
class UltralyticsEngine(IInferenceEngine):
|
||||||
|
"""
|
||||||
|
Ultralytics YOLO inference engine.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Zero-copy GPU tensor inference
|
||||||
|
- Built-in preprocessing/postprocessing for YOLO models
|
||||||
|
- Supports .pt, .engine formats
|
||||||
|
- Automatic model export to TensorRT with caching
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._model = None
|
||||||
|
self._metadata = None
|
||||||
|
self._device = None
|
||||||
|
self._model_path = None
|
||||||
|
self._exporter = None
|
||||||
|
|
||||||
|
def initialize(
|
||||||
|
self,
|
||||||
|
model_path: str,
|
||||||
|
device: torch.device,
|
||||||
|
batch: int = 1,
|
||||||
|
half: bool = False,
|
||||||
|
imgsz: int = 640,
|
||||||
|
cache_dir: str = ".ultralytics_cache",
|
||||||
|
**kwargs,
|
||||||
|
) -> EngineMetadata:
|
||||||
|
"""
|
||||||
|
Initialize Ultralytics YOLO model.
|
||||||
|
|
||||||
|
Automatically exports .pt models to .engine format with caching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to .pt or .engine file
|
||||||
|
device: GPU device
|
||||||
|
batch: Maximum batch size for inference
|
||||||
|
half: Use FP16 precision
|
||||||
|
imgsz: Input image size
|
||||||
|
cache_dir: Directory for caching exported engines
|
||||||
|
**kwargs: Additional export parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EngineMetadata
|
||||||
|
"""
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
from .ultralytics_exporter import UltralyticsExporter
|
||||||
|
|
||||||
|
self._device = device
|
||||||
|
self._model_path = model_path
|
||||||
|
|
||||||
|
# Check if we need to export
|
||||||
|
model_file = Path(model_path)
|
||||||
|
final_model_path = model_path
|
||||||
|
|
||||||
|
if model_file.suffix == ".pt":
|
||||||
|
# Use exporter with caching
|
||||||
|
print(f"Checking for cached TensorRT engine...")
|
||||||
|
self._exporter = UltralyticsExporter(cache_dir=cache_dir)
|
||||||
|
|
||||||
|
_, engine_path = self._exporter.export(
|
||||||
|
model_path=str(model_path),
|
||||||
|
device=device.index if device.type == "cuda" else 0,
|
||||||
|
half=half,
|
||||||
|
imgsz=imgsz,
|
||||||
|
batch=batch,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_model_path = engine_path
|
||||||
|
print(f"Using TensorRT engine: {engine_path}")
|
||||||
|
|
||||||
|
# Load model (Ultralytics handles .engine files natively)
|
||||||
|
self._model = YOLO(final_model_path)
|
||||||
|
|
||||||
|
# Move to device if needed (only for .pt models, .engine already on specific device)
|
||||||
|
if hasattr(self._model, "model") and self._model.model is not None:
|
||||||
|
# Check if it's actually a torch model (not a string path for .engine files)
|
||||||
|
if hasattr(self._model.model, "to"):
|
||||||
|
self._model.model = self._model.model.to(device)
|
||||||
|
|
||||||
|
# Extract metadata
|
||||||
|
self._metadata = self._extract_metadata()
|
||||||
|
|
||||||
|
return self._metadata
|
||||||
|
|
||||||
|
def _extract_metadata(self) -> EngineMetadata:
|
||||||
|
"""Extract metadata from Ultralytics model"""
|
||||||
|
# Ultralytics models typically expect (B, 3, H, W) input
|
||||||
|
# and return Results objects, not raw tensors
|
||||||
|
|
||||||
|
# Default values
|
||||||
|
batch_size = -1 # Dynamic batching by default
|
||||||
|
imgsz = 640
|
||||||
|
input_shape = (batch_size, 3, imgsz, imgsz)
|
||||||
|
|
||||||
|
if hasattr(self._model, "model") and self._model.model is not None:
|
||||||
|
# Try to get actual input shape from model
|
||||||
|
try:
|
||||||
|
# For .engine files, check predictor model
|
||||||
|
if (
|
||||||
|
hasattr(self._model, "predictor")
|
||||||
|
and self._model.predictor is not None
|
||||||
|
):
|
||||||
|
predictor = self._model.predictor
|
||||||
|
|
||||||
|
# Get image size
|
||||||
|
if hasattr(predictor, "args") and hasattr(predictor.args, "imgsz"):
|
||||||
|
imgsz_val = predictor.args.imgsz
|
||||||
|
if isinstance(imgsz_val, (list, tuple)):
|
||||||
|
h, w = (
|
||||||
|
imgsz_val[0],
|
||||||
|
imgsz_val[1] if len(imgsz_val) > 1 else imgsz_val[0],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
h = w = imgsz_val
|
||||||
|
imgsz = h # Use height as reference
|
||||||
|
|
||||||
|
# Get batch size from model
|
||||||
|
if hasattr(predictor, "model"):
|
||||||
|
pred_model = predictor.model
|
||||||
|
|
||||||
|
# For TensorRT engines, check input bindings
|
||||||
|
if hasattr(pred_model, "bindings"):
|
||||||
|
# This is a TensorRT AutoBackend
|
||||||
|
try:
|
||||||
|
# Get first input binding shape
|
||||||
|
if hasattr(pred_model, "input_shape"):
|
||||||
|
shape = pred_model.input_shape
|
||||||
|
if shape and len(shape) >= 4:
|
||||||
|
batch_size = shape[0] if shape[0] > 0 else -1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try batch attribute
|
||||||
|
if batch_size == -1 and hasattr(pred_model, "batch"):
|
||||||
|
batch_size = (
|
||||||
|
pred_model.batch if pred_model.batch > 0 else -1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback: check model args
|
||||||
|
if hasattr(self._model.model, "args"):
|
||||||
|
imgsz_val = getattr(self._model.model.args, "imgsz", 640)
|
||||||
|
if isinstance(imgsz_val, (list, tuple)):
|
||||||
|
h, w = (
|
||||||
|
imgsz_val[0],
|
||||||
|
imgsz_val[1] if len(imgsz_val) > 1 else imgsz_val[0],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
h = w = imgsz_val
|
||||||
|
imgsz = h
|
||||||
|
|
||||||
|
input_shape = (batch_size, 3, imgsz, imgsz)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not extract full metadata: {e}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
return EngineMetadata(
|
||||||
|
engine_type="ultralytics",
|
||||||
|
model_path=self._model_path,
|
||||||
|
input_shapes={"images": input_shape},
|
||||||
|
output_shapes={"results": (-1,)}, # Dynamic, depends on detections
|
||||||
|
input_names=["images"],
|
||||||
|
output_names=["results"],
|
||||||
|
input_dtypes={"images": torch.float32},
|
||||||
|
output_dtypes={"results": torch.float32},
|
||||||
|
supports_batching=True,
|
||||||
|
supports_dynamic_shapes=(batch_size == -1),
|
||||||
|
extra_info={
|
||||||
|
"is_yolo": True,
|
||||||
|
"has_builtin_postprocess": True,
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"imgsz": imgsz,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def infer(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, torch.Tensor],
|
||||||
|
return_raw: bool = False,
|
||||||
|
conf: float = 0.25,
|
||||||
|
iou: float = 0.45,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Run Ultralytics inference with zero-copy GPU tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Dict with "images" key -> CUDA tensor (B, 3, H, W), normalized [0, 1]
|
||||||
|
return_raw: If True, return raw tensor output. If False, return Results objects
|
||||||
|
conf: Confidence threshold
|
||||||
|
iou: IoU threshold for NMS
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with inference results
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Input tensor should be normalized to [0, 1] range.
|
||||||
|
Format: (B, 3, H, W) in RGB color space.
|
||||||
|
"""
|
||||||
|
if not self.is_initialized:
|
||||||
|
raise RuntimeError("Engine not initialized")
|
||||||
|
|
||||||
|
# Get input tensor
|
||||||
|
if "images" not in inputs:
|
||||||
|
raise ValueError("Input must contain 'images' key")
|
||||||
|
|
||||||
|
images = inputs["images"]
|
||||||
|
|
||||||
|
if not images.is_cuda:
|
||||||
|
raise ValueError("Input must be a CUDA tensor")
|
||||||
|
|
||||||
|
# Ensure tensor is on correct device
|
||||||
|
if images.device != self._device:
|
||||||
|
images = images.to(self._device)
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
results = self._model(images, conf=conf, iou=iou, verbose=False, **kwargs)
|
||||||
|
|
||||||
|
# Return results
|
||||||
|
# Note: Ultralytics returns Results objects, not raw tensors
|
||||||
|
# For compatibility, we wrap them in a dict
|
||||||
|
return {
|
||||||
|
"results": results,
|
||||||
|
"raw_predictions": results[0].boxes.data
|
||||||
|
if len(results) > 0 and hasattr(results[0], "boxes")
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_metadata(self) -> EngineMetadata:
|
||||||
|
"""Get engine metadata"""
|
||||||
|
if self._metadata is None:
|
||||||
|
raise RuntimeError("Engine not initialized")
|
||||||
|
return self._metadata
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup Ultralytics model"""
|
||||||
|
if self._model is not None:
|
||||||
|
del self._model
|
||||||
|
self._model = None
|
||||||
|
self._metadata = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_initialized(self) -> bool:
|
||||||
|
return self._model is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self._device
|
||||||
|
|
||||||
|
|
||||||
|
def create_engine(backend: str | BackendType, **kwargs) -> IInferenceEngine:
|
||||||
|
"""
|
||||||
|
Factory function to create inference engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend: Backend type (BackendType enum or string: "tensorrt", "ultralytics")
|
||||||
|
**kwargs: Engine-specific arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IInferenceEngine instance
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from services import create_engine, BackendType
|
||||||
|
>>> engine = create_engine(BackendType.TENSORRT)
|
||||||
|
>>> engine = create_engine("ultralytics")
|
||||||
|
"""
|
||||||
|
# Convert string to BackendType if needed
|
||||||
|
if isinstance(backend, str):
|
||||||
|
backend = BackendType.from_string(backend)
|
||||||
|
|
||||||
|
engines = {
|
||||||
|
BackendType.TENSORRT: NativeTensorRTEngine,
|
||||||
|
BackendType.ULTRALYTICS: UltralyticsEngine,
|
||||||
|
}
|
||||||
|
|
||||||
|
if backend not in engines:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown backend: {backend}. Available: {[b.value for b in BackendType]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return engines[backend]()
|
||||||
|
|
@ -5,21 +5,22 @@ This module provides batched inference coordination using ping-pong circular buf
|
||||||
with force-switch timeout mechanism using threading and callbacks.
|
with force-switch timeout mechanism using threading and callbacks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import threading
|
|
||||||
import torch
|
|
||||||
from typing import Dict, List, Optional, Callable, Any
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from enum import Enum
|
|
||||||
import time
|
|
||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchFrame:
|
class BatchFrame:
|
||||||
"""Represents a frame in the batch buffer"""
|
|
||||||
stream_id: str
|
stream_id: str
|
||||||
frame: torch.Tensor # GPU tensor (3, H, W)
|
frame: torch.Tensor # GPU tensor (3, H, W)
|
||||||
timestamp: float
|
timestamp: float
|
||||||
|
|
@ -28,6 +29,7 @@ class BatchFrame:
|
||||||
|
|
||||||
class BufferState(Enum):
|
class BufferState(Enum):
|
||||||
"""State of a ping-pong buffer"""
|
"""State of a ping-pong buffer"""
|
||||||
|
|
||||||
IDLE = "idle"
|
IDLE = "idle"
|
||||||
FILLING = "filling"
|
FILLING = "filling"
|
||||||
PROCESSING = "processing"
|
PROCESSING = "processing"
|
||||||
|
|
@ -80,7 +82,9 @@ class ModelController:
|
||||||
f"Will process frames sequentially. Consider rebuilding model with dynamic batching."
|
f"Will process frames sequentially. Consider rebuilding model with dynamic batching."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(f"Model '{model_id}' supports batch_size={self.model_batch_size}")
|
logger.info(
|
||||||
|
f"Model '{model_id}' supports batch_size={self.model_batch_size}"
|
||||||
|
)
|
||||||
|
|
||||||
# Ping-pong buffers
|
# Ping-pong buffers
|
||||||
self.buffer_a: List[BatchFrame] = []
|
self.buffer_a: List[BatchFrame] = []
|
||||||
|
|
@ -130,7 +134,9 @@ class ModelController:
|
||||||
# Fixed batch size
|
# Fixed batch size
|
||||||
return batch_dim
|
return batch_dim
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1")
|
logger.warning(
|
||||||
|
f"Could not detect model batch size: {e}. Assuming batch_size=1"
|
||||||
|
)
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
|
|
@ -143,14 +149,20 @@ class ModelController:
|
||||||
self.stop_event.clear()
|
self.stop_event.clear()
|
||||||
|
|
||||||
# Start timeout monitor thread
|
# Start timeout monitor thread
|
||||||
self.timeout_thread = threading.Thread(target=self._timeout_monitor, daemon=True)
|
self.timeout_thread = threading.Thread(
|
||||||
|
target=self._timeout_monitor, daemon=True
|
||||||
|
)
|
||||||
self.timeout_thread.start()
|
self.timeout_thread.start()
|
||||||
|
|
||||||
# Start processor threads for each buffer
|
# Start processor threads for each buffer
|
||||||
self.processor_threads['A'] = threading.Thread(target=self._batch_processor, args=('A',), daemon=True)
|
self.processor_threads["A"] = threading.Thread(
|
||||||
self.processor_threads['B'] = threading.Thread(target=self._batch_processor, args=('B',), daemon=True)
|
target=self._batch_processor, args=("A",), daemon=True
|
||||||
self.processor_threads['A'].start()
|
)
|
||||||
self.processor_threads['B'].start()
|
self.processor_threads["B"] = threading.Thread(
|
||||||
|
target=self._batch_processor, args=("B",), daemon=True
|
||||||
|
)
|
||||||
|
self.processor_threads["A"].start()
|
||||||
|
self.processor_threads["B"].start()
|
||||||
|
|
||||||
logger.info("ModelController started")
|
logger.info("ModelController started")
|
||||||
|
|
||||||
|
|
@ -197,10 +209,7 @@ class ModelController:
|
||||||
logger.debug(f"Unregistered callback for stream: {stream_id}")
|
logger.debug(f"Unregistered callback for stream: {stream_id}")
|
||||||
|
|
||||||
def submit_frame(
|
def submit_frame(
|
||||||
self,
|
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
|
||||||
stream_id: str,
|
|
||||||
frame: torch.Tensor,
|
|
||||||
metadata: Optional[Dict] = None
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Submit a frame for batched inference.
|
Submit a frame for batched inference.
|
||||||
|
|
@ -215,7 +224,7 @@ class ModelController:
|
||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
frame=frame,
|
frame=frame,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
metadata=metadata or {}
|
metadata=metadata or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add to active buffer
|
# Add to active buffer
|
||||||
|
|
@ -242,7 +251,9 @@ class ModelController:
|
||||||
|
|
||||||
# Check if timeout expired and we have frames waiting
|
# Check if timeout expired and we have frames waiting
|
||||||
if time_since_submit >= self.force_timeout:
|
if time_since_submit >= self.force_timeout:
|
||||||
active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
active_buffer = (
|
||||||
|
self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
||||||
|
)
|
||||||
if len(active_buffer) > 0:
|
if len(active_buffer) > 0:
|
||||||
self._try_swap_buffers()
|
self._try_swap_buffers()
|
||||||
|
|
||||||
|
|
@ -254,7 +265,9 @@ class ModelController:
|
||||||
This method should be called with buffer_lock held.
|
This method should be called with buffer_lock held.
|
||||||
"""
|
"""
|
||||||
# Check if inactive buffer is available
|
# Check if inactive buffer is available
|
||||||
inactive_state = self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
|
inactive_state = (
|
||||||
|
self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
|
||||||
|
)
|
||||||
|
|
||||||
if inactive_state != BufferState.PROCESSING:
|
if inactive_state != BufferState.PROCESSING:
|
||||||
# Swap active buffer
|
# Swap active buffer
|
||||||
|
|
@ -269,7 +282,9 @@ class ModelController:
|
||||||
self.buffer_b_state = BufferState.PROCESSING
|
self.buffer_b_state = BufferState.PROCESSING
|
||||||
buffer_size = len(self.buffer_b)
|
buffer_size = len(self.buffer_b)
|
||||||
|
|
||||||
logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})")
|
logger.debug(
|
||||||
|
f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})"
|
||||||
|
)
|
||||||
|
|
||||||
def _batch_processor(self, buffer_name: str):
|
def _batch_processor(self, buffer_name: str):
|
||||||
"""Background thread that processes a specific buffer when available"""
|
"""Background thread that processes a specific buffer when available"""
|
||||||
|
|
@ -322,8 +337,8 @@ class ModelController:
|
||||||
self.total_batches_processed += 1
|
self.total_batches_processed += 1
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Processed batch of {len(batch)} frames in {inference_time*1000:.2f}ms "
|
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms "
|
||||||
f"({inference_time*1000/len(batch):.2f}ms per frame)"
|
f"({inference_time * 1000 / len(batch):.2f}ms per frame)"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Emit results to callbacks
|
# Emit results to callbacks
|
||||||
|
|
@ -334,7 +349,10 @@ class ModelController:
|
||||||
try:
|
try:
|
||||||
callback(result)
|
callback(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in callback for {batch_frame.stream_id}: {e}", exc_info=True)
|
logger.error(
|
||||||
|
f"Error in callback for {batch_frame.stream_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing batch: {e}", exc_info=True)
|
logger.error(f"Error processing batch: {e}", exc_info=True)
|
||||||
|
|
@ -365,7 +383,9 @@ class ModelController:
|
||||||
# Use true batching for models that support it
|
# Use true batching for models that support it
|
||||||
return self._run_batched_inference(batch)
|
return self._run_batched_inference(batch)
|
||||||
|
|
||||||
def _run_sequential_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
def _run_sequential_inference(
|
||||||
|
self, batch: List[BatchFrame]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""Run inference sequentially for batch_size=1 models"""
|
"""Run inference sequentially for batch_size=1 models"""
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
|
|
@ -375,13 +395,15 @@ class ModelController:
|
||||||
processed = self.preprocess_fn(batch_frame.frame)
|
processed = self.preprocess_fn(batch_frame.frame)
|
||||||
else:
|
else:
|
||||||
# Ensure we have batch dimension
|
# Ensure we have batch dimension
|
||||||
processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame
|
processed = (
|
||||||
|
batch_frame.frame.unsqueeze(0)
|
||||||
|
if batch_frame.frame.dim() == 3
|
||||||
|
else batch_frame.frame
|
||||||
|
)
|
||||||
|
|
||||||
# Run inference for this frame
|
# Run inference for this frame
|
||||||
outputs = self.model_repository.infer(
|
outputs = self.model_repository.infer(
|
||||||
self.model_id,
|
self.model_id, {"images": processed}, synchronize=True
|
||||||
{"images": processed},
|
|
||||||
synchronize=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Postprocess
|
# Postprocess
|
||||||
|
|
@ -389,9 +411,13 @@ class ModelController:
|
||||||
try:
|
try:
|
||||||
detections = self.postprocess_fn(outputs)
|
detections = self.postprocess_fn(outputs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
|
logger.error(
|
||||||
|
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
|
||||||
|
)
|
||||||
# Return empty detections on error
|
# Return empty detections on error
|
||||||
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
|
detections = torch.zeros(
|
||||||
|
(0, 6), device=list(outputs.values())[0].device
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
detections = outputs
|
detections = outputs
|
||||||
|
|
||||||
|
|
@ -429,32 +455,37 @@ class ModelController:
|
||||||
f"will split into sub-batches"
|
f"will split into sub-batches"
|
||||||
)
|
)
|
||||||
# TODO: Handle splitting into sub-batches
|
# TODO: Handle splitting into sub-batches
|
||||||
batch_tensor = batch_tensor[:self.model_batch_size]
|
batch_tensor = batch_tensor[: self.model_batch_size]
|
||||||
batch = batch[:self.model_batch_size]
|
batch = batch[: self.model_batch_size]
|
||||||
|
|
||||||
# Run inference
|
# Run inference
|
||||||
outputs = self.model_repository.infer(
|
outputs = self.model_repository.infer(
|
||||||
self.model_id,
|
self.model_id, {"images": batch_tensor}, synchronize=True
|
||||||
{"images": batch_tensor},
|
|
||||||
synchronize=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Postprocess results (split batch back to individual results)
|
# Postprocess results (split batch back to individual results)
|
||||||
results = []
|
results = []
|
||||||
for i, batch_frame in enumerate(batch):
|
for i, batch_frame in enumerate(batch):
|
||||||
# Extract single frame output from batch
|
# Extract single frame output from batch and clone to ensure memory safety
|
||||||
|
# This prevents potential race conditions if the output tensors are still
|
||||||
|
# in use when the next inference batch is processed
|
||||||
frame_output = {}
|
frame_output = {}
|
||||||
for k, v in outputs.items():
|
for k, v in outputs.items():
|
||||||
# v has shape (N, ...), extract index i and keep batch dimension
|
# v has shape (N, ...), extract index i and keep batch dimension
|
||||||
frame_output[k] = v[i:i+1] # Shape: (1, ...)
|
# Clone to decouple from shared batch output tensor
|
||||||
|
frame_output[k] = v[i : i + 1].clone() # Shape: (1, ...)
|
||||||
|
|
||||||
if self.postprocess_fn:
|
if self.postprocess_fn:
|
||||||
try:
|
try:
|
||||||
detections = self.postprocess_fn(frame_output)
|
detections = self.postprocess_fn(frame_output)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
|
logger.error(
|
||||||
|
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
|
||||||
|
)
|
||||||
# Return empty detections on error
|
# Return empty detections on error
|
||||||
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
|
detections = torch.zeros(
|
||||||
|
(0, 6), device=list(outputs.values())[0].device
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
detections = frame_output
|
detections = frame_output
|
||||||
|
|
||||||
|
|
@ -490,6 +521,8 @@ class ModelController:
|
||||||
"total_batches_processed": self.total_batches_processed,
|
"total_batches_processed": self.total_batches_processed,
|
||||||
"avg_batch_size": (
|
"avg_batch_size": (
|
||||||
self.total_frames_processed / self.total_batches_processed
|
self.total_frames_processed / self.total_batches_processed
|
||||||
|
if self.total_batches_processed > 0
|
||||||
|
else 0
|
||||||
if self.total_batches_processed > 0 else 0
|
if self.total_batches_processed > 0 else 0
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,14 @@
|
||||||
import threading
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
from typing import Optional, Dict, Any, List, Tuple
|
import logging
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
import torch
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import tensorrt as trt
|
import tensorrt as trt
|
||||||
from dataclasses import dataclass
|
import torch
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -15,6 +16,7 @@ logger = logging.getLogger(__name__)
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelMetadata:
|
class ModelMetadata:
|
||||||
"""Metadata for a loaded TensorRT model"""
|
"""Metadata for a loaded TensorRT model"""
|
||||||
|
|
||||||
file_path: str
|
file_path: str
|
||||||
file_hash: str
|
file_hash: str
|
||||||
input_shapes: Dict[str, Tuple[int, ...]]
|
input_shapes: Dict[str, Tuple[int, ...]]
|
||||||
|
|
@ -30,8 +32,14 @@ class ExecutionContext:
|
||||||
Wrapper for TensorRT execution context with CUDA stream.
|
Wrapper for TensorRT execution context with CUDA stream.
|
||||||
Used in context pool for load balancing.
|
Used in context pool for load balancing.
|
||||||
"""
|
"""
|
||||||
def __init__(self, context: trt.IExecutionContext, stream: torch.cuda.Stream,
|
|
||||||
context_id: int, device: torch.device):
|
def __init__(
|
||||||
|
self,
|
||||||
|
context: trt.IExecutionContext,
|
||||||
|
stream: torch.cuda.Stream,
|
||||||
|
context_id: int,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
self.context = context
|
self.context = context
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.context_id = context_id
|
self.context_id = context_id
|
||||||
|
|
@ -53,8 +61,16 @@ class SharedEngine:
|
||||||
- Contexts are borrowed/returned using mutex locks
|
- Contexts are borrowed/returned using mutex locks
|
||||||
- Load balancing: contexts distributed across requests
|
- Load balancing: contexts distributed across requests
|
||||||
"""
|
"""
|
||||||
def __init__(self, engine: trt.ICudaEngine, file_hash: str, file_path: str,
|
|
||||||
num_contexts: int, device: torch.device, metadata: ModelMetadata):
|
def __init__(
|
||||||
|
self,
|
||||||
|
engine: trt.ICudaEngine,
|
||||||
|
file_hash: str,
|
||||||
|
file_path: str,
|
||||||
|
num_contexts: int,
|
||||||
|
device: torch.device,
|
||||||
|
metadata: ModelMetadata,
|
||||||
|
):
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
self.file_hash = file_hash
|
self.file_hash = file_hash
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
|
|
@ -80,9 +96,13 @@ class SharedEngine:
|
||||||
self.model_ids: set = set()
|
self.model_ids: set = set()
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
print(f"Created context pool with {num_contexts} contexts for engine {file_hash[:8]}...")
|
print(
|
||||||
|
f"Created context pool with {num_contexts} contexts for engine {file_hash[:8]}..."
|
||||||
|
)
|
||||||
|
|
||||||
def acquire_context(self, timeout: Optional[float] = None) -> Optional[ExecutionContext]:
|
def acquire_context(
|
||||||
|
self, timeout: Optional[float] = None
|
||||||
|
) -> Optional[ExecutionContext]:
|
||||||
"""
|
"""
|
||||||
Acquire an available execution context from the pool.
|
Acquire an available execution context from the pool.
|
||||||
Blocks if all contexts are in use.
|
Blocks if all contexts are in use.
|
||||||
|
|
@ -162,7 +182,13 @@ class TensorRTModelRepository:
|
||||||
# Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts!
|
# Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts!
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4, enable_pt_conversion: bool = True, cache_dir: str = ".trt_cache"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
gpu_id: int = 0,
|
||||||
|
default_num_contexts: int = 4,
|
||||||
|
enable_pt_conversion: bool = True,
|
||||||
|
cache_dir: str = ".trt_cache",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the model repository.
|
Initialize the model repository.
|
||||||
|
|
||||||
|
|
@ -173,7 +199,7 @@ class TensorRTModelRepository:
|
||||||
cache_dir: Directory for caching stripped TensorRT engines and metadata
|
cache_dir: Directory for caching stripped TensorRT engines and metadata
|
||||||
"""
|
"""
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
self.device = torch.device(f'cuda:{gpu_id}')
|
self.device = torch.device(f"cuda:{gpu_id}")
|
||||||
self.default_num_contexts = default_num_contexts
|
self.default_num_contexts = default_num_contexts
|
||||||
self.enable_pt_conversion = enable_pt_conversion
|
self.enable_pt_conversion = enable_pt_conversion
|
||||||
self.cache_dir = Path(cache_dir)
|
self.cache_dir = Path(cache_dir)
|
||||||
|
|
@ -195,7 +221,9 @@ class TensorRTModelRepository:
|
||||||
self._pt_converter = None
|
self._pt_converter = None
|
||||||
|
|
||||||
print(f"TensorRT Model Repository initialized on GPU {gpu_id}")
|
print(f"TensorRT Model Repository initialized on GPU {gpu_id}")
|
||||||
print(f"Default context pool size: {default_num_contexts} contexts per unique model")
|
print(
|
||||||
|
f"Default context pool size: {default_num_contexts} contexts per unique model"
|
||||||
|
)
|
||||||
print(f"Cache directory: {self.cache_dir}")
|
print(f"Cache directory: {self.cache_dir}")
|
||||||
if enable_pt_conversion:
|
if enable_pt_conversion:
|
||||||
print(f"PyTorch to TensorRT conversion: enabled")
|
print(f"PyTorch to TensorRT conversion: enabled")
|
||||||
|
|
@ -205,6 +233,7 @@ class TensorRTModelRepository:
|
||||||
"""Lazy initialization of PT converter"""
|
"""Lazy initialization of PT converter"""
|
||||||
if self._pt_converter is None and self.enable_pt_conversion:
|
if self._pt_converter is None and self.enable_pt_conversion:
|
||||||
from .pt_converter import PTConverter
|
from .pt_converter import PTConverter
|
||||||
|
|
||||||
self._pt_converter = PTConverter(gpu_id=self.gpu_id)
|
self._pt_converter = PTConverter(gpu_id=self.gpu_id)
|
||||||
logger.info("PT converter initialized")
|
logger.info("PT converter initialized")
|
||||||
return self._pt_converter
|
return self._pt_converter
|
||||||
|
|
@ -255,11 +284,11 @@ class TensorRTModelRepository:
|
||||||
# Check if stripped engine already cached
|
# Check if stripped engine already cached
|
||||||
if cache_engine_path.exists():
|
if cache_engine_path.exists():
|
||||||
logger.info(f"Loading cached stripped engine from {cache_engine_path}")
|
logger.info(f"Loading cached stripped engine from {cache_engine_path}")
|
||||||
with open(cache_engine_path, 'rb') as f:
|
with open(cache_engine_path, "rb") as f:
|
||||||
engine_data = f.read()
|
engine_data = f.read()
|
||||||
else:
|
else:
|
||||||
# Read and process original file
|
# Read and process original file
|
||||||
with open(file_path, 'rb') as f:
|
with open(file_path, "rb") as f:
|
||||||
# Try to read Ultralytics metadata header (first 4 bytes = metadata length)
|
# Try to read Ultralytics metadata header (first 4 bytes = metadata length)
|
||||||
try:
|
try:
|
||||||
meta_len_bytes = f.read(4)
|
meta_len_bytes = f.read(4)
|
||||||
|
|
@ -278,13 +307,15 @@ class TensorRTModelRepository:
|
||||||
# Save stripped engine to cache
|
# Save stripped engine to cache
|
||||||
logger.info(f"Detected Ultralytics engine format")
|
logger.info(f"Detected Ultralytics engine format")
|
||||||
logger.info(f"Ultralytics metadata: {metadata}")
|
logger.info(f"Ultralytics metadata: {metadata}")
|
||||||
logger.info(f"Caching stripped engine to {cache_engine_path}")
|
logger.info(
|
||||||
|
f"Caching stripped engine to {cache_engine_path}"
|
||||||
|
)
|
||||||
|
|
||||||
with open(cache_engine_path, 'wb') as cache_f:
|
with open(cache_engine_path, "wb") as cache_f:
|
||||||
cache_f.write(engine_data)
|
cache_f.write(engine_data)
|
||||||
|
|
||||||
# Save metadata separately
|
# Save metadata separately
|
||||||
with open(cache_metadata_path, 'w') as meta_f:
|
with open(cache_metadata_path, "w") as meta_f:
|
||||||
json.dump(metadata, meta_f, indent=2)
|
json.dump(metadata, meta_f, indent=2)
|
||||||
|
|
||||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||||
|
|
@ -301,13 +332,15 @@ class TensorRTModelRepository:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Any error, rewind and read entire file
|
# Any error, rewind and read entire file
|
||||||
logger.warning(f"Error reading engine metadata: {e}, treating as raw TRT engine")
|
logger.warning(
|
||||||
|
f"Error reading engine metadata: {e}, treating as raw TRT engine"
|
||||||
|
)
|
||||||
f.seek(0)
|
f.seek(0)
|
||||||
engine_data = f.read()
|
engine_data = f.read()
|
||||||
|
|
||||||
# Cache the engine data (even if it was already raw TRT)
|
# Cache the engine data (even if it was already raw TRT)
|
||||||
if not cache_engine_path.exists():
|
if not cache_engine_path.exists():
|
||||||
with open(cache_engine_path, 'wb') as cache_f:
|
with open(cache_engine_path, "wb") as cache_f:
|
||||||
cache_f.write(engine_data)
|
cache_f.write(engine_data)
|
||||||
|
|
||||||
engine = runtime.deserialize_cuda_engine(engine_data)
|
engine = runtime.deserialize_cuda_engine(engine_data)
|
||||||
|
|
@ -316,8 +349,9 @@ class TensorRTModelRepository:
|
||||||
|
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
def _extract_metadata(self, engine: trt.ICudaEngine,
|
def _extract_metadata(
|
||||||
file_path: str, file_hash: str) -> ModelMetadata:
|
self, engine: trt.ICudaEngine, file_path: str, file_hash: str
|
||||||
|
) -> ModelMetadata:
|
||||||
"""
|
"""
|
||||||
Extract metadata from TensorRT engine.
|
Extract metadata from TensorRT engine.
|
||||||
|
|
||||||
|
|
@ -369,15 +403,19 @@ class TensorRTModelRepository:
|
||||||
input_names=input_names,
|
input_names=input_names,
|
||||||
output_names=output_names,
|
output_names=output_names,
|
||||||
input_dtypes=input_dtypes,
|
input_dtypes=input_dtypes,
|
||||||
output_dtypes=output_dtypes
|
output_dtypes=output_dtypes,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_model(self, model_id: str, file_path: str,
|
def load_model(
|
||||||
num_contexts: Optional[int] = None,
|
self,
|
||||||
force_reload: bool = False,
|
model_id: str,
|
||||||
pt_input_shapes: Optional[Dict[str, Tuple]] = None,
|
file_path: str,
|
||||||
pt_precision: Optional[torch.dtype] = None,
|
num_contexts: Optional[int] = None,
|
||||||
**pt_conversion_kwargs) -> ModelMetadata:
|
force_reload: bool = False,
|
||||||
|
pt_input_shapes: Optional[Dict[str, Tuple]] = None,
|
||||||
|
pt_precision: Optional[torch.dtype] = None,
|
||||||
|
**pt_conversion_kwargs,
|
||||||
|
) -> ModelMetadata:
|
||||||
"""
|
"""
|
||||||
Load a TensorRT model with the given ID.
|
Load a TensorRT model with the given ID.
|
||||||
|
|
||||||
|
|
@ -410,7 +448,7 @@ class TensorRTModelRepository:
|
||||||
|
|
||||||
# Check if file is PyTorch model
|
# Check if file is PyTorch model
|
||||||
file_ext = Path(file_path).suffix.lower()
|
file_ext = Path(file_path).suffix.lower()
|
||||||
if file_ext in ['.pt', '.pth']:
|
if file_ext in [".pt", ".pth"]:
|
||||||
if not self.enable_pt_conversion:
|
if not self.enable_pt_conversion:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"PT file provided but PT conversion is disabled. "
|
f"PT file provided but PT conversion is disabled. "
|
||||||
|
|
@ -425,7 +463,7 @@ class TensorRTModelRepository:
|
||||||
file_path,
|
file_path,
|
||||||
input_shapes=pt_input_shapes,
|
input_shapes=pt_input_shapes,
|
||||||
precision=pt_precision,
|
precision=pt_precision,
|
||||||
**pt_conversion_kwargs
|
**pt_conversion_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update file_path to use converted TRT file
|
# Update file_path to use converted TRT file
|
||||||
|
|
@ -455,8 +493,12 @@ class TensorRTModelRepository:
|
||||||
# Check if this file is already loaded (deduplication)
|
# Check if this file is already loaded (deduplication)
|
||||||
if file_hash in self._shared_engines:
|
if file_hash in self._shared_engines:
|
||||||
shared_engine = self._shared_engines[file_hash]
|
shared_engine = self._shared_engines[file_hash]
|
||||||
print(f"Engine already loaded (hash match), reusing engine and context pool...")
|
print(
|
||||||
print(f" Existing model_ids using this engine: {shared_engine.model_ids}")
|
f"Engine already loaded (hash match), reusing engine and context pool..."
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Existing model_ids using this engine: {shared_engine.model_ids}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Load new engine
|
# Load new engine
|
||||||
print(f"Loading TensorRT engine from {file_path}...")
|
print(f"Loading TensorRT engine from {file_path}...")
|
||||||
|
|
@ -472,7 +514,7 @@ class TensorRTModelRepository:
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
num_contexts=num_contexts,
|
num_contexts=num_contexts,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
metadata=metadata
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
self._shared_engines[file_hash] = shared_engine
|
self._shared_engines[file_hash] = shared_engine
|
||||||
|
|
||||||
|
|
@ -485,18 +527,29 @@ class TensorRTModelRepository:
|
||||||
print(f"Model '{model_id}' loaded successfully")
|
print(f"Model '{model_id}' loaded successfully")
|
||||||
print(f" Inputs: {shared_engine.metadata.input_names}")
|
print(f" Inputs: {shared_engine.metadata.input_names}")
|
||||||
for name in shared_engine.metadata.input_names:
|
for name in shared_engine.metadata.input_names:
|
||||||
print(f" {name}: {shared_engine.metadata.input_shapes[name]} ({shared_engine.metadata.input_dtypes[name]})")
|
print(
|
||||||
|
f" {name}: {shared_engine.metadata.input_shapes[name]} ({shared_engine.metadata.input_dtypes[name]})"
|
||||||
|
)
|
||||||
print(f" Outputs: {shared_engine.metadata.output_names}")
|
print(f" Outputs: {shared_engine.metadata.output_names}")
|
||||||
for name in shared_engine.metadata.output_names:
|
for name in shared_engine.metadata.output_names:
|
||||||
print(f" {name}: {shared_engine.metadata.output_shapes[name]} ({shared_engine.metadata.output_dtypes[name]})")
|
print(
|
||||||
|
f" {name}: {shared_engine.metadata.output_shapes[name]} ({shared_engine.metadata.output_dtypes[name]})"
|
||||||
|
)
|
||||||
print(f" Context pool size: {num_contexts}")
|
print(f" Context pool size: {num_contexts}")
|
||||||
print(f" Model IDs sharing this engine: {shared_engine.get_reference_count()}")
|
print(
|
||||||
|
f" Model IDs sharing this engine: {shared_engine.get_reference_count()}"
|
||||||
|
)
|
||||||
print(f" Unique engines in VRAM: {len(self._shared_engines)}")
|
print(f" Unique engines in VRAM: {len(self._shared_engines)}")
|
||||||
|
|
||||||
return shared_engine.metadata
|
return shared_engine.metadata
|
||||||
|
|
||||||
def infer(self, model_id: str, inputs: Dict[str, torch.Tensor],
|
def infer(
|
||||||
synchronize: bool = True, timeout: Optional[float] = 5.0) -> Dict[str, torch.Tensor]:
|
self,
|
||||||
|
model_id: str,
|
||||||
|
inputs: Dict[str, torch.Tensor],
|
||||||
|
synchronize: bool = True,
|
||||||
|
timeout: Optional[float] = 5.0,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Run GPU-to-GPU inference with the specified model using context pooling.
|
Run GPU-to-GPU inference with the specified model using context pooling.
|
||||||
|
|
||||||
|
|
@ -519,7 +572,9 @@ class TensorRTModelRepository:
|
||||||
"""
|
"""
|
||||||
# Get shared engine
|
# Get shared engine
|
||||||
if model_id not in self._model_to_hash:
|
if model_id not in self._model_to_hash:
|
||||||
raise KeyError(f"Model '{model_id}' not found. Available: {list(self._model_to_hash.keys())}")
|
raise KeyError(
|
||||||
|
f"Model '{model_id}' not found. Available: {list(self._model_to_hash.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
file_hash = self._model_to_hash[model_id]
|
file_hash = self._model_to_hash[model_id]
|
||||||
shared_engine = self._shared_engines[file_hash]
|
shared_engine = self._shared_engines[file_hash]
|
||||||
|
|
@ -536,7 +591,9 @@ class TensorRTModelRepository:
|
||||||
|
|
||||||
# Check device
|
# Check device
|
||||||
if tensor.device != self.device:
|
if tensor.device != self.device:
|
||||||
print(f"Warning: Input '{name}' on {tensor.device}, moving to {self.device}")
|
print(
|
||||||
|
f"Warning: Input '{name}' on {tensor.device}, moving to {self.device}"
|
||||||
|
)
|
||||||
inputs[name] = tensor.to(self.device)
|
inputs[name] = tensor.to(self.device)
|
||||||
|
|
||||||
# Acquire context from pool (mutex-based)
|
# Acquire context from pool (mutex-based)
|
||||||
|
|
@ -562,9 +619,7 @@ class TensorRTModelRepository:
|
||||||
output_dtype = metadata.output_dtypes[name]
|
output_dtype = metadata.output_dtypes[name]
|
||||||
|
|
||||||
output_tensor = torch.empty(
|
output_tensor = torch.empty(
|
||||||
output_shape,
|
output_shape, dtype=output_dtype, device=self.device
|
||||||
dtype=output_dtype,
|
|
||||||
device=self.device
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: Don't track these tensors - they're returned to caller and consumed
|
# NOTE: Don't track these tensors - they're returned to caller and consumed
|
||||||
|
|
@ -584,9 +639,23 @@ class TensorRTModelRepository:
|
||||||
if not success:
|
if not success:
|
||||||
raise RuntimeError(f"Inference failed for model '{model_id}'")
|
raise RuntimeError(f"Inference failed for model '{model_id}'")
|
||||||
|
|
||||||
# Synchronize if requested
|
# CRITICAL: Always synchronize before releasing context
|
||||||
if synchronize:
|
# Even if caller requested async execution, we MUST sync before
|
||||||
exec_ctx.stream.synchronize()
|
# releasing the context to prevent race conditions where the next
|
||||||
|
# inference using this context overwrites tensor addresses while
|
||||||
|
# the current batch is still being processed.
|
||||||
|
exec_ctx.stream.synchronize()
|
||||||
|
|
||||||
|
# Clone outputs to new tensors to ensure memory safety
|
||||||
|
# This prevents race conditions where the next batch using this context
|
||||||
|
# could overwrite the output tensor addresses before the caller
|
||||||
|
# finishes processing these results.
|
||||||
|
if not synchronize:
|
||||||
|
# For async mode, clone to decouple from context
|
||||||
|
cloned_outputs = {}
|
||||||
|
for name, tensor in outputs.items():
|
||||||
|
cloned_outputs[name] = tensor.clone()
|
||||||
|
outputs = cloned_outputs
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
@ -594,8 +663,12 @@ class TensorRTModelRepository:
|
||||||
# Always release context back to pool
|
# Always release context back to pool
|
||||||
shared_engine.release_context(exec_ctx)
|
shared_engine.release_context(exec_ctx)
|
||||||
|
|
||||||
def infer_batch(self, model_id: str, batch_inputs: List[Dict[str, torch.Tensor]],
|
def infer_batch(
|
||||||
synchronize: bool = True) -> List[Dict[str, torch.Tensor]]:
|
self,
|
||||||
|
model_id: str,
|
||||||
|
batch_inputs: List[Dict[str, torch.Tensor]],
|
||||||
|
synchronize: bool = True,
|
||||||
|
) -> List[Dict[str, torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Run inference on multiple inputs.
|
Run inference on multiple inputs.
|
||||||
Contexts are borrowed/returned for each input, enabling parallel processing.
|
Contexts are borrowed/returned for each input, enabling parallel processing.
|
||||||
|
|
@ -641,9 +714,13 @@ class TensorRTModelRepository:
|
||||||
if remaining_refs == 0:
|
if remaining_refs == 0:
|
||||||
shared_engine.cleanup()
|
shared_engine.cleanup()
|
||||||
del self._shared_engines[file_hash]
|
del self._shared_engines[file_hash]
|
||||||
print(f"Model '{model_id}' unloaded, engine removed from VRAM (0 references)")
|
print(
|
||||||
|
f"Model '{model_id}' unloaded, engine removed from VRAM (0 references)"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(f"Model '{model_id}' unloaded, engine kept in VRAM ({remaining_refs} references)")
|
print(
|
||||||
|
f"Model '{model_id}' unloaded, engine kept in VRAM ({remaining_refs} references)"
|
||||||
|
)
|
||||||
|
|
||||||
# Remove from model_id mapping
|
# Remove from model_id mapping
|
||||||
del self._model_to_hash[model_id]
|
del self._model_to_hash[model_id]
|
||||||
|
|
@ -702,26 +779,26 @@ class TensorRTModelRepository:
|
||||||
metadata = shared_engine.metadata
|
metadata = shared_engine.metadata
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'model_id': model_id,
|
"model_id": model_id,
|
||||||
'file_path': metadata.file_path,
|
"file_path": metadata.file_path,
|
||||||
'file_hash': metadata.file_hash[:16] + '...',
|
"file_hash": metadata.file_hash[:16] + "...",
|
||||||
'engine_references': shared_engine.get_reference_count(),
|
"engine_references": shared_engine.get_reference_count(),
|
||||||
'context_pool_size': shared_engine.num_contexts,
|
"context_pool_size": shared_engine.num_contexts,
|
||||||
'shared_with_model_ids': list(shared_engine.model_ids),
|
"shared_with_model_ids": list(shared_engine.model_ids),
|
||||||
'inputs': {
|
"inputs": {
|
||||||
name: {
|
name: {
|
||||||
'shape': metadata.input_shapes[name],
|
"shape": metadata.input_shapes[name],
|
||||||
'dtype': str(metadata.input_dtypes[name])
|
"dtype": str(metadata.input_dtypes[name]),
|
||||||
}
|
}
|
||||||
for name in metadata.input_names
|
for name in metadata.input_names
|
||||||
},
|
},
|
||||||
'outputs': {
|
"outputs": {
|
||||||
name: {
|
name: {
|
||||||
'shape': metadata.output_shapes[name],
|
"shape": metadata.output_shapes[name],
|
||||||
'dtype': str(metadata.output_dtypes[name])
|
"dtype": str(metadata.output_dtypes[name]),
|
||||||
}
|
}
|
||||||
for name in metadata.output_names
|
for name in metadata.output_names
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
|
@ -733,24 +810,25 @@ class TensorRTModelRepository:
|
||||||
"""
|
"""
|
||||||
with self._repo_lock:
|
with self._repo_lock:
|
||||||
total_contexts = sum(
|
total_contexts = sum(
|
||||||
engine.num_contexts
|
engine.num_contexts for engine in self._shared_engines.values()
|
||||||
for engine in self._shared_engines.values()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'total_model_ids': len(self._model_to_hash),
|
"total_model_ids": len(self._model_to_hash),
|
||||||
'unique_engines': len(self._shared_engines),
|
"unique_engines": len(self._shared_engines),
|
||||||
'total_contexts': total_contexts,
|
"total_contexts": total_contexts,
|
||||||
'memory_efficiency': f"{len(self._model_to_hash)} model IDs using only {len(self._shared_engines)} engines",
|
"memory_efficiency": f"{len(self._model_to_hash)} model IDs using only {len(self._shared_engines)} engines",
|
||||||
'gpu_id': self.gpu_id,
|
"gpu_id": self.gpu_id,
|
||||||
'models': list(self._model_to_hash.keys())
|
"models": list(self._model_to_hash.keys()),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
with self._repo_lock:
|
with self._repo_lock:
|
||||||
return (f"TensorRTModelRepository(gpu={self.gpu_id}, "
|
return (
|
||||||
f"model_ids={len(self._model_to_hash)}, "
|
f"TensorRTModelRepository(gpu={self.gpu_id}, "
|
||||||
f"unique_engines={len(self._shared_engines)})")
|
f"model_ids={len(self._model_to_hash)}, "
|
||||||
|
f"unique_engines={len(self._shared_engines)})"
|
||||||
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
"""Cleanup all models on deletion"""
|
"""Cleanup all models on deletion"""
|
||||||
|
|
|
||||||
|
|
@ -5,25 +5,28 @@ This module provides high-level connection management for multiple RTSP streams,
|
||||||
coordinating decoders, batched inference, and tracking with callbacks and threading.
|
coordinating decoders, batched inference, and tracking with callbacks and threading.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from typing import Dict, Optional, Callable, Tuple, Any, List
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .model_controller import ModelController
|
from .base_model_controller import BaseModelController
|
||||||
from .stream_decoder import StreamDecoderFactory
|
|
||||||
from .model_repository import TensorRTModelRepository
|
from .model_repository import TensorRTModelRepository
|
||||||
|
from .stream_decoder import StreamDecoderFactory
|
||||||
|
from .tensorrt_model_controller import TensorRTModelController
|
||||||
|
from .ultralytics_model_controller import UltralyticsModelController
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ConnectionStatus(Enum):
|
class ConnectionStatus(Enum):
|
||||||
"""Status of a stream connection"""
|
"""Status of a stream connection"""
|
||||||
|
|
||||||
CONNECTING = "connecting"
|
CONNECTING = "connecting"
|
||||||
CONNECTED = "connected"
|
CONNECTED = "connected"
|
||||||
DISCONNECTED = "disconnected"
|
DISCONNECTED = "disconnected"
|
||||||
|
|
@ -33,6 +36,7 @@ class ConnectionStatus(Enum):
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrackingResult:
|
class TrackingResult:
|
||||||
"""Result emitted to user callbacks"""
|
"""Result emitted to user callbacks"""
|
||||||
|
|
||||||
stream_id: str
|
stream_id: str
|
||||||
timestamp: float
|
timestamp: float
|
||||||
tracked_objects: List # List of TrackedObject from TrackingController
|
tracked_objects: List # List of TrackedObject from TrackingController
|
||||||
|
|
@ -61,7 +65,7 @@ class StreamConnection:
|
||||||
self,
|
self,
|
||||||
stream_id: str,
|
stream_id: str,
|
||||||
decoder,
|
decoder,
|
||||||
model_controller: ModelController,
|
model_controller: BaseModelController,
|
||||||
tracking_controller,
|
tracking_controller,
|
||||||
poll_interval: float = 0.01,
|
poll_interval: float = 0.01,
|
||||||
):
|
):
|
||||||
|
|
@ -107,7 +111,9 @@ class StreamConnection:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# Timeout - but don't fail hard, let it try to connect in background
|
# Timeout - but don't fail hard, let it try to connect in background
|
||||||
logger.warning(f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying...")
|
logger.warning(
|
||||||
|
f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying..."
|
||||||
|
)
|
||||||
self.status = ConnectionStatus.CONNECTING
|
self.status = ConnectionStatus.CONNECTING
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
|
|
@ -144,28 +150,42 @@ class StreamConnection:
|
||||||
self.last_frame_time = time.time()
|
self.last_frame_time = time.time()
|
||||||
self.frame_count += 1
|
self.frame_count += 1
|
||||||
|
|
||||||
|
# CRITICAL: Clone the GPU tensor to decouple from decoder's frame buffer
|
||||||
|
# The decoder reuses frame buffer memory, so we must copy the tensor
|
||||||
|
# before submitting to async batched inference to prevent race conditions
|
||||||
|
# where the decoder overwrites memory while inference is still reading it.
|
||||||
|
cloned_tensor = frame_ref.rgb_tensor.clone()
|
||||||
|
|
||||||
# Submit to model controller for batched inference
|
# Submit to model controller for batched inference
|
||||||
# Pass the FrameReference in metadata so we can free it later
|
# Pass the FrameReference in metadata so we can free it later
|
||||||
self.model_controller.submit_frame(
|
self.model_controller.submit_frame(
|
||||||
stream_id=self.stream_id,
|
stream_id=self.stream_id,
|
||||||
frame=frame_ref.rgb_tensor,
|
frame=cloned_tensor, # Use cloned tensor, not original
|
||||||
metadata={
|
metadata={
|
||||||
"frame_number": self.frame_count,
|
"frame_number": self.frame_count,
|
||||||
"shape": tuple(frame_ref.rgb_tensor.shape),
|
"shape": tuple(cloned_tensor.shape),
|
||||||
"frame_ref": frame_ref, # Store reference for later cleanup
|
"frame_ref": frame_ref, # Store reference for later cleanup
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update connection status based on decoder status
|
# Update connection status based on decoder status
|
||||||
if self.decoder.is_connected() and self.status != ConnectionStatus.CONNECTED:
|
if (
|
||||||
|
self.decoder.is_connected()
|
||||||
|
and self.status != ConnectionStatus.CONNECTED
|
||||||
|
):
|
||||||
logger.info(f"Stream {self.stream_id} reconnected")
|
logger.info(f"Stream {self.stream_id} reconnected")
|
||||||
self.status = ConnectionStatus.CONNECTED
|
self.status = ConnectionStatus.CONNECTED
|
||||||
elif not self.decoder.is_connected() and self.status == ConnectionStatus.CONNECTED:
|
elif (
|
||||||
|
not self.decoder.is_connected()
|
||||||
|
and self.status == ConnectionStatus.CONNECTED
|
||||||
|
):
|
||||||
logger.warning(f"Stream {self.stream_id} disconnected")
|
logger.warning(f"Stream {self.stream_id} disconnected")
|
||||||
self.status = ConnectionStatus.DISCONNECTED
|
self.status = ConnectionStatus.DISCONNECTED
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing frame for {self.stream_id}: {e}", exc_info=True)
|
logger.error(
|
||||||
|
f"Error processing frame for {self.stream_id}: {e}", exc_info=True
|
||||||
|
)
|
||||||
self.error_queue.put(e)
|
self.error_queue.put(e)
|
||||||
self.status = ConnectionStatus.ERROR
|
self.status = ConnectionStatus.ERROR
|
||||||
# Free the frame on error
|
# Free the frame on error
|
||||||
|
|
@ -205,7 +225,10 @@ class StreamConnection:
|
||||||
self.result_queue.put(tracking_result)
|
self.result_queue.put(tracking_result)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
|
logger.error(
|
||||||
|
f"Error handling inference result for {self.stream_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
self.error_queue.put(e)
|
self.error_queue.put(e)
|
||||||
finally:
|
finally:
|
||||||
# Free the frame reference - this is the last point in the pipeline
|
# Free the frame reference - this is the last point in the pipeline
|
||||||
|
|
@ -235,12 +258,16 @@ class StreamConnection:
|
||||||
if confidence < min_confidence:
|
if confidence < min_confidence:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
detection_list.append(Detection(
|
detection_list.append(
|
||||||
bbox=det[:4].cpu().tolist(),
|
Detection(
|
||||||
confidence=confidence,
|
bbox=det[:4].cpu().tolist(),
|
||||||
class_id=int(det[5]) if det.shape[0] > 5 else 0,
|
confidence=confidence,
|
||||||
class_name=f"class_{int(det[5])}" if det.shape[0] > 5 else "unknown"
|
class_id=int(det[5]) if det.shape[0] > 5 else 0,
|
||||||
))
|
class_name=f"class_{int(det[5])}"
|
||||||
|
if det.shape[0] > 5
|
||||||
|
else "unknown",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Update tracker with detections (will scale bboxes to frame space)
|
# Update tracker with detections (will scale bboxes to frame space)
|
||||||
return self.tracking_controller.update(detection_list, frame_shape=frame_shape)
|
return self.tracking_controller.update(detection_list, frame_shape=frame_shape)
|
||||||
|
|
@ -319,21 +346,38 @@ class StreamConnectionManager:
|
||||||
force_timeout: float = 0.05,
|
force_timeout: float = 0.05,
|
||||||
poll_interval: float = 0.01,
|
poll_interval: float = 0.01,
|
||||||
enable_pt_conversion: bool = True,
|
enable_pt_conversion: bool = True,
|
||||||
|
backend: str = "tensorrt", # "tensorrt" or "ultralytics"
|
||||||
):
|
):
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.force_timeout = force_timeout
|
self.force_timeout = force_timeout
|
||||||
self.poll_interval = poll_interval
|
self.poll_interval = poll_interval
|
||||||
|
self.backend = backend.lower()
|
||||||
|
|
||||||
# Factories
|
# Factories
|
||||||
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
|
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
|
||||||
self.model_repository = TensorRTModelRepository(
|
|
||||||
gpu_id=gpu_id,
|
# Initialize inference engine based on backend
|
||||||
enable_pt_conversion=enable_pt_conversion
|
self.inference_engine = None
|
||||||
)
|
self.model_repository = None # Legacy - will be removed
|
||||||
|
|
||||||
|
if self.backend == "ultralytics":
|
||||||
|
# Use Ultralytics native YOLO inference
|
||||||
|
from .inference_engine import UltralyticsEngine
|
||||||
|
|
||||||
|
self.inference_engine = UltralyticsEngine()
|
||||||
|
logger.info("Using Ultralytics inference engine")
|
||||||
|
else:
|
||||||
|
# Use native TensorRT inference
|
||||||
|
self.model_repository = TensorRTModelRepository(
|
||||||
|
gpu_id=gpu_id, enable_pt_conversion=enable_pt_conversion
|
||||||
|
)
|
||||||
|
logger.info("Using native TensorRT inference engine")
|
||||||
|
|
||||||
# Controllers
|
# Controllers
|
||||||
self.model_controller: Optional[ModelController] = None
|
self.model_controller = (
|
||||||
|
None # Will be TensorRTModelController or UltralyticsModelController
|
||||||
|
)
|
||||||
|
|
||||||
# Connections
|
# Connections
|
||||||
self.connections: Dict[str, StreamConnection] = {}
|
self.connections: Dict[str, StreamConnection] = {}
|
||||||
|
|
@ -350,7 +394,7 @@ class StreamConnectionManager:
|
||||||
num_contexts: int = 4,
|
num_contexts: int = 4,
|
||||||
pt_input_shapes: Optional[Dict] = None,
|
pt_input_shapes: Optional[Dict] = None,
|
||||||
pt_precision: Optional[Any] = None,
|
pt_precision: Optional[Any] = None,
|
||||||
**pt_conversion_kwargs
|
**pt_conversion_kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the manager with a model.
|
Initialize the manager with a model.
|
||||||
|
|
@ -382,28 +426,58 @@ class StreamConnectionManager:
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
|
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
|
||||||
|
logger.info(f"Backend: {self.backend}")
|
||||||
|
|
||||||
# Load model (synchronous)
|
# Initialize engine based on backend
|
||||||
self.model_repository.load_model(
|
if self.backend == "ultralytics":
|
||||||
model_id,
|
# Use Ultralytics native inference
|
||||||
model_path,
|
logger.info("Initializing Ultralytics YOLO engine...")
|
||||||
num_contexts=num_contexts,
|
device = torch.device(f"cuda:{self.gpu_id}")
|
||||||
pt_input_shapes=pt_input_shapes,
|
|
||||||
pt_precision=pt_precision,
|
|
||||||
**pt_conversion_kwargs
|
|
||||||
)
|
|
||||||
logger.info(f"Loaded model {model_id} from {model_path}")
|
|
||||||
|
|
||||||
# Create model controller
|
metadata = self.inference_engine.initialize(
|
||||||
self.model_controller = ModelController(
|
model_path=model_path,
|
||||||
model_repository=self.model_repository,
|
device=device,
|
||||||
model_id=model_id,
|
batch=self.batch_size,
|
||||||
batch_size=self.batch_size,
|
half=False, # Use FP32 for now
|
||||||
force_timeout=self.force_timeout,
|
imgsz=640,
|
||||||
preprocess_fn=preprocess_fn,
|
**pt_conversion_kwargs,
|
||||||
postprocess_fn=postprocess_fn,
|
)
|
||||||
)
|
logger.info(f"Ultralytics engine initialized: {metadata}")
|
||||||
self.model_controller.start()
|
|
||||||
|
# Create Ultralytics model controller
|
||||||
|
self.model_controller = UltralyticsModelController(
|
||||||
|
inference_engine=self.inference_engine,
|
||||||
|
model_id=model_id,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
force_timeout=self.force_timeout,
|
||||||
|
preprocess_fn=preprocess_fn,
|
||||||
|
postprocess_fn=postprocess_fn,
|
||||||
|
)
|
||||||
|
self.model_controller.start()
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Use native TensorRT with model repository
|
||||||
|
logger.info("Initializing TensorRT engine...")
|
||||||
|
self.model_repository.load_model(
|
||||||
|
model_id,
|
||||||
|
model_path,
|
||||||
|
num_contexts=num_contexts,
|
||||||
|
pt_input_shapes=pt_input_shapes,
|
||||||
|
pt_precision=pt_precision,
|
||||||
|
**pt_conversion_kwargs,
|
||||||
|
)
|
||||||
|
logger.info(f"Loaded model {model_id} from {model_path}")
|
||||||
|
|
||||||
|
# Create TensorRT model controller
|
||||||
|
self.model_controller = TensorRTModelController(
|
||||||
|
model_repository=self.model_repository,
|
||||||
|
model_id=model_id,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
force_timeout=self.force_timeout,
|
||||||
|
preprocess_fn=preprocess_fn,
|
||||||
|
postprocess_fn=postprocess_fn,
|
||||||
|
)
|
||||||
|
self.model_controller.start()
|
||||||
|
|
||||||
# Don't create a shared tracking controller here
|
# Don't create a shared tracking controller here
|
||||||
# Each stream will get its own tracking controller to avoid track accumulation
|
# Each stream will get its own tracking controller to avoid track accumulation
|
||||||
|
|
@ -452,12 +526,13 @@ class StreamConnectionManager:
|
||||||
|
|
||||||
# Create lightweight tracker (NO model_repository dependency!)
|
# Create lightweight tracker (NO model_repository dependency!)
|
||||||
from .tracking_controller import ObjectTracker
|
from .tracking_controller import ObjectTracker
|
||||||
|
|
||||||
tracking_controller = ObjectTracker(
|
tracking_controller = ObjectTracker(
|
||||||
gpu_id=self.gpu_id,
|
gpu_id=self.gpu_id,
|
||||||
tracker_type="iou",
|
tracker_type="iou",
|
||||||
max_age=30,
|
max_age=30,
|
||||||
iou_threshold=0.3,
|
iou_threshold=0.3,
|
||||||
class_names=None # TODO: pass class names if available
|
class_names=None, # TODO: pass class names if available
|
||||||
)
|
)
|
||||||
logger.info(f"Created lightweight ObjectTracker for stream {stream_id}")
|
logger.info(f"Created lightweight ObjectTracker for stream {stream_id}")
|
||||||
|
|
||||||
|
|
@ -472,8 +547,7 @@ class StreamConnectionManager:
|
||||||
|
|
||||||
# Register callback with model controller
|
# Register callback with model controller
|
||||||
self.model_controller.register_callback(
|
self.model_controller.register_callback(
|
||||||
stream_id,
|
stream_id, connection._handle_inference_result
|
||||||
connection._handle_inference_result
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start connection
|
# Start connection
|
||||||
|
|
@ -487,14 +561,12 @@ class StreamConnectionManager:
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self._forward_results,
|
target=self._forward_results,
|
||||||
args=(connection, on_tracking_result),
|
args=(connection, on_tracking_result),
|
||||||
daemon=True
|
daemon=True,
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
if on_error:
|
if on_error:
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self._forward_errors,
|
target=self._forward_errors, args=(connection, on_error), daemon=True
|
||||||
args=(connection, on_error),
|
|
||||||
daemon=True
|
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
logger.info(f"Stream {stream_id} connected successfully")
|
logger.info(f"Stream {stream_id} connected successfully")
|
||||||
|
|
@ -549,7 +621,10 @@ class StreamConnectionManager:
|
||||||
for result in connection.tracking_results():
|
for result in connection.tracking_results():
|
||||||
callback(result)
|
callback(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True)
|
logger.error(
|
||||||
|
f"Error in result forwarding for {connection.stream_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
def _forward_errors(self, connection: StreamConnection, callback: Callable):
|
def _forward_errors(self, connection: StreamConnection, callback: Callable):
|
||||||
"""
|
"""
|
||||||
|
|
@ -563,7 +638,10 @@ class StreamConnectionManager:
|
||||||
for error in connection.errors():
|
for error in connection.errors():
|
||||||
callback(error)
|
callback(error)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True)
|
logger.error(
|
||||||
|
f"Error in error forwarding for {connection.stream_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -581,7 +659,9 @@ class StreamConnectionManager:
|
||||||
"force_timeout": self.force_timeout,
|
"force_timeout": self.force_timeout,
|
||||||
"poll_interval": self.poll_interval,
|
"poll_interval": self.poll_interval,
|
||||||
},
|
},
|
||||||
"model_controller": self.model_controller.get_stats() if self.model_controller else {},
|
"model_controller": self.model_controller.get_stats()
|
||||||
|
if self.model_controller
|
||||||
|
else {},
|
||||||
"connections": {
|
"connections": {
|
||||||
stream_id: conn.get_stats()
|
stream_id: conn.get_stats()
|
||||||
for stream_id, conn in self.connections.items()
|
for stream_id, conn in self.connections.items()
|
||||||
|
|
|
||||||
182
services/tensorrt_model_controller.py
Normal file
182
services/tensorrt_model_controller.py
Normal file
|
|
@ -0,0 +1,182 @@
|
||||||
|
"""
|
||||||
|
TensorRT Model Controller - Native TensorRT inference with batched processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .base_model_controller import BaseModelController, BatchFrame
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TensorRTModelController(BaseModelController):
|
||||||
|
"""
|
||||||
|
Model controller for native TensorRT inference.
|
||||||
|
|
||||||
|
Uses TensorRTModelRepository for GPU-accelerated inference with
|
||||||
|
context pooling and deduplication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_repository,
|
||||||
|
model_id: str,
|
||||||
|
batch_size: int = 16,
|
||||||
|
force_timeout: float = 0.05,
|
||||||
|
preprocess_fn: Optional[Callable] = None,
|
||||||
|
postprocess_fn: Optional[Callable] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
model_id=model_id,
|
||||||
|
batch_size=batch_size,
|
||||||
|
force_timeout=force_timeout,
|
||||||
|
preprocess_fn=preprocess_fn,
|
||||||
|
postprocess_fn=postprocess_fn,
|
||||||
|
)
|
||||||
|
self.model_repository = model_repository
|
||||||
|
|
||||||
|
# Detect model's actual batch size from input shape
|
||||||
|
self.model_batch_size = self._detect_model_batch_size()
|
||||||
|
if self.model_batch_size == 1:
|
||||||
|
logger.warning(
|
||||||
|
f"Model '{model_id}' has fixed batch_size=1. "
|
||||||
|
f"Will process frames sequentially."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Model '{model_id}' supports batch_size={self.model_batch_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _detect_model_batch_size(self) -> int:
|
||||||
|
"""Detect the model's batch size from its input shape"""
|
||||||
|
try:
|
||||||
|
metadata = self.model_repository.get_metadata(self.model_id)
|
||||||
|
first_input_name = metadata.input_names[0]
|
||||||
|
input_shape = metadata.input_shapes[first_input_name]
|
||||||
|
batch_dim = input_shape[0]
|
||||||
|
|
||||||
|
if batch_dim == -1:
|
||||||
|
return self.batch_size # Dynamic batch size
|
||||||
|
else:
|
||||||
|
return batch_dim # Fixed batch size
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not detect model batch size: {e}. Assuming batch_size=1"
|
||||||
|
)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||||
|
"""Run TensorRT inference on a batch of frames"""
|
||||||
|
if self.model_batch_size == 1:
|
||||||
|
return self._run_sequential_inference(batch)
|
||||||
|
else:
|
||||||
|
return self._run_batched_inference(batch)
|
||||||
|
|
||||||
|
def _run_sequential_inference(
|
||||||
|
self, batch: List[BatchFrame]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Run inference sequentially for batch_size=1 models"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for batch_frame in batch:
|
||||||
|
# Preprocess frame
|
||||||
|
if self.preprocess_fn:
|
||||||
|
processed = self.preprocess_fn(batch_frame.frame)
|
||||||
|
else:
|
||||||
|
processed = (
|
||||||
|
batch_frame.frame.unsqueeze(0)
|
||||||
|
if batch_frame.frame.dim() == 3
|
||||||
|
else batch_frame.frame
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
outputs = self.model_repository.infer(
|
||||||
|
self.model_id, {"images": processed}, synchronize=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Postprocess
|
||||||
|
if self.postprocess_fn:
|
||||||
|
try:
|
||||||
|
detections = self.postprocess_fn(outputs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
|
||||||
|
)
|
||||||
|
detections = torch.zeros(
|
||||||
|
(0, 6), device=list(outputs.values())[0].device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
detections = outputs
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"stream_id": batch_frame.stream_id,
|
||||||
|
"timestamp": batch_frame.timestamp,
|
||||||
|
"detections": detections,
|
||||||
|
"metadata": batch_frame.metadata,
|
||||||
|
}
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _run_batched_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||||
|
"""Run true batched inference for models that support it"""
|
||||||
|
# Preprocess frames
|
||||||
|
preprocessed = []
|
||||||
|
for batch_frame in batch:
|
||||||
|
if self.preprocess_fn:
|
||||||
|
processed = self.preprocess_fn(batch_frame.frame)
|
||||||
|
if processed.dim() == 4 and processed.shape[0] == 1:
|
||||||
|
processed = processed.squeeze(0)
|
||||||
|
else:
|
||||||
|
processed = batch_frame.frame
|
||||||
|
preprocessed.append(processed)
|
||||||
|
|
||||||
|
# Stack into batch tensor
|
||||||
|
batch_tensor = torch.stack(preprocessed, dim=0)
|
||||||
|
|
||||||
|
# Limit to model's max batch size
|
||||||
|
if batch_tensor.shape[0] > self.model_batch_size:
|
||||||
|
logger.warning(
|
||||||
|
f"Batch size {batch_tensor.shape[0]} exceeds model max {self.model_batch_size}"
|
||||||
|
)
|
||||||
|
batch_tensor = batch_tensor[: self.model_batch_size]
|
||||||
|
batch = batch[: self.model_batch_size]
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
outputs = self.model_repository.infer(
|
||||||
|
self.model_id, {"images": batch_tensor}, synchronize=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Postprocess results (split batch back to individual results)
|
||||||
|
results = []
|
||||||
|
for i, batch_frame in enumerate(batch):
|
||||||
|
# Extract single frame output and clone for memory safety
|
||||||
|
frame_output = {}
|
||||||
|
for k, v in outputs.items():
|
||||||
|
frame_output[k] = v[i : i + 1].clone()
|
||||||
|
|
||||||
|
if self.postprocess_fn:
|
||||||
|
try:
|
||||||
|
detections = self.postprocess_fn(frame_output)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
|
||||||
|
)
|
||||||
|
detections = torch.zeros(
|
||||||
|
(0, 6), device=list(outputs.values())[0].device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
detections = frame_output
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"stream_id": batch_frame.stream_id,
|
||||||
|
"timestamp": batch_frame.timestamp,
|
||||||
|
"detections": detections,
|
||||||
|
"metadata": batch_frame.metadata,
|
||||||
|
}
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
222
services/ultralytics_exporter.py
Normal file
222
services/ultralytics_exporter.py
Normal file
|
|
@ -0,0 +1,222 @@
|
||||||
|
"""
|
||||||
|
Ultralytics YOLO Model Exporter with Caching
|
||||||
|
|
||||||
|
Exports YOLO .pt models to TensorRT .engine format using Ultralytics library.
|
||||||
|
Provides proper NMS and postprocessing built into the engine.
|
||||||
|
Caches exported engines to avoid redundant exports.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UltralyticsExporter:
|
||||||
|
"""
|
||||||
|
Export YOLO models using Ultralytics with caching.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Exports .pt models to TensorRT .engine format
|
||||||
|
- Caches exported engines by source file hash
|
||||||
|
- Saves metadata about exported models
|
||||||
|
- Reuses cached engines when available
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cache_dir: str = ".ultralytics_cache"):
|
||||||
|
"""
|
||||||
|
Initialize exporter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_dir: Directory for caching exported engines
|
||||||
|
"""
|
||||||
|
self.cache_dir = Path(cache_dir)
|
||||||
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
logger.info(f"Ultralytics exporter cache directory: {self.cache_dir}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_file_hash(file_path: str) -> str:
|
||||||
|
"""
|
||||||
|
Compute SHA256 hash of a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hexadecimal hash string
|
||||||
|
"""
|
||||||
|
sha256_hash = hashlib.sha256()
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
for byte_block in iter(lambda: f.read(65536), b""):
|
||||||
|
sha256_hash.update(byte_block)
|
||||||
|
return sha256_hash.hexdigest()
|
||||||
|
|
||||||
|
def export(
|
||||||
|
self,
|
||||||
|
model_path: str,
|
||||||
|
device: int = 0,
|
||||||
|
half: bool = False,
|
||||||
|
imgsz: int = 640,
|
||||||
|
batch: int = 1,
|
||||||
|
**export_kwargs,
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Export YOLO model to TensorRT engine with caching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to .pt model file
|
||||||
|
device: GPU device ID
|
||||||
|
half: Use FP16 precision
|
||||||
|
imgsz: Input image size (default: 640)
|
||||||
|
batch: Maximum batch size for inference
|
||||||
|
**export_kwargs: Additional arguments for Ultralytics export
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (engine_hash, engine_path)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If model file doesn't exist
|
||||||
|
RuntimeError: If export fails
|
||||||
|
"""
|
||||||
|
model_path = Path(model_path).resolve()
|
||||||
|
|
||||||
|
if not model_path.exists():
|
||||||
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||||
|
|
||||||
|
# Compute hash of source model
|
||||||
|
logger.info(f"Computing hash for {model_path}...")
|
||||||
|
model_hash = self.compute_file_hash(str(model_path))
|
||||||
|
logger.info(f"Model hash: {model_hash[:16]}...")
|
||||||
|
|
||||||
|
# Create export config hash (includes export parameters)
|
||||||
|
export_config = {
|
||||||
|
"model_hash": model_hash,
|
||||||
|
"device": device,
|
||||||
|
"half": half,
|
||||||
|
"imgsz": imgsz,
|
||||||
|
"batch": batch,
|
||||||
|
**export_kwargs,
|
||||||
|
}
|
||||||
|
config_str = json.dumps(export_config, sort_keys=True)
|
||||||
|
config_hash = hashlib.sha256(config_str.encode()).hexdigest()
|
||||||
|
|
||||||
|
# Check cache
|
||||||
|
cache_engine_path = self.cache_dir / f"{config_hash}.engine"
|
||||||
|
cache_metadata_path = self.cache_dir / f"{config_hash}_metadata.json"
|
||||||
|
|
||||||
|
if cache_engine_path.exists():
|
||||||
|
logger.info(f"Found cached engine: {cache_engine_path}")
|
||||||
|
logger.info(f"Reusing cached export (config hash: {config_hash[:16]}...)")
|
||||||
|
|
||||||
|
# Load and return metadata
|
||||||
|
if cache_metadata_path.exists():
|
||||||
|
with open(cache_metadata_path, "r") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
logger.info(f"Cached engine metadata: {metadata}")
|
||||||
|
|
||||||
|
return config_hash, str(cache_engine_path)
|
||||||
|
|
||||||
|
# Export using Ultralytics
|
||||||
|
logger.info(f"Exporting YOLO model to TensorRT engine...")
|
||||||
|
logger.info(f" Source: {model_path}")
|
||||||
|
logger.info(f" Device: GPU {device}")
|
||||||
|
logger.info(f" Precision: {'FP16' if half else 'FP32'}")
|
||||||
|
logger.info(f" Image size: {imgsz}")
|
||||||
|
logger.info(f" Batch size: {batch}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model = YOLO(str(model_path))
|
||||||
|
|
||||||
|
# Export to TensorRT
|
||||||
|
exported_path = model.export(
|
||||||
|
format="engine",
|
||||||
|
device=device,
|
||||||
|
half=half,
|
||||||
|
imgsz=imgsz,
|
||||||
|
batch=batch,
|
||||||
|
verbose=True,
|
||||||
|
**export_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Export complete: {exported_path}")
|
||||||
|
|
||||||
|
# Copy to cache
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
shutil.copy(exported_path, cache_engine_path)
|
||||||
|
logger.info(f"Cached engine: {cache_engine_path}")
|
||||||
|
|
||||||
|
# Save metadata
|
||||||
|
metadata = {
|
||||||
|
"source_model": str(model_path),
|
||||||
|
"model_hash": model_hash,
|
||||||
|
"config_hash": config_hash,
|
||||||
|
"device": device,
|
||||||
|
"half": half,
|
||||||
|
"imgsz": imgsz,
|
||||||
|
"batch": batch,
|
||||||
|
"export_kwargs": export_kwargs,
|
||||||
|
"exported_path": str(exported_path),
|
||||||
|
"cached_path": str(cache_engine_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(cache_metadata_path, "w") as f:
|
||||||
|
json.dump(metadata, f, indent=2)
|
||||||
|
|
||||||
|
logger.info(f"Saved metadata: {cache_metadata_path}")
|
||||||
|
|
||||||
|
return config_hash, str(cache_engine_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Export failed: {e}")
|
||||||
|
raise RuntimeError(f"Failed to export YOLO model: {e}")
|
||||||
|
|
||||||
|
def get_cached_engine(self, model_path: str, **export_kwargs) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get cached engine path if it exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to .pt model
|
||||||
|
**export_kwargs: Export parameters (must match cached export)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to cached engine or None if not cached
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_path = Path(model_path).resolve()
|
||||||
|
|
||||||
|
if not model_path.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Compute hashes
|
||||||
|
model_hash = self.compute_file_hash(str(model_path))
|
||||||
|
|
||||||
|
export_config = {"model_hash": model_hash, **export_kwargs}
|
||||||
|
config_str = json.dumps(export_config, sort_keys=True)
|
||||||
|
config_hash = hashlib.sha256(config_str.encode()).hexdigest()
|
||||||
|
|
||||||
|
cache_engine_path = self.cache_dir / f"{config_hash}.engine"
|
||||||
|
|
||||||
|
if cache_engine_path.exists():
|
||||||
|
return str(cache_engine_path)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to check cache: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def clear_cache(self):
|
||||||
|
"""Clear all cached engines"""
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
if self.cache_dir.exists():
|
||||||
|
shutil.rmtree(self.cache_dir)
|
||||||
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
logger.info("Cache cleared")
|
||||||
217
services/ultralytics_model_controller.py
Normal file
217
services/ultralytics_model_controller.py
Normal file
|
|
@ -0,0 +1,217 @@
|
||||||
|
"""
|
||||||
|
Ultralytics Model Controller - YOLO inference with batched processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .base_model_controller import BaseModelController, BatchFrame
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UltralyticsModelController(BaseModelController):
|
||||||
|
"""
|
||||||
|
Model controller for Ultralytics YOLO inference.
|
||||||
|
|
||||||
|
Uses UltralyticsEngine which wraps the Ultralytics YOLO model with
|
||||||
|
native TensorRT backend for GPU-accelerated inference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inference_engine,
|
||||||
|
model_id: str,
|
||||||
|
batch_size: int = 16,
|
||||||
|
force_timeout: float = 0.05,
|
||||||
|
preprocess_fn: Optional[Callable] = None,
|
||||||
|
postprocess_fn: Optional[Callable] = None,
|
||||||
|
):
|
||||||
|
# Auto-detect actual batch size from the YOLO engine
|
||||||
|
engine_batch_size = self._detect_engine_batch_size(inference_engine)
|
||||||
|
|
||||||
|
# If engine has fixed batch size, use it. Otherwise use user's batch_size
|
||||||
|
actual_batch_size = engine_batch_size if engine_batch_size > 0 else batch_size
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
model_id=model_id,
|
||||||
|
batch_size=actual_batch_size,
|
||||||
|
force_timeout=force_timeout,
|
||||||
|
preprocess_fn=preprocess_fn,
|
||||||
|
postprocess_fn=postprocess_fn,
|
||||||
|
)
|
||||||
|
self.inference_engine = inference_engine
|
||||||
|
self.engine_batch_size = engine_batch_size # Store for padding logic
|
||||||
|
|
||||||
|
if engine_batch_size > 0:
|
||||||
|
logger.info(
|
||||||
|
f"Ultralytics engine has fixed batch_size={engine_batch_size}, "
|
||||||
|
f"will pad batches to match"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Ultralytics engine supports dynamic batching, "
|
||||||
|
f"using max batch_size={actual_batch_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _detect_engine_batch_size(self, inference_engine) -> int:
|
||||||
|
"""
|
||||||
|
Detect the batch size from Ultralytics engine.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Fixed batch size (e.g., 2, 4, 8) or -1 for dynamic batching
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get engine metadata
|
||||||
|
metadata = inference_engine.get_metadata()
|
||||||
|
|
||||||
|
# Check input shape for batch dimension
|
||||||
|
if "images" in metadata.input_shapes:
|
||||||
|
input_shape = metadata.input_shapes["images"]
|
||||||
|
batch_dim = input_shape[0]
|
||||||
|
|
||||||
|
if batch_dim > 0:
|
||||||
|
# Fixed batch size
|
||||||
|
return batch_dim
|
||||||
|
else:
|
||||||
|
# Dynamic batch size (-1)
|
||||||
|
return -1
|
||||||
|
|
||||||
|
# Fallback: try to get from model directly
|
||||||
|
if (
|
||||||
|
hasattr(inference_engine, "_model")
|
||||||
|
and inference_engine._model is not None
|
||||||
|
):
|
||||||
|
model = inference_engine._model
|
||||||
|
|
||||||
|
# Try to get batch info from Ultralytics model
|
||||||
|
if hasattr(model, "predictor") and model.predictor is not None:
|
||||||
|
predictor = model.predictor
|
||||||
|
if hasattr(predictor, "model") and hasattr(
|
||||||
|
predictor.model, "batch"
|
||||||
|
):
|
||||||
|
return predictor.model.batch
|
||||||
|
|
||||||
|
# Try to get from model.model (for .engine files)
|
||||||
|
if hasattr(model, "model"):
|
||||||
|
# For TensorRT engines, check input shape
|
||||||
|
if hasattr(model.model, "get_input_details"):
|
||||||
|
details = model.model.get_input_details()
|
||||||
|
if details and len(details) > 0:
|
||||||
|
shape = details[0].get("shape")
|
||||||
|
if shape and len(shape) > 0:
|
||||||
|
return shape[0] if shape[0] > 0 else -1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not detect engine batch size: {e}")
|
||||||
|
|
||||||
|
# Default: assume dynamic batching
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Run Ultralytics YOLO inference on a batch of frames.
|
||||||
|
|
||||||
|
Ultralytics handles batching natively and returns Results objects.
|
||||||
|
"""
|
||||||
|
# Preprocess frames
|
||||||
|
preprocessed = []
|
||||||
|
for batch_frame in batch:
|
||||||
|
if self.preprocess_fn:
|
||||||
|
processed = self.preprocess_fn(batch_frame.frame)
|
||||||
|
# Ensure shape is (C, H, W) not (1, C, H, W)
|
||||||
|
if processed.dim() == 4 and processed.shape[0] == 1:
|
||||||
|
processed = processed.squeeze(0)
|
||||||
|
else:
|
||||||
|
processed = batch_frame.frame
|
||||||
|
preprocessed.append(processed)
|
||||||
|
|
||||||
|
# Stack into batch tensor: (B, C, H, W)
|
||||||
|
batch_tensor = torch.stack(preprocessed, dim=0)
|
||||||
|
actual_batch_size = len(batch)
|
||||||
|
|
||||||
|
# Handle fixed batch size engines (pad if needed)
|
||||||
|
if self.engine_batch_size > 0:
|
||||||
|
# Engine has fixed batch size
|
||||||
|
if batch_tensor.shape[0] > self.engine_batch_size:
|
||||||
|
# Truncate to engine's max batch size
|
||||||
|
logger.warning(
|
||||||
|
f"Batch size {batch_tensor.shape[0]} exceeds engine max {self.engine_batch_size}, truncating"
|
||||||
|
)
|
||||||
|
batch_tensor = batch_tensor[: self.engine_batch_size]
|
||||||
|
batch = batch[: self.engine_batch_size]
|
||||||
|
actual_batch_size = self.engine_batch_size
|
||||||
|
elif batch_tensor.shape[0] < self.engine_batch_size:
|
||||||
|
# Pad to match engine's fixed batch size
|
||||||
|
padding_size = self.engine_batch_size - batch_tensor.shape[0]
|
||||||
|
# Replicate last frame to pad (cheaper than zeros)
|
||||||
|
padding = batch_tensor[-1:].repeat(padding_size, 1, 1, 1)
|
||||||
|
batch_tensor = torch.cat([batch_tensor, padding], dim=0)
|
||||||
|
logger.debug(
|
||||||
|
f"Padded batch from {actual_batch_size} to {self.engine_batch_size} frames"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Dynamic batching - just limit to max
|
||||||
|
if batch_tensor.shape[0] > self.batch_size:
|
||||||
|
logger.warning(
|
||||||
|
f"Batch size {batch_tensor.shape[0]} exceeds configured max {self.batch_size}"
|
||||||
|
)
|
||||||
|
batch_tensor = batch_tensor[: self.batch_size]
|
||||||
|
batch = batch[: self.batch_size]
|
||||||
|
actual_batch_size = self.batch_size
|
||||||
|
|
||||||
|
# Run Ultralytics inference
|
||||||
|
# Input should be (B, 3, H, W) in range [0, 1], RGB format
|
||||||
|
outputs = self.inference_engine.infer(
|
||||||
|
inputs={"images": batch_tensor},
|
||||||
|
conf=0.25, # Confidence threshold
|
||||||
|
iou=0.45, # NMS IoU threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ultralytics returns Results objects in outputs["results"]
|
||||||
|
yolo_results = outputs["results"]
|
||||||
|
|
||||||
|
# Convert Results objects to our standard format
|
||||||
|
# Only process actual batch size (ignore padded results if any)
|
||||||
|
results = []
|
||||||
|
for i in range(actual_batch_size):
|
||||||
|
batch_frame = batch[i]
|
||||||
|
yolo_result = yolo_results[i]
|
||||||
|
# Extract detections from YOLO Results object
|
||||||
|
# yolo_result.boxes.data has format: [x1, y1, x2, y2, conf, cls]
|
||||||
|
if hasattr(yolo_result, "boxes") and yolo_result.boxes is not None:
|
||||||
|
detections = yolo_result.boxes.data # Already a tensor on GPU
|
||||||
|
else:
|
||||||
|
# No detections
|
||||||
|
detections = torch.zeros((0, 6), device=batch_tensor.device)
|
||||||
|
|
||||||
|
# Apply custom postprocessing if provided
|
||||||
|
if self.postprocess_fn:
|
||||||
|
try:
|
||||||
|
# For Ultralytics, postprocess_fn might do additional filtering
|
||||||
|
# Pass the raw boxes tensor in the same format as TensorRT output
|
||||||
|
detections = self.postprocess_fn(
|
||||||
|
{
|
||||||
|
"output0": detections.unsqueeze(
|
||||||
|
0
|
||||||
|
) # Add batch dim for compatibility
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
|
||||||
|
)
|
||||||
|
detections = torch.zeros((0, 6), device=batch_tensor.device)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"stream_id": batch_frame.stream_id,
|
||||||
|
"timestamp": batch_frame.timestamp,
|
||||||
|
"detections": detections,
|
||||||
|
"metadata": batch_frame.metadata,
|
||||||
|
"yolo_result": yolo_result, # Keep original Results object for debugging
|
||||||
|
}
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
@ -101,14 +101,15 @@ class YOLOv8Utils:
|
||||||
|
|
||||||
# Get output tensor (first and only output)
|
# Get output tensor (first and only output)
|
||||||
output_name = list(outputs.keys())[0]
|
output_name = list(outputs.keys())[0]
|
||||||
output = outputs[output_name] # (1, 84, 8400)
|
output = outputs[output_name] # (1, 4+num_classes, 8400)
|
||||||
|
|
||||||
# Transpose to (1, 8400, 84) for easier processing
|
# Transpose to (1, 8400, 4+num_classes) for easier processing
|
||||||
output = output.transpose(1, 2).squeeze(0) # (8400, 84)
|
output = output.transpose(1, 2).squeeze(0) # (8400, 4+num_classes)
|
||||||
|
|
||||||
# Split bbox coordinates and class scores (vectorized)
|
# Split bbox coordinates and class scores (vectorized)
|
||||||
|
# Format: [cx, cy, w, h, class_scores...]
|
||||||
bboxes = output[:, :4] # (8400, 4) - (cx, cy, w, h)
|
bboxes = output[:, :4] # (8400, 4) - (cx, cy, w, h)
|
||||||
class_scores = output[:, 4:] # (8400, 80)
|
class_scores = output[:, 4:] # (8400, num_classes) - dynamically sized
|
||||||
|
|
||||||
# Get max class score and corresponding class ID for all anchors (vectorized)
|
# Get max class score and corresponding class ID for all anchors (vectorized)
|
||||||
max_scores, class_ids = torch.max(class_scores, dim=1) # (8400,), (8400,)
|
max_scores, class_ids = torch.max(class_scores, dim=1) # (8400,), (8400,)
|
||||||
|
|
|
||||||
|
|
@ -9,193 +9,25 @@ This script demonstrates:
|
||||||
- Automatic PT to TensorRT conversion
|
- Automatic PT to TensorRT conversion
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
|
||||||
import os
|
import os
|
||||||
import torch
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from services import (
|
from services import (
|
||||||
StreamConnectionManager,
|
|
||||||
YOLOv8Utils,
|
|
||||||
COCO_CLASSES,
|
COCO_CLASSES,
|
||||||
|
StreamConnectionManager,
|
||||||
|
UltralyticsExporter,
|
||||||
|
YOLOv8Utils,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
def main_single_stream():
|
|
||||||
"""Single stream example with event-driven architecture."""
|
|
||||||
print("=" * 80)
|
|
||||||
print("Event-Driven GPU-Accelerated Object Tracking - Single Stream")
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
# Configuration
|
|
||||||
GPU_ID = 0
|
|
||||||
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # Transparent loading: .pt, .engine, or .trt
|
|
||||||
STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
|
||||||
BATCH_SIZE = 4
|
|
||||||
FORCE_TIMEOUT = 0.05
|
|
||||||
ENABLE_DISPLAY = os.getenv('ENABLE_DISPLAY', 'false').lower() == 'true' # Set to 'true' to enable OpenCV display
|
|
||||||
MAX_FRAMES = int(os.getenv('MAX_FRAMES', '300')) # Stop after N frames (0 = unlimited)
|
|
||||||
|
|
||||||
print(f"\nConfiguration:")
|
|
||||||
print(f" GPU: {GPU_ID}")
|
|
||||||
print(f" Model: {MODEL_PATH}")
|
|
||||||
print(f" Stream: {STREAM_URL}")
|
|
||||||
print(f" Batch size: {BATCH_SIZE}")
|
|
||||||
print(f" Force timeout: {FORCE_TIMEOUT}s")
|
|
||||||
print(f" Display: {'Enabled' if ENABLE_DISPLAY else 'Disabled (inference only)'}")
|
|
||||||
print(f" Max frames: {MAX_FRAMES if MAX_FRAMES > 0 else 'Unlimited'}\n")
|
|
||||||
|
|
||||||
# Create StreamConnectionManager with PT conversion enabled
|
|
||||||
print("[1/3] Creating StreamConnectionManager...")
|
|
||||||
manager = StreamConnectionManager(
|
|
||||||
gpu_id=GPU_ID,
|
|
||||||
batch_size=BATCH_SIZE,
|
|
||||||
force_timeout=FORCE_TIMEOUT,
|
|
||||||
enable_pt_conversion=True # Enable PT conversion
|
|
||||||
)
|
|
||||||
print("✓ Manager created")
|
|
||||||
|
|
||||||
# Initialize with model (transparent loading - no manual parameters needed)
|
|
||||||
print("\n[2/3] Initializing model...")
|
|
||||||
print("Note: YOLO models auto-convert to native TensorRT .engine (first time only)")
|
|
||||||
print("Metadata is auto-detected from model - no manual input_shapes needed!\n")
|
|
||||||
|
|
||||||
try:
|
|
||||||
manager.initialize(
|
|
||||||
model_path=MODEL_PATH,
|
|
||||||
model_id="detector",
|
|
||||||
preprocess_fn=YOLOv8Utils.preprocess,
|
|
||||||
postprocess_fn=YOLOv8Utils.postprocess,
|
|
||||||
num_contexts=4
|
|
||||||
# Note: No pt_input_shapes or pt_precision needed for YOLO models!
|
|
||||||
)
|
|
||||||
print("✓ Manager initialized")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Failed to initialize: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return
|
|
||||||
|
|
||||||
# Connect stream
|
|
||||||
print("\n[3/3] Connecting to stream...")
|
|
||||||
try:
|
|
||||||
connection = manager.connect_stream(
|
|
||||||
rtsp_url=STREAM_URL,
|
|
||||||
stream_id="camera_1",
|
|
||||||
buffer_size=30
|
|
||||||
)
|
|
||||||
print(f"✓ Stream connected: camera_1")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Failed to connect stream: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"\n{'=' * 80}")
|
|
||||||
print("Event-driven tracking is running!")
|
|
||||||
print("Press Ctrl+C to stop")
|
|
||||||
print(f"{'=' * 80}\n")
|
|
||||||
|
|
||||||
# Stream results with optional OpenCV visualization
|
|
||||||
result_count = 0
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# Create window only if display is enabled
|
|
||||||
if ENABLE_DISPLAY:
|
|
||||||
cv2.namedWindow("Object Tracking", cv2.WINDOW_NORMAL)
|
|
||||||
cv2.resizeWindow("Object Tracking", 1280, 720)
|
|
||||||
|
|
||||||
try:
|
|
||||||
for result in connection.tracking_results():
|
|
||||||
result_count += 1
|
|
||||||
|
|
||||||
# Check if we've reached max frames
|
|
||||||
if MAX_FRAMES > 0 and result_count >= MAX_FRAMES:
|
|
||||||
print(f"\n✓ Reached max frames limit ({MAX_FRAMES})")
|
|
||||||
break
|
|
||||||
|
|
||||||
# OpenCV visualization (only if enabled)
|
|
||||||
if ENABLE_DISPLAY:
|
|
||||||
# Get latest frame from decoder (as CPU numpy array)
|
|
||||||
frame = connection.decoder.get_latest_frame_cpu(rgb=True)
|
|
||||||
|
|
||||||
if frame is not None:
|
|
||||||
# Convert to BGR for OpenCV
|
|
||||||
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
|
||||||
|
|
||||||
# Draw tracked objects
|
|
||||||
for obj in result.tracked_objects:
|
|
||||||
# Get bbox coordinates
|
|
||||||
x1, y1, x2, y2 = map(int, obj.bbox)
|
|
||||||
|
|
||||||
# Draw bounding box
|
|
||||||
cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
||||||
|
|
||||||
# Draw track ID and class name
|
|
||||||
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
|
|
||||||
label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
|
||||||
|
|
||||||
# Draw label background
|
|
||||||
cv2.rectangle(frame_bgr, (x1, y1 - label_size[1] - 10),
|
|
||||||
(x1 + label_size[0], y1), (0, 255, 0), -1)
|
|
||||||
|
|
||||||
# Draw label text
|
|
||||||
cv2.putText(frame_bgr, label, (x1, y1 - 5),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
|
|
||||||
|
|
||||||
# Draw FPS and object count
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
fps = result_count / elapsed if elapsed > 0 else 0
|
|
||||||
info_text = f"FPS: {fps:.1f} | Objects: {len(result.tracked_objects)} | Frame: {result_count}"
|
|
||||||
cv2.putText(frame_bgr, info_text, (10, 30),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
|
||||||
|
|
||||||
# Display frame
|
|
||||||
cv2.imshow("Object Tracking", frame_bgr)
|
|
||||||
|
|
||||||
# Check for 'q' key to quit
|
|
||||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
|
||||||
print(f"\n✓ Quit by user (pressed 'q')")
|
|
||||||
break
|
|
||||||
|
|
||||||
# Print stats every 30 results
|
|
||||||
if result_count % 30 == 0:
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
fps = result_count / elapsed if elapsed > 0 else 0
|
|
||||||
|
|
||||||
print(f"\nResults: {result_count} | FPS: {fps:.1f}")
|
|
||||||
print(f" Stream: {result.stream_id}")
|
|
||||||
print(f" Objects: {len(result.tracked_objects)}")
|
|
||||||
|
|
||||||
if result.tracked_objects:
|
|
||||||
class_counts = {}
|
|
||||||
for obj in result.tracked_objects:
|
|
||||||
class_counts[obj.class_name] = class_counts.get(obj.class_name, 0) + 1
|
|
||||||
print(f" Classes: {class_counts}")
|
|
||||||
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print(f"\n✓ Interrupted by user")
|
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
print(f"\n{'=' * 80}")
|
|
||||||
print("Cleanup")
|
|
||||||
print(f"{'=' * 80}")
|
|
||||||
|
|
||||||
# Close OpenCV window if it was opened
|
|
||||||
if ENABLE_DISPLAY:
|
|
||||||
cv2.destroyAllWindows()
|
|
||||||
|
|
||||||
connection.stop()
|
|
||||||
manager.shutdown()
|
|
||||||
print("✓ Stopped")
|
|
||||||
|
|
||||||
# Final stats
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
avg_fps = result_count / elapsed if elapsed > 0 else 0
|
|
||||||
print(f"\nFinal: {result_count} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)")
|
|
||||||
|
|
||||||
|
|
||||||
def main_multi_stream():
|
def main_multi_stream():
|
||||||
"""Multi-stream example with batched inference."""
|
"""Multi-stream example with batched inference."""
|
||||||
|
|
@ -206,14 +38,18 @@ def main_multi_stream():
|
||||||
# Configuration
|
# Configuration
|
||||||
GPU_ID = 0
|
GPU_ID = 0
|
||||||
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # Transparent loading: .pt, .engine, or .trt
|
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # Transparent loading: .pt, .engine, or .trt
|
||||||
BATCH_SIZE = 16
|
USE_ULTRALYTICS = (
|
||||||
|
os.getenv("USE_ULTRALYTICS", "true").lower() == "true"
|
||||||
|
) # Use Ultralytics engine for YOLO
|
||||||
|
BATCH_SIZE = 2 # Reduced to 2 to avoid GPU memory issues
|
||||||
FORCE_TIMEOUT = 0.05
|
FORCE_TIMEOUT = 0.05
|
||||||
|
ENABLE_DISPLAY = os.getenv("ENABLE_DISPLAY", "true").lower() == "true"
|
||||||
|
|
||||||
# Load camera URLs
|
# Load camera URLs
|
||||||
camera_urls = []
|
camera_urls = []
|
||||||
i = 1
|
i = 1
|
||||||
while True:
|
while True:
|
||||||
url = os.getenv(f'CAMERA_URL_{i}')
|
url = os.getenv(f"CAMERA_URL_{i}")
|
||||||
if url:
|
if url:
|
||||||
camera_urls.append((f"camera_{i}", url))
|
camera_urls.append((f"camera_{i}", url))
|
||||||
i += 1
|
i += 1
|
||||||
|
|
@ -230,13 +66,16 @@ def main_multi_stream():
|
||||||
print(f" Streams: {len(camera_urls)}")
|
print(f" Streams: {len(camera_urls)}")
|
||||||
print(f" Batch size: {BATCH_SIZE}\n")
|
print(f" Batch size: {BATCH_SIZE}\n")
|
||||||
|
|
||||||
# Create manager with PT conversion
|
# Create manager with backend selection
|
||||||
print("[1/3] Creating StreamConnectionManager...")
|
print("[1/3] Creating StreamConnectionManager...")
|
||||||
|
backend = "ultralytics"
|
||||||
|
print(f" Backend: {backend}")
|
||||||
manager = StreamConnectionManager(
|
manager = StreamConnectionManager(
|
||||||
gpu_id=GPU_ID,
|
gpu_id=GPU_ID,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
force_timeout=FORCE_TIMEOUT,
|
force_timeout=FORCE_TIMEOUT,
|
||||||
enable_pt_conversion=True
|
enable_pt_conversion=True,
|
||||||
|
backend=backend,
|
||||||
)
|
)
|
||||||
print("✓ Manager created")
|
print("✓ Manager created")
|
||||||
|
|
||||||
|
|
@ -248,30 +87,52 @@ def main_multi_stream():
|
||||||
model_id="detector",
|
model_id="detector",
|
||||||
preprocess_fn=YOLOv8Utils.preprocess,
|
preprocess_fn=YOLOv8Utils.preprocess,
|
||||||
postprocess_fn=YOLOv8Utils.postprocess,
|
postprocess_fn=YOLOv8Utils.postprocess,
|
||||||
num_contexts=8
|
num_contexts=1, # Single context to minimize GPU memory usage
|
||||||
# Note: No pt_input_shapes or pt_precision needed for YOLO models!
|
# Note: No pt_input_shapes or pt_precision needed for YOLO models!
|
||||||
)
|
)
|
||||||
print("✓ Manager initialized")
|
print("✓ Manager initialized")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Failed to initialize: {e}")
|
print(f"✗ Failed to initialize: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Connect all streams
|
# Connect all streams in parallel using threads
|
||||||
print(f"\n[3/3] Connecting {len(camera_urls)} streams...")
|
print(f"\n[3/3] Connecting {len(camera_urls)} streams in parallel...")
|
||||||
connections = {}
|
connections = {}
|
||||||
for stream_id, rtsp_url in camera_urls:
|
connection_threads = []
|
||||||
|
connection_results = {}
|
||||||
|
|
||||||
|
def connect_stream(stream_id, rtsp_url):
|
||||||
|
"""Thread worker to connect a single stream"""
|
||||||
try:
|
try:
|
||||||
conn = manager.connect_stream(
|
conn = manager.connect_stream(
|
||||||
rtsp_url=rtsp_url,
|
rtsp_url=rtsp_url, stream_id=stream_id, buffer_size=3
|
||||||
stream_id=stream_id,
|
|
||||||
buffer_size=5
|
|
||||||
)
|
)
|
||||||
connections[stream_id] = conn
|
connection_results[stream_id] = ("success", conn)
|
||||||
print(f"✓ Connected: {stream_id}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Failed {stream_id}: {e}")
|
connection_results[stream_id] = ("error", str(e))
|
||||||
|
|
||||||
|
# Start all connection threads
|
||||||
|
for stream_id, rtsp_url in camera_urls:
|
||||||
|
thread = threading.Thread(
|
||||||
|
target=connect_stream, args=(stream_id, rtsp_url), daemon=True
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
|
connection_threads.append(thread)
|
||||||
|
|
||||||
|
# Wait for all connections to complete
|
||||||
|
for thread in connection_threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
# Collect results
|
||||||
|
for stream_id, (status, result) in connection_results.items():
|
||||||
|
if status == "success":
|
||||||
|
connections[stream_id] = result
|
||||||
|
print(f"✓ Connected: {stream_id}")
|
||||||
|
else:
|
||||||
|
print(f"✗ Failed {stream_id}: {result}")
|
||||||
|
|
||||||
if not connections:
|
if not connections:
|
||||||
print("No streams connected")
|
print("No streams connected")
|
||||||
|
|
@ -284,10 +145,20 @@ def main_multi_stream():
|
||||||
print(f"{'=' * 80}\n")
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
# Track stats
|
# Track stats
|
||||||
stream_stats = {sid: {'count': 0, 'start': time.time()} for sid in connections.keys()}
|
stream_stats = {
|
||||||
|
sid: {"count": 0, "start": time.time()} for sid in connections.keys()
|
||||||
|
}
|
||||||
total_results = 0
|
total_results = 0
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Create windows for each stream if display enabled
|
||||||
|
if ENABLE_DISPLAY:
|
||||||
|
for stream_id in connections.keys():
|
||||||
|
cv2.namedWindow(stream_id, cv2.WINDOW_NORMAL)
|
||||||
|
cv2.resizeWindow(
|
||||||
|
stream_id, 640, 360
|
||||||
|
) # Smaller windows for multiple streams
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Merge all result queues from all connections
|
# Merge all result queues from all connections
|
||||||
import queue as queue_module
|
import queue as queue_module
|
||||||
|
|
@ -306,27 +177,92 @@ def main_multi_stream():
|
||||||
stream_id = result.stream_id
|
stream_id = result.stream_id
|
||||||
|
|
||||||
if stream_id in stream_stats:
|
if stream_id in stream_stats:
|
||||||
stream_stats[stream_id]['count'] += 1
|
stream_stats[stream_id]["count"] += 1
|
||||||
|
|
||||||
|
# Display visualization if enabled
|
||||||
|
if ENABLE_DISPLAY:
|
||||||
|
# Get latest frame from decoder (already in CPU memory as numpy RGB)
|
||||||
|
frame_rgb = conn.decoder.get_latest_frame_cpu(rgb=True)
|
||||||
|
if frame_rgb is not None:
|
||||||
|
# Convert RGB to BGR for OpenCV
|
||||||
|
frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# Draw bounding boxes
|
||||||
|
for obj in result.tracked_objects:
|
||||||
|
x1, y1, x2, y2 = map(int, obj.bbox)
|
||||||
|
|
||||||
|
# Draw box
|
||||||
|
cv2.rectangle(
|
||||||
|
frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Draw label with ID and class
|
||||||
|
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
|
||||||
|
(label_w, label_h), _ = cv2.getTextSize(
|
||||||
|
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
||||||
|
)
|
||||||
|
cv2.rectangle(
|
||||||
|
frame_bgr,
|
||||||
|
(x1, y1 - label_h - 10),
|
||||||
|
(x1 + label_w, y1),
|
||||||
|
(0, 255, 0),
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
cv2.putText(
|
||||||
|
frame_bgr,
|
||||||
|
label,
|
||||||
|
(x1, y1 - 5),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.5,
|
||||||
|
(0, 0, 0),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show FPS on frame
|
||||||
|
s_elapsed = time.time() - stream_stats[stream_id]["start"]
|
||||||
|
s_fps = (
|
||||||
|
stream_stats[stream_id]["count"] / s_elapsed
|
||||||
|
if s_elapsed > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
fps_text = f"{stream_id}: {s_fps:.1f} FPS | {len(result.tracked_objects)} objects"
|
||||||
|
cv2.putText(
|
||||||
|
frame_bgr,
|
||||||
|
fps_text,
|
||||||
|
(10, 30),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.7,
|
||||||
|
(0, 255, 0),
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Display
|
||||||
|
cv2.imshow(stream_id, frame_bgr)
|
||||||
|
|
||||||
# Print stats every 100 results
|
# Print stats every 100 results
|
||||||
if total_results % 100 == 0:
|
if total_results % 100 == 0:
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
total_fps = total_results / elapsed if elapsed > 0 else 0
|
total_fps = total_results / elapsed if elapsed > 0 else 0
|
||||||
|
|
||||||
print(f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS")
|
print(
|
||||||
|
f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS"
|
||||||
|
)
|
||||||
for sid, stats in stream_stats.items():
|
for sid, stats in stream_stats.items():
|
||||||
s_elapsed = time.time() - stats['start']
|
s_elapsed = time.time() - stats["start"]
|
||||||
s_fps = stats['count'] / s_elapsed if s_elapsed > 0 else 0
|
s_fps = stats["count"] / s_elapsed if s_elapsed > 0 else 0
|
||||||
print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)")
|
print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)")
|
||||||
|
|
||||||
except queue_module.Empty:
|
except queue_module.Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Process OpenCV events to keep windows responsive
|
||||||
|
if ENABLE_DISPLAY:
|
||||||
|
cv2.waitKey(1)
|
||||||
|
|
||||||
# Small sleep if no results to avoid busy loop
|
# Small sleep if no results to avoid busy loop
|
||||||
if not got_result:
|
if not got_result:
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print(f"\n✓ Interrupted")
|
print(f"\n✓ Interrupted")
|
||||||
|
|
||||||
|
|
@ -335,6 +271,10 @@ def main_multi_stream():
|
||||||
print("Cleanup")
|
print("Cleanup")
|
||||||
print(f"{'=' * 80}")
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
# Close OpenCV windows if they were opened
|
||||||
|
if ENABLE_DISPLAY:
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
for conn in connections.values():
|
for conn in connections.values():
|
||||||
conn.stop()
|
conn.stop()
|
||||||
manager.shutdown()
|
manager.shutdown()
|
||||||
|
|
@ -347,8 +287,4 @@ def main_multi_stream():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
main_multi_stream()
|
||||||
if len(sys.argv) > 1 and sys.argv[1] == "single":
|
|
||||||
main_single_stream()
|
|
||||||
else:
|
|
||||||
main_multi_stream()
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue