converter system
This commit is contained in:
parent
d3dbf9a580
commit
748fb71980
9 changed files with 1012 additions and 14 deletions
174
examples/jpeg_encode.py
Executable file
174
examples/jpeg_encode.py
Executable file
|
|
@ -0,0 +1,174 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for JPEG encoding with nvImageCodec
|
||||
Tests GPU-accelerated JPEG encoding from RTSP stream frames
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
from services import StreamDecoderFactory
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Test JPEG encoding from RTSP stream')
|
||||
parser.add_argument(
|
||||
'--rtsp-url',
|
||||
type=str,
|
||||
default=None,
|
||||
help='RTSP stream URL (defaults to CAMERA_URL_1 from .env)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output-dir',
|
||||
type=str,
|
||||
default='./snapshots',
|
||||
help='Output directory for JPEG files'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--num-frames',
|
||||
type=int,
|
||||
default=10,
|
||||
help='Number of frames to capture'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--interval',
|
||||
type=float,
|
||||
default=1.0,
|
||||
help='Interval between captures in seconds'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--quality',
|
||||
type=int,
|
||||
default=95,
|
||||
help='JPEG quality (0-100)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--gpu-id',
|
||||
type=int,
|
||||
default=0,
|
||||
help='GPU device ID'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get RTSP URL from command line or environment
|
||||
rtsp_url = args.rtsp_url
|
||||
if not rtsp_url:
|
||||
rtsp_url = os.getenv('CAMERA_URL_1')
|
||||
if not rtsp_url:
|
||||
print("Error: No RTSP URL provided")
|
||||
print("Please either:")
|
||||
print(" 1. Use --rtsp-url argument, or")
|
||||
print(" 2. Add CAMERA_URL_1 to your .env file")
|
||||
sys.exit(1)
|
||||
|
||||
# Create output directory
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print("=" * 80)
|
||||
print("RTSP Stream JPEG Encoding Test")
|
||||
print("=" * 80)
|
||||
print(f"RTSP URL: {rtsp_url}")
|
||||
print(f"Output Directory: {output_dir}")
|
||||
print(f"Number of Frames: {args.num_frames}")
|
||||
print(f"Capture Interval: {args.interval}s")
|
||||
print(f"JPEG Quality: {args.quality}")
|
||||
print(f"GPU ID: {args.gpu_id}")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
try:
|
||||
# Initialize factory and decoder
|
||||
print("[1/3] Initializing StreamDecoderFactory...")
|
||||
factory = StreamDecoderFactory(gpu_id=args.gpu_id)
|
||||
print("✓ Factory initialized\n")
|
||||
|
||||
print("[2/3] Creating and starting decoder...")
|
||||
decoder = factory.create_decoder(
|
||||
rtsp_url=rtsp_url,
|
||||
buffer_size=30
|
||||
)
|
||||
decoder.start()
|
||||
print("✓ Decoder started\n")
|
||||
|
||||
# Wait for connection
|
||||
print("[3/3] Waiting for stream to connect...")
|
||||
max_wait = 10
|
||||
for i in range(max_wait):
|
||||
if decoder.is_connected():
|
||||
print("✓ Stream connected\n")
|
||||
break
|
||||
time.sleep(1)
|
||||
print(f" Waiting... {i+1}/{max_wait}s")
|
||||
else:
|
||||
print("✗ Failed to connect to stream")
|
||||
sys.exit(1)
|
||||
|
||||
# Capture frames
|
||||
print(f"Capturing {args.num_frames} frames...")
|
||||
print("-" * 80)
|
||||
|
||||
captured = 0
|
||||
for i in range(args.num_frames):
|
||||
# Get frame as JPEG
|
||||
start_time = time.time()
|
||||
jpeg_bytes = decoder.get_frame_as_jpeg(quality=args.quality)
|
||||
encode_time = (time.time() - start_time) * 1000 # ms
|
||||
|
||||
if jpeg_bytes:
|
||||
# Save to file
|
||||
filename = output_dir / f"frame_{i:04d}.jpg"
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(jpeg_bytes)
|
||||
|
||||
size_kb = len(jpeg_bytes) / 1024
|
||||
print(f"[{i+1}/{args.num_frames}] Saved {filename.name} "
|
||||
f"({size_kb:.1f} KB, encoded in {encode_time:.2f}ms)")
|
||||
captured += 1
|
||||
else:
|
||||
print(f"[{i+1}/{args.num_frames}] Failed to get frame")
|
||||
|
||||
# Wait before next capture (except for last frame)
|
||||
if i < args.num_frames - 1:
|
||||
time.sleep(args.interval)
|
||||
|
||||
print("-" * 80)
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 80)
|
||||
print("Capture Complete")
|
||||
print("=" * 80)
|
||||
print(f"Successfully captured: {captured}/{args.num_frames} frames")
|
||||
print(f"Output directory: {output_dir.absolute()}")
|
||||
print("=" * 80)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n✗ Interrupted by user")
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n\n✗ Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if 'decoder' in locals():
|
||||
print("\nCleaning up...")
|
||||
decoder.stop()
|
||||
print("✓ Decoder stopped")
|
||||
|
||||
print("\n✓ Test completed successfully")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -6,3 +6,4 @@ av
|
|||
cuda-python
|
||||
nvidia-nvimgcodec-cu12 # GPU-accelerated JPEG encoding/decoding with nvJPEG
|
||||
python-dotenv # Load environment variables from .env file
|
||||
torch_tensorrt
|
||||
|
|
@ -10,6 +10,8 @@ from .tracking_factory import TrackingFactory
|
|||
from .yolo import YOLOv8Utils, COCO_CLASSES
|
||||
from .model_controller import ModelController, BatchFrame, BufferState
|
||||
from .stream_connection_manager import StreamConnectionManager, StreamConnection, TrackingResult
|
||||
from .pt_converter import PTConverter
|
||||
from .modelstorage import IModelStorage, FileModelStorage
|
||||
|
||||
__all__ = [
|
||||
'StreamDecoderFactory',
|
||||
|
|
@ -32,4 +34,7 @@ __all__ = [
|
|||
'StreamConnectionManager',
|
||||
'StreamConnection',
|
||||
'TrackingResult',
|
||||
'PTConverter',
|
||||
'IModelStorage',
|
||||
'FileModelStorage',
|
||||
]
|
||||
|
|
|
|||
|
|
@ -6,6 +6,9 @@ from queue import Queue
|
|||
import torch
|
||||
import tensorrt as trt
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -158,17 +161,19 @@ class TensorRTModelRepository:
|
|||
# Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts!
|
||||
"""
|
||||
|
||||
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4):
|
||||
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4, enable_pt_conversion: bool = True):
|
||||
"""
|
||||
Initialize the model repository.
|
||||
|
||||
Args:
|
||||
gpu_id: GPU device ID to use
|
||||
default_num_contexts: Default number of execution contexts per unique engine
|
||||
enable_pt_conversion: Enable automatic PyTorch to TensorRT conversion
|
||||
"""
|
||||
self.gpu_id = gpu_id
|
||||
self.device = torch.device(f'cuda:{gpu_id}')
|
||||
self.default_num_contexts = default_num_contexts
|
||||
self.enable_pt_conversion = enable_pt_conversion
|
||||
|
||||
# Model ID to engine mapping: model_id -> file_hash
|
||||
self._model_to_hash: Dict[str, str] = {}
|
||||
|
|
@ -182,8 +187,22 @@ class TensorRTModelRepository:
|
|||
# TensorRT logger
|
||||
self.trt_logger = trt.Logger(trt.Logger.WARNING)
|
||||
|
||||
# PT converter (lazy initialization)
|
||||
self._pt_converter = None
|
||||
|
||||
print(f"TensorRT Model Repository initialized on GPU {gpu_id}")
|
||||
print(f"Default context pool size: {default_num_contexts} contexts per unique model")
|
||||
if enable_pt_conversion:
|
||||
print(f"PyTorch to TensorRT conversion: enabled")
|
||||
|
||||
@property
|
||||
def pt_converter(self):
|
||||
"""Lazy initialization of PT converter"""
|
||||
if self._pt_converter is None and self.enable_pt_conversion:
|
||||
from .pt_converter import PTConverter
|
||||
self._pt_converter = PTConverter(gpu_id=self.gpu_id)
|
||||
logger.info("PT converter initialized")
|
||||
return self._pt_converter
|
||||
|
||||
@staticmethod
|
||||
def compute_file_hash(file_path: str) -> str:
|
||||
|
|
@ -282,18 +301,26 @@ class TensorRTModelRepository:
|
|||
|
||||
def load_model(self, model_id: str, file_path: str,
|
||||
num_contexts: Optional[int] = None,
|
||||
force_reload: bool = False) -> ModelMetadata:
|
||||
force_reload: bool = False,
|
||||
pt_input_shapes: Optional[Dict[str, Tuple]] = None,
|
||||
pt_precision: Optional[torch.dtype] = None,
|
||||
**pt_conversion_kwargs) -> ModelMetadata:
|
||||
"""
|
||||
Load a TensorRT model with the given ID.
|
||||
|
||||
Supports both .trt and .pt files. PT files are automatically converted to TensorRT.
|
||||
|
||||
Deduplication: If a model with the same file hash is already loaded, the model_id
|
||||
is simply mapped to the existing SharedEngine (no new engine or contexts created).
|
||||
|
||||
Args:
|
||||
model_id: User-defined identifier for this model (e.g., "camera_1")
|
||||
file_path: Path to TensorRT engine file (.trt or .engine)
|
||||
file_path: Path to TensorRT engine file (.trt, .engine) or PyTorch file (.pt, .pth)
|
||||
num_contexts: Number of execution contexts in pool (None = use default)
|
||||
force_reload: If True, reload even if model_id exists
|
||||
pt_input_shapes: Required for .pt files - dict of input shapes (e.g., {"x": (1, 3, 224, 224)})
|
||||
pt_precision: Precision for PT conversion (torch.float16 or torch.float32)
|
||||
**pt_conversion_kwargs: Additional arguments for torch_tensorrt.compile()
|
||||
|
||||
Returns:
|
||||
ModelMetadata for the loaded model
|
||||
|
|
@ -301,13 +328,37 @@ class TensorRTModelRepository:
|
|||
Raises:
|
||||
FileNotFoundError: If model file doesn't exist
|
||||
RuntimeError: If engine loading fails
|
||||
ValueError: If model_id already exists and force_reload is False
|
||||
ValueError: If model_id already exists and force_reload is False, or PT conversion requires input_shapes
|
||||
"""
|
||||
file_path = str(Path(file_path).resolve())
|
||||
|
||||
if not Path(file_path).exists():
|
||||
raise FileNotFoundError(f"Model file not found: {file_path}")
|
||||
|
||||
# Check if file is PyTorch model
|
||||
file_ext = Path(file_path).suffix.lower()
|
||||
if file_ext in ['.pt', '.pth']:
|
||||
if not self.enable_pt_conversion:
|
||||
raise ValueError(
|
||||
f"PT file provided but PT conversion is disabled. "
|
||||
f"Enable with enable_pt_conversion=True or provide a .trt file."
|
||||
)
|
||||
|
||||
logger.info(f"Detected PyTorch model file: {file_path}")
|
||||
logger.info("Converting to TensorRT...")
|
||||
|
||||
# Convert PT to TRT
|
||||
trt_hash, trt_path = self.pt_converter.convert(
|
||||
file_path,
|
||||
input_shapes=pt_input_shapes,
|
||||
precision=pt_precision,
|
||||
**pt_conversion_kwargs
|
||||
)
|
||||
|
||||
# Update file_path to use converted TRT file
|
||||
file_path = trt_path
|
||||
logger.info(f"Will load converted TensorRT model from: {file_path}")
|
||||
|
||||
if num_contexts is None:
|
||||
num_contexts = self.default_num_contexts
|
||||
|
||||
|
|
|
|||
8
services/modelstorage/__init__.py
Normal file
8
services/modelstorage/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
"""
|
||||
Model storage module for managing TensorRT and PyTorch model files.
|
||||
"""
|
||||
|
||||
from .interface import IModelStorage
|
||||
from .file_storage import FileModelStorage
|
||||
|
||||
__all__ = ['IModelStorage', 'FileModelStorage']
|
||||
161
services/modelstorage/file_storage.py
Normal file
161
services/modelstorage/file_storage.py
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
"""
|
||||
FileModelStorage - Local filesystem implementation of IModelStorage.
|
||||
|
||||
Stores model files in a local directory structure.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
from .interface import IModelStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileModelStorage(IModelStorage):
|
||||
"""
|
||||
Local filesystem storage for model files.
|
||||
|
||||
Stores files in a directory structure:
|
||||
./models/trtptcache/
|
||||
trt/
|
||||
<hash1>.trt
|
||||
<hash2>.trt
|
||||
pt/
|
||||
<hash3>.pt
|
||||
"""
|
||||
|
||||
def __init__(self, base_path: str = "./models/trtptcache"):
|
||||
"""
|
||||
Initialize file storage.
|
||||
|
||||
Args:
|
||||
base_path: Base directory for storing files (default: ./models/trtptcache)
|
||||
"""
|
||||
self.base_path = Path(base_path).resolve()
|
||||
self._ensure_directories()
|
||||
|
||||
def _ensure_directories(self):
|
||||
"""Create base directory structure if it doesn't exist"""
|
||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Model storage initialized at: {self.base_path}")
|
||||
|
||||
def _get_full_path(self, key: str) -> Path:
|
||||
"""
|
||||
Get full filesystem path for a key.
|
||||
|
||||
Args:
|
||||
key: Storage key (e.g., "trt/hash123.trt")
|
||||
|
||||
Returns:
|
||||
Full filesystem path
|
||||
"""
|
||||
return self.base_path / key
|
||||
|
||||
def write(self, key: str, data: bytes) -> None:
|
||||
"""
|
||||
Write data to filesystem.
|
||||
|
||||
Args:
|
||||
key: Storage key (e.g., "trt/hash123.trt")
|
||||
data: Binary data to write
|
||||
|
||||
Raises:
|
||||
IOError: If write operation fails
|
||||
"""
|
||||
file_path = self._get_full_path(key)
|
||||
|
||||
# Ensure parent directory exists
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(data)
|
||||
logger.debug(f"Wrote {len(data)} bytes to {file_path}")
|
||||
except Exception as e:
|
||||
raise IOError(f"Failed to write to {file_path}: {e}")
|
||||
|
||||
def read(self, key: str) -> Optional[bytes]:
|
||||
"""
|
||||
Read data from filesystem.
|
||||
|
||||
Args:
|
||||
key: Storage key
|
||||
|
||||
Returns:
|
||||
Binary data if found, None otherwise
|
||||
|
||||
Raises:
|
||||
IOError: If read operation fails
|
||||
"""
|
||||
file_path = self._get_full_path(key)
|
||||
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
data = f.read()
|
||||
logger.debug(f"Read {len(data)} bytes from {file_path}")
|
||||
return data
|
||||
except Exception as e:
|
||||
raise IOError(f"Failed to read from {file_path}: {e}")
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if file exists.
|
||||
|
||||
Args:
|
||||
key: Storage key
|
||||
|
||||
Returns:
|
||||
True if file exists, False otherwise
|
||||
"""
|
||||
return self._get_full_path(key).exists()
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete file from filesystem.
|
||||
|
||||
Args:
|
||||
key: Storage key
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False if file didn't exist
|
||||
"""
|
||||
file_path = self._get_full_path(key)
|
||||
|
||||
if not file_path.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
file_path.unlink()
|
||||
logger.debug(f"Deleted {file_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete {file_path}: {e}")
|
||||
return False
|
||||
|
||||
def get_local_path(self, key: str) -> Optional[str]:
|
||||
"""
|
||||
Get local filesystem path for a key.
|
||||
|
||||
Args:
|
||||
key: Storage key
|
||||
|
||||
Returns:
|
||||
Local path if file exists, None otherwise
|
||||
"""
|
||||
file_path = self._get_full_path(key)
|
||||
return str(file_path) if file_path.exists() else None
|
||||
|
||||
def get_storage_path(self) -> str:
|
||||
"""
|
||||
Get the base storage path.
|
||||
|
||||
Returns:
|
||||
Base path where files are stored
|
||||
"""
|
||||
return str(self.base_path)
|
||||
91
services/modelstorage/interface.py
Normal file
91
services/modelstorage/interface.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
"""
|
||||
IModelStorage - Interface for model file storage.
|
||||
|
||||
Defines the contract for storing and retrieving model files (TensorRT, PyTorch, etc.)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, BinaryIO
|
||||
from pathlib import Path
|
||||
import io
|
||||
|
||||
|
||||
class IModelStorage(ABC):
|
||||
"""
|
||||
Interface for model file storage.
|
||||
|
||||
This abstraction allows swapping storage backends (local filesystem, S3, etc.)
|
||||
without changing the model conversion and loading logic.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def write(self, key: str, data: bytes) -> None:
|
||||
"""
|
||||
Write data to storage with the given key.
|
||||
|
||||
Args:
|
||||
key: Storage key (e.g., "trt/hash123.trt" or "pt/hash456.pt")
|
||||
data: Binary data to write
|
||||
|
||||
Raises:
|
||||
IOError: If write operation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def read(self, key: str) -> Optional[bytes]:
|
||||
"""
|
||||
Read data from storage by key.
|
||||
|
||||
Args:
|
||||
key: Storage key
|
||||
|
||||
Returns:
|
||||
Binary data if found, None otherwise
|
||||
|
||||
Raises:
|
||||
IOError: If read operation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if a key exists in storage.
|
||||
|
||||
Args:
|
||||
key: Storage key
|
||||
|
||||
Returns:
|
||||
True if key exists, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete data from storage.
|
||||
|
||||
Args:
|
||||
key: Storage key
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False if key didn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_local_path(self, key: str) -> Optional[str]:
|
||||
"""
|
||||
Get a local filesystem path for the key.
|
||||
|
||||
For local storage, this returns the direct path.
|
||||
For remote storage (S3), this may download to a temp location or return None.
|
||||
|
||||
Args:
|
||||
key: Storage key
|
||||
|
||||
Returns:
|
||||
Local path if available/downloaded, None if not supported
|
||||
"""
|
||||
pass
|
||||
485
services/pt_converter.py
Normal file
485
services/pt_converter.py
Normal file
|
|
@ -0,0 +1,485 @@
|
|||
"""
|
||||
PTConverter - Convert PyTorch models to TensorRT using torch_tensorrt.
|
||||
|
||||
This service handles conversion of .pt files to TensorRT format with caching
|
||||
to avoid redundant conversions. It maintains a mapping database between
|
||||
PT file hashes and their converted TensorRT file hashes.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Tuple
|
||||
import torch
|
||||
import torch_tensorrt
|
||||
|
||||
from .modelstorage import IModelStorage, FileModelStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PTConverter:
|
||||
"""
|
||||
PyTorch to TensorRT converter with intelligent caching.
|
||||
|
||||
Features:
|
||||
- Hash-based deduplication: Same PT file only converted once
|
||||
- Persistent mapping database: PT hash -> TRT hash mapping
|
||||
- Pluggable storage backend: IModelStorage interface
|
||||
- Automatic cache management
|
||||
|
||||
Architecture:
|
||||
1. Compute hash of input .pt file
|
||||
2. Check mapping database for existing conversion
|
||||
3. If found, return cached TRT file hash and path
|
||||
4. If not found, perform conversion and store mapping
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: Optional[IModelStorage] = None,
|
||||
gpu_id: int = 0,
|
||||
default_precision: torch.dtype = torch.float16
|
||||
):
|
||||
"""
|
||||
Initialize PT converter.
|
||||
|
||||
Args:
|
||||
storage: Storage backend (defaults to FileModelStorage)
|
||||
gpu_id: GPU device ID for conversion
|
||||
default_precision: Default precision for TensorRT conversion (fp16, fp32)
|
||||
"""
|
||||
self.storage = storage or FileModelStorage()
|
||||
self.gpu_id = gpu_id
|
||||
self.device = torch.device(f'cuda:{gpu_id}')
|
||||
self.default_precision = default_precision
|
||||
|
||||
# Mapping database: pt_hash -> {"trt_hash": str, "metadata": {...}}
|
||||
self.mapping_db_key = "pt_to_trt_mapping.json"
|
||||
self.mapping_db: Dict[str, Dict] = self._load_mapping_db()
|
||||
|
||||
logger.info(f"PTConverter initialized on GPU {gpu_id}")
|
||||
logger.info(f"Storage backend: {self.storage.__class__.__name__}")
|
||||
logger.info(f"Storage path: {self.storage.get_storage_path()}")
|
||||
logger.info(f"Loaded {len(self.mapping_db)} cached conversions")
|
||||
|
||||
def _load_mapping_db(self) -> Dict[str, Dict]:
|
||||
"""
|
||||
Load mapping database from storage.
|
||||
|
||||
Returns:
|
||||
Mapping dictionary (pt_hash -> metadata)
|
||||
"""
|
||||
try:
|
||||
data = self.storage.read(self.mapping_db_key)
|
||||
if data:
|
||||
db = json.loads(data.decode('utf-8'))
|
||||
logger.debug(f"Loaded mapping database with {len(db)} entries")
|
||||
return db
|
||||
else:
|
||||
logger.debug("No existing mapping database found, starting fresh")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load mapping database: {e}. Starting fresh.")
|
||||
return {}
|
||||
|
||||
def _save_mapping_db(self):
|
||||
"""Save mapping database to storage"""
|
||||
try:
|
||||
data = json.dumps(self.mapping_db, indent=2).encode('utf-8')
|
||||
self.storage.write(self.mapping_db_key, data)
|
||||
logger.debug(f"Saved mapping database with {len(self.mapping_db)} entries")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save mapping database: {e}")
|
||||
|
||||
@staticmethod
|
||||
def compute_file_hash(file_path: str) -> str:
|
||||
"""
|
||||
Compute SHA256 hash of a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
|
||||
Returns:
|
||||
Hexadecimal hash string
|
||||
"""
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for byte_block in iter(lambda: f.read(65536), b""):
|
||||
sha256_hash.update(byte_block)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
def get_cached_conversion(self, pt_hash: str) -> Optional[Tuple[str, str]]:
|
||||
"""
|
||||
Check if PT file has already been converted.
|
||||
|
||||
Args:
|
||||
pt_hash: SHA256 hash of the PT file
|
||||
|
||||
Returns:
|
||||
Tuple of (trt_hash, trt_file_path) if cached, None otherwise
|
||||
"""
|
||||
if pt_hash not in self.mapping_db:
|
||||
return None
|
||||
|
||||
mapping = self.mapping_db[pt_hash]
|
||||
trt_hash = mapping["trt_hash"]
|
||||
trt_key = f"trt/{trt_hash}.trt"
|
||||
|
||||
# Verify TRT file still exists in storage
|
||||
if not self.storage.exists(trt_key):
|
||||
logger.warning(
|
||||
f"Mapping exists for PT hash {pt_hash[:16]}... but TRT file missing. "
|
||||
f"Will reconvert."
|
||||
)
|
||||
# Remove stale mapping
|
||||
del self.mapping_db[pt_hash]
|
||||
self._save_mapping_db()
|
||||
return None
|
||||
|
||||
# Get local path
|
||||
trt_path = self.storage.get_local_path(trt_key)
|
||||
if trt_path is None:
|
||||
logger.error(f"Could not get local path for TRT file {trt_key}")
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"Found cached conversion for PT hash {pt_hash[:16]}... -> "
|
||||
f"TRT hash {trt_hash[:16]}..."
|
||||
)
|
||||
return (trt_hash, trt_path)
|
||||
|
||||
def convert(
|
||||
self,
|
||||
pt_file_path: str,
|
||||
input_shapes: Optional[Dict[str, Tuple]] = None,
|
||||
precision: Optional[torch.dtype] = None,
|
||||
**conversion_kwargs
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Convert PyTorch model to TensorRT.
|
||||
|
||||
If this PT file has been converted before (same hash), returns cached result.
|
||||
Otherwise, performs conversion and caches the result.
|
||||
|
||||
Args:
|
||||
pt_file_path: Path to .pt file
|
||||
input_shapes: Dict of input names to shapes (e.g., {"x": (1, 3, 224, 224)})
|
||||
precision: Target precision (fp16, fp32) - defaults to self.default_precision
|
||||
**conversion_kwargs: Additional arguments for torch_tensorrt.compile()
|
||||
|
||||
Returns:
|
||||
Tuple of (trt_hash, trt_file_path)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If PT file doesn't exist
|
||||
RuntimeError: If conversion fails
|
||||
"""
|
||||
pt_path = Path(pt_file_path).resolve()
|
||||
if not pt_path.exists():
|
||||
raise FileNotFoundError(f"PT file not found: {pt_file_path}")
|
||||
|
||||
# Compute PT file hash
|
||||
logger.info(f"Computing hash for {pt_path}...")
|
||||
pt_hash = self.compute_file_hash(str(pt_path))
|
||||
logger.info(f"PT file hash: {pt_hash[:16]}...")
|
||||
|
||||
# Check cache
|
||||
cached = self.get_cached_conversion(pt_hash)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
# Perform conversion
|
||||
logger.info(f"Converting {pt_path.name} to TensorRT...")
|
||||
trt_hash, trt_path = self._perform_conversion(
|
||||
str(pt_path),
|
||||
pt_hash,
|
||||
input_shapes,
|
||||
precision or self.default_precision,
|
||||
**conversion_kwargs
|
||||
)
|
||||
|
||||
# Store mapping
|
||||
self.mapping_db[pt_hash] = {
|
||||
"trt_hash": trt_hash,
|
||||
"pt_file": str(pt_path),
|
||||
"input_shapes": str(input_shapes),
|
||||
"precision": str(precision or self.default_precision),
|
||||
}
|
||||
self._save_mapping_db()
|
||||
|
||||
logger.info(f"Conversion complete: PT {pt_hash[:16]}... -> TRT {trt_hash[:16]}...")
|
||||
return (trt_hash, trt_path)
|
||||
|
||||
def _is_ultralytics_model(self, model) -> bool:
|
||||
"""
|
||||
Check if model is from ultralytics (YOLO).
|
||||
|
||||
Args:
|
||||
model: PyTorch model
|
||||
|
||||
Returns:
|
||||
True if ultralytics model, False otherwise
|
||||
"""
|
||||
model_class_name = model.__class__.__name__
|
||||
model_module = model.__class__.__module__
|
||||
|
||||
# Check if it's an ultralytics model
|
||||
is_ultralytics = (
|
||||
'ultralytics' in model_module or
|
||||
model_class_name in ['DetectionModel', 'SegmentationModel', 'PoseModel', 'ClassificationModel']
|
||||
)
|
||||
|
||||
return is_ultralytics
|
||||
|
||||
def _convert_ultralytics_model(
|
||||
self,
|
||||
pt_path: str,
|
||||
pt_hash: str,
|
||||
input_shapes: Optional[Dict[str, Tuple]],
|
||||
precision: torch.dtype,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Convert ultralytics YOLO model using ONNX → TensorRT pipeline.
|
||||
Uses the same approach as scripts/convert_pt_to_tensorrt.py
|
||||
|
||||
Args:
|
||||
pt_path: Path to PT file
|
||||
pt_hash: PT file hash
|
||||
input_shapes: Input tensor shapes
|
||||
precision: Target precision
|
||||
|
||||
Returns:
|
||||
Tuple of (trt_hash, trt_file_path)
|
||||
"""
|
||||
import tensorrt as trt
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
|
||||
logger.info("Detected ultralytics YOLO model, using ONNX → TensorRT pipeline...")
|
||||
|
||||
# Load ultralytics model
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(pt_path)
|
||||
except ImportError:
|
||||
raise ImportError("ultralytics package not found. Install with: pip install ultralytics")
|
||||
|
||||
# Determine input shape
|
||||
if not input_shapes:
|
||||
raise ValueError("input_shapes required for ultralytics conversion")
|
||||
|
||||
input_key = 'images' if 'images' in input_shapes else list(input_shapes.keys())[0]
|
||||
input_shape = input_shapes[input_key]
|
||||
|
||||
# Export to ONNX first
|
||||
logger.info(f"Exporting to ONNX (input shape: {input_shape})...")
|
||||
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmp_onnx:
|
||||
onnx_path = tmp_onnx.name
|
||||
|
||||
try:
|
||||
# Use ultralytics export to ONNX
|
||||
model.export(format='onnx', imgsz=input_shape[2], batch=input_shape[0])
|
||||
# Ultralytics saves as model_name.onnx in same directory
|
||||
pt_dir = os.path.dirname(pt_path)
|
||||
pt_name = os.path.splitext(os.path.basename(pt_path))[0]
|
||||
onnx_export_path = os.path.join(pt_dir, f"{pt_name}.onnx")
|
||||
|
||||
# Move to our temp location (use shutil.move for cross-device support)
|
||||
if os.path.exists(onnx_export_path):
|
||||
shutil.move(onnx_export_path, onnx_path)
|
||||
else:
|
||||
raise RuntimeError(f"ONNX export failed, file not found: {onnx_export_path}")
|
||||
|
||||
logger.info(f"ONNX export complete: {onnx_path}")
|
||||
|
||||
# Build TensorRT engine from ONNX
|
||||
logger.info("Building TensorRT engine from ONNX...")
|
||||
trt_logger = trt.Logger(trt.Logger.WARNING)
|
||||
builder = trt.Builder(trt_logger)
|
||||
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||||
parser = trt.OnnxParser(network, trt_logger)
|
||||
|
||||
# Parse ONNX
|
||||
with open(onnx_path, 'rb') as f:
|
||||
if not parser.parse(f.read()):
|
||||
errors = [parser.get_error(i) for i in range(parser.num_errors)]
|
||||
raise RuntimeError(f"Failed to parse ONNX: {errors}")
|
||||
|
||||
# Configure builder
|
||||
config = builder.create_builder_config()
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30) # 4GB
|
||||
|
||||
# Set precision
|
||||
if precision == torch.float16:
|
||||
if builder.platform_has_fast_fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
logger.info("FP16 mode enabled")
|
||||
|
||||
# Build engine
|
||||
logger.info("Building TensorRT engine (this may take a few minutes)...")
|
||||
serialized_engine = builder.build_serialized_network(network, config)
|
||||
|
||||
if serialized_engine is None:
|
||||
raise RuntimeError("Failed to build TensorRT engine")
|
||||
|
||||
# Convert IHostMemory to bytes
|
||||
engine_bytes = bytes(serialized_engine)
|
||||
|
||||
# Save to storage
|
||||
trt_hash = hashlib.sha256(engine_bytes).hexdigest()
|
||||
trt_key = f"trt/{trt_hash}.trt"
|
||||
self.storage.write(trt_key, engine_bytes)
|
||||
|
||||
trt_path = self.storage.get_local_path(trt_key)
|
||||
if trt_path is None:
|
||||
raise RuntimeError("Failed to get local path for TRT file")
|
||||
|
||||
logger.info(f"TensorRT engine built successfully: {trt_path}")
|
||||
return (trt_hash, trt_path)
|
||||
|
||||
finally:
|
||||
# Cleanup ONNX file
|
||||
if os.path.exists(onnx_path):
|
||||
os.unlink(onnx_path)
|
||||
|
||||
def _perform_conversion(
|
||||
self,
|
||||
pt_path: str,
|
||||
pt_hash: str,
|
||||
input_shapes: Optional[Dict[str, Tuple]],
|
||||
precision: torch.dtype,
|
||||
**conversion_kwargs
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Perform the actual PT to TRT conversion.
|
||||
|
||||
Args:
|
||||
pt_path: Path to PT file
|
||||
pt_hash: PT file hash
|
||||
input_shapes: Input tensor shapes
|
||||
precision: Target precision
|
||||
**conversion_kwargs: Additional torch_tensorrt arguments
|
||||
|
||||
Returns:
|
||||
Tuple of (trt_hash, trt_file_path)
|
||||
"""
|
||||
try:
|
||||
# Load PyTorch model to check type
|
||||
logger.debug(f"Loading PyTorch model from {pt_path}...")
|
||||
# Use weights_only=False for models with custom classes (like ultralytics)
|
||||
# This is safe for trusted local models
|
||||
loaded = torch.load(pt_path, map_location='cpu', weights_only=False)
|
||||
|
||||
# If model is wrapped in a dict, extract the model
|
||||
if isinstance(loaded, dict):
|
||||
if 'model' in loaded:
|
||||
model = loaded['model']
|
||||
elif 'state_dict' in loaded:
|
||||
raise ValueError(
|
||||
"PT file contains state_dict only. "
|
||||
"Please provide a full model or use a different loading method."
|
||||
)
|
||||
else:
|
||||
model = loaded
|
||||
|
||||
# Check if this is an ultralytics model
|
||||
if self._is_ultralytics_model(model):
|
||||
logger.info("Detected ultralytics model, using ultralytics export API")
|
||||
return self._convert_ultralytics_model(pt_path, pt_hash, input_shapes, precision)
|
||||
|
||||
# For non-ultralytics models, use torch_tensorrt
|
||||
logger.info("Using torch_tensorrt for conversion")
|
||||
model.eval()
|
||||
|
||||
# Convert model to target precision to avoid mixed precision issues
|
||||
if precision == torch.float16:
|
||||
model = model.half()
|
||||
elif precision == torch.float32:
|
||||
model = model.float()
|
||||
|
||||
# Move to GPU
|
||||
model = model.to(self.device)
|
||||
|
||||
# Prepare inputs for tracing
|
||||
if input_shapes is None:
|
||||
raise ValueError(
|
||||
"input_shapes must be provided for TensorRT conversion. "
|
||||
"Example: {'x': (1, 3, 224, 224)}"
|
||||
)
|
||||
|
||||
# Create sample inputs with matching precision
|
||||
inputs = []
|
||||
for name, shape in input_shapes.items():
|
||||
sample_input = torch.randn(shape, device=self.device, dtype=precision)
|
||||
inputs.append(sample_input)
|
||||
|
||||
# Configure torch_tensorrt
|
||||
enabled_precisions = {precision}
|
||||
if precision == torch.float16:
|
||||
enabled_precisions.add(torch.float32) # Fallback for unsupported ops
|
||||
|
||||
# Compile to TensorRT
|
||||
logger.info(f"Compiling to TensorRT (precision: {precision})...")
|
||||
trt_model = torch_tensorrt.compile(
|
||||
model,
|
||||
inputs=inputs,
|
||||
enabled_precisions=enabled_precisions,
|
||||
**conversion_kwargs
|
||||
)
|
||||
|
||||
# Save TRT model to temporary location
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.trt') as tmp_file:
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
torch.jit.save(trt_model, tmp_path)
|
||||
logger.debug(f"Saved TRT model to temporary file: {tmp_path}")
|
||||
|
||||
# Compute TRT file hash
|
||||
trt_hash = self.compute_file_hash(tmp_path)
|
||||
logger.info(f"TRT file hash: {trt_hash[:16]}...")
|
||||
|
||||
# Store in storage backend
|
||||
trt_key = f"trt/{trt_hash}.trt"
|
||||
with open(tmp_path, 'rb') as f:
|
||||
trt_data = f.read()
|
||||
self.storage.write(trt_key, trt_data)
|
||||
|
||||
# Get local path
|
||||
trt_path = self.storage.get_local_path(trt_key)
|
||||
if trt_path is None:
|
||||
raise RuntimeError("Failed to get local path for converted TRT file")
|
||||
|
||||
# Cleanup temp file
|
||||
Path(tmp_path).unlink()
|
||||
|
||||
logger.info(f"TRT model stored successfully at {trt_path}")
|
||||
return (trt_hash, trt_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Conversion failed: {e}", exc_info=True)
|
||||
raise RuntimeError(f"Failed to convert PT to TensorRT: {e}")
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear all cached conversions and mapping database"""
|
||||
logger.warning("Clearing all cached conversions...")
|
||||
self.mapping_db.clear()
|
||||
self._save_mapping_db()
|
||||
logger.info("Cache cleared")
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""
|
||||
Get conversion statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache stats
|
||||
"""
|
||||
return {
|
||||
"total_cached_conversions": len(self.mapping_db),
|
||||
"storage_path": self.storage.get_storage_path(),
|
||||
"gpu_id": self.gpu_id,
|
||||
"default_precision": str(self.default_precision),
|
||||
}
|
||||
|
|
@ -167,9 +167,11 @@ def main():
|
|||
"""
|
||||
Main function for real-time tracking visualization.
|
||||
"""
|
||||
import torch
|
||||
|
||||
# Configuration
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "models/yolov8n.trt"
|
||||
MODEL_PATH = "models/yolov8n.pt" # Changed to PT file
|
||||
RTSP_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
||||
BUFFER_SIZE = 30
|
||||
WINDOW_NAME = "Real-time Object Tracking"
|
||||
|
|
@ -178,18 +180,24 @@ def main():
|
|||
print("Real-time GPU-Accelerated Object Tracking")
|
||||
print("=" * 80)
|
||||
|
||||
# Step 1: Create model repository
|
||||
# Step 1: Create model repository with PT conversion enabled
|
||||
print("\n[1/4] Initializing TensorRT Model Repository...")
|
||||
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=4)
|
||||
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=4, enable_pt_conversion=True)
|
||||
|
||||
# Load detection model
|
||||
# Load detection model (will auto-convert PT to TRT)
|
||||
model_id = "yolov8_detector"
|
||||
if os.path.exists(MODEL_PATH):
|
||||
try:
|
||||
print(f"Loading model from {MODEL_PATH}...")
|
||||
print("Note: First load will convert PT to TensorRT (may take 3-5 minutes)")
|
||||
print("Subsequent loads will use cached TensorRT engine")
|
||||
|
||||
metadata = model_repo.load_model(
|
||||
model_id=model_id,
|
||||
file_path=MODEL_PATH,
|
||||
num_contexts=4
|
||||
num_contexts=4,
|
||||
pt_input_shapes={"images": (1, 3, 640, 640)}, # Required for PT conversion
|
||||
pt_precision=torch.float16 # Use FP16 for better performance
|
||||
)
|
||||
print(f"✓ Model loaded successfully")
|
||||
print(f" Input shape: {metadata.input_shapes}")
|
||||
|
|
@ -197,10 +205,12 @@ def main():
|
|||
except Exception as e:
|
||||
print(f"✗ Failed to load model: {e}")
|
||||
print(f" Please ensure {MODEL_PATH} exists")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
else:
|
||||
print(f"✗ Model file not found: {MODEL_PATH}")
|
||||
print(f" Please provide a valid TensorRT model file")
|
||||
print(f" Please provide a valid PyTorch (.pt) or TensorRT (.trt) model file")
|
||||
return
|
||||
|
||||
# Step 2: Create tracking controller
|
||||
|
|
@ -370,7 +380,7 @@ def main_multi_window():
|
|||
with separate OpenCV windows for each stream.
|
||||
"""
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "models/yolov8n.trt"
|
||||
MODEL_PATH = "models/yolov8n.pt"
|
||||
|
||||
# Load camera URLs from environment
|
||||
camera_urls = []
|
||||
|
|
@ -389,11 +399,23 @@ def main_multi_window():
|
|||
|
||||
print(f"Starting multi-window tracking with {len(camera_urls)} cameras")
|
||||
|
||||
# Create shared model repository
|
||||
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=8)
|
||||
# Create shared model repository with PT conversion enabled
|
||||
import torch
|
||||
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=8, enable_pt_conversion=True)
|
||||
|
||||
if os.path.exists(MODEL_PATH):
|
||||
model_repo.load_model("detector", MODEL_PATH, num_contexts=8)
|
||||
print(f"Loading model from {MODEL_PATH}...")
|
||||
print("Note: First load will convert PT to TensorRT (may take 3-5 minutes)")
|
||||
print("Subsequent loads will use cached TensorRT engine")
|
||||
|
||||
model_repo.load_model(
|
||||
model_id="detector",
|
||||
file_path=MODEL_PATH,
|
||||
num_contexts=8,
|
||||
pt_input_shapes={"images": (1, 3, 640, 640)}, # Required for PT conversion
|
||||
pt_precision=torch.float16 # Use FP16 for better performance
|
||||
)
|
||||
print("✓ Model loaded successfully")
|
||||
else:
|
||||
print(f"Model not found: {MODEL_PATH}")
|
||||
return
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue