diff --git a/examples/jpeg_encode.py b/examples/jpeg_encode.py new file mode 100755 index 0000000..8e35145 --- /dev/null +++ b/examples/jpeg_encode.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +""" +Test script for JPEG encoding with nvImageCodec +Tests GPU-accelerated JPEG encoding from RTSP stream frames +""" + +import argparse +import sys +import time +import os +from pathlib import Path +from dotenv import load_dotenv +from services import StreamDecoderFactory + +# Load environment variables from .env file +load_dotenv() + + +def main(): + parser = argparse.ArgumentParser(description='Test JPEG encoding from RTSP stream') + parser.add_argument( + '--rtsp-url', + type=str, + default=None, + help='RTSP stream URL (defaults to CAMERA_URL_1 from .env)' + ) + parser.add_argument( + '--output-dir', + type=str, + default='./snapshots', + help='Output directory for JPEG files' + ) + parser.add_argument( + '--num-frames', + type=int, + default=10, + help='Number of frames to capture' + ) + parser.add_argument( + '--interval', + type=float, + default=1.0, + help='Interval between captures in seconds' + ) + parser.add_argument( + '--quality', + type=int, + default=95, + help='JPEG quality (0-100)' + ) + parser.add_argument( + '--gpu-id', + type=int, + default=0, + help='GPU device ID' + ) + + args = parser.parse_args() + + # Get RTSP URL from command line or environment + rtsp_url = args.rtsp_url + if not rtsp_url: + rtsp_url = os.getenv('CAMERA_URL_1') + if not rtsp_url: + print("Error: No RTSP URL provided") + print("Please either:") + print(" 1. Use --rtsp-url argument, or") + print(" 2. Add CAMERA_URL_1 to your .env file") + sys.exit(1) + + # Create output directory + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + print("=" * 80) + print("RTSP Stream JPEG Encoding Test") + print("=" * 80) + print(f"RTSP URL: {rtsp_url}") + print(f"Output Directory: {output_dir}") + print(f"Number of Frames: {args.num_frames}") + print(f"Capture Interval: {args.interval}s") + print(f"JPEG Quality: {args.quality}") + print(f"GPU ID: {args.gpu_id}") + print("=" * 80) + print() + + try: + # Initialize factory and decoder + print("[1/3] Initializing StreamDecoderFactory...") + factory = StreamDecoderFactory(gpu_id=args.gpu_id) + print("✓ Factory initialized\n") + + print("[2/3] Creating and starting decoder...") + decoder = factory.create_decoder( + rtsp_url=rtsp_url, + buffer_size=30 + ) + decoder.start() + print("✓ Decoder started\n") + + # Wait for connection + print("[3/3] Waiting for stream to connect...") + max_wait = 10 + for i in range(max_wait): + if decoder.is_connected(): + print("✓ Stream connected\n") + break + time.sleep(1) + print(f" Waiting... {i+1}/{max_wait}s") + else: + print("✗ Failed to connect to stream") + sys.exit(1) + + # Capture frames + print(f"Capturing {args.num_frames} frames...") + print("-" * 80) + + captured = 0 + for i in range(args.num_frames): + # Get frame as JPEG + start_time = time.time() + jpeg_bytes = decoder.get_frame_as_jpeg(quality=args.quality) + encode_time = (time.time() - start_time) * 1000 # ms + + if jpeg_bytes: + # Save to file + filename = output_dir / f"frame_{i:04d}.jpg" + with open(filename, 'wb') as f: + f.write(jpeg_bytes) + + size_kb = len(jpeg_bytes) / 1024 + print(f"[{i+1}/{args.num_frames}] Saved {filename.name} " + f"({size_kb:.1f} KB, encoded in {encode_time:.2f}ms)") + captured += 1 + else: + print(f"[{i+1}/{args.num_frames}] Failed to get frame") + + # Wait before next capture (except for last frame) + if i < args.num_frames - 1: + time.sleep(args.interval) + + print("-" * 80) + + # Summary + print("\n" + "=" * 80) + print("Capture Complete") + print("=" * 80) + print(f"Successfully captured: {captured}/{args.num_frames} frames") + print(f"Output directory: {output_dir.absolute()}") + print("=" * 80) + + except KeyboardInterrupt: + print("\n\n✗ Interrupted by user") + sys.exit(1) + + except Exception as e: + print(f"\n\n✗ Error: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + + finally: + # Cleanup + if 'decoder' in locals(): + print("\nCleaning up...") + decoder.stop() + print("✓ Decoder stopped") + + print("\n✓ Test completed successfully") + sys.exit(0) + + +if __name__ == '__main__': + main() diff --git a/requirements.txt b/requirements.txt index 88ab797..e45d119 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ av cuda-python nvidia-nvimgcodec-cu12 # GPU-accelerated JPEG encoding/decoding with nvJPEG python-dotenv # Load environment variables from .env file +torch_tensorrt \ No newline at end of file diff --git a/services/__init__.py b/services/__init__.py index 497e777..7510f61 100644 --- a/services/__init__.py +++ b/services/__init__.py @@ -10,6 +10,8 @@ from .tracking_factory import TrackingFactory from .yolo import YOLOv8Utils, COCO_CLASSES from .model_controller import ModelController, BatchFrame, BufferState from .stream_connection_manager import StreamConnectionManager, StreamConnection, TrackingResult +from .pt_converter import PTConverter +from .modelstorage import IModelStorage, FileModelStorage __all__ = [ 'StreamDecoderFactory', @@ -32,4 +34,7 @@ __all__ = [ 'StreamConnectionManager', 'StreamConnection', 'TrackingResult', + 'PTConverter', + 'IModelStorage', + 'FileModelStorage', ] diff --git a/services/model_repository.py b/services/model_repository.py index ec50401..29c9da8 100644 --- a/services/model_repository.py +++ b/services/model_repository.py @@ -6,6 +6,9 @@ from queue import Queue import torch import tensorrt as trt from dataclasses import dataclass +import logging + +logger = logging.getLogger(__name__) @dataclass @@ -158,17 +161,19 @@ class TensorRTModelRepository: # Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts! """ - def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4): + def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4, enable_pt_conversion: bool = True): """ Initialize the model repository. Args: gpu_id: GPU device ID to use default_num_contexts: Default number of execution contexts per unique engine + enable_pt_conversion: Enable automatic PyTorch to TensorRT conversion """ self.gpu_id = gpu_id self.device = torch.device(f'cuda:{gpu_id}') self.default_num_contexts = default_num_contexts + self.enable_pt_conversion = enable_pt_conversion # Model ID to engine mapping: model_id -> file_hash self._model_to_hash: Dict[str, str] = {} @@ -182,8 +187,22 @@ class TensorRTModelRepository: # TensorRT logger self.trt_logger = trt.Logger(trt.Logger.WARNING) + # PT converter (lazy initialization) + self._pt_converter = None + print(f"TensorRT Model Repository initialized on GPU {gpu_id}") print(f"Default context pool size: {default_num_contexts} contexts per unique model") + if enable_pt_conversion: + print(f"PyTorch to TensorRT conversion: enabled") + + @property + def pt_converter(self): + """Lazy initialization of PT converter""" + if self._pt_converter is None and self.enable_pt_conversion: + from .pt_converter import PTConverter + self._pt_converter = PTConverter(gpu_id=self.gpu_id) + logger.info("PT converter initialized") + return self._pt_converter @staticmethod def compute_file_hash(file_path: str) -> str: @@ -282,18 +301,26 @@ class TensorRTModelRepository: def load_model(self, model_id: str, file_path: str, num_contexts: Optional[int] = None, - force_reload: bool = False) -> ModelMetadata: + force_reload: bool = False, + pt_input_shapes: Optional[Dict[str, Tuple]] = None, + pt_precision: Optional[torch.dtype] = None, + **pt_conversion_kwargs) -> ModelMetadata: """ Load a TensorRT model with the given ID. + Supports both .trt and .pt files. PT files are automatically converted to TensorRT. + Deduplication: If a model with the same file hash is already loaded, the model_id is simply mapped to the existing SharedEngine (no new engine or contexts created). Args: model_id: User-defined identifier for this model (e.g., "camera_1") - file_path: Path to TensorRT engine file (.trt or .engine) + file_path: Path to TensorRT engine file (.trt, .engine) or PyTorch file (.pt, .pth) num_contexts: Number of execution contexts in pool (None = use default) force_reload: If True, reload even if model_id exists + pt_input_shapes: Required for .pt files - dict of input shapes (e.g., {"x": (1, 3, 224, 224)}) + pt_precision: Precision for PT conversion (torch.float16 or torch.float32) + **pt_conversion_kwargs: Additional arguments for torch_tensorrt.compile() Returns: ModelMetadata for the loaded model @@ -301,13 +328,37 @@ class TensorRTModelRepository: Raises: FileNotFoundError: If model file doesn't exist RuntimeError: If engine loading fails - ValueError: If model_id already exists and force_reload is False + ValueError: If model_id already exists and force_reload is False, or PT conversion requires input_shapes """ file_path = str(Path(file_path).resolve()) if not Path(file_path).exists(): raise FileNotFoundError(f"Model file not found: {file_path}") + # Check if file is PyTorch model + file_ext = Path(file_path).suffix.lower() + if file_ext in ['.pt', '.pth']: + if not self.enable_pt_conversion: + raise ValueError( + f"PT file provided but PT conversion is disabled. " + f"Enable with enable_pt_conversion=True or provide a .trt file." + ) + + logger.info(f"Detected PyTorch model file: {file_path}") + logger.info("Converting to TensorRT...") + + # Convert PT to TRT + trt_hash, trt_path = self.pt_converter.convert( + file_path, + input_shapes=pt_input_shapes, + precision=pt_precision, + **pt_conversion_kwargs + ) + + # Update file_path to use converted TRT file + file_path = trt_path + logger.info(f"Will load converted TensorRT model from: {file_path}") + if num_contexts is None: num_contexts = self.default_num_contexts diff --git a/services/modelstorage/__init__.py b/services/modelstorage/__init__.py new file mode 100644 index 0000000..f852a70 --- /dev/null +++ b/services/modelstorage/__init__.py @@ -0,0 +1,8 @@ +""" +Model storage module for managing TensorRT and PyTorch model files. +""" + +from .interface import IModelStorage +from .file_storage import FileModelStorage + +__all__ = ['IModelStorage', 'FileModelStorage'] diff --git a/services/modelstorage/file_storage.py b/services/modelstorage/file_storage.py new file mode 100644 index 0000000..39f8e32 --- /dev/null +++ b/services/modelstorage/file_storage.py @@ -0,0 +1,161 @@ +""" +FileModelStorage - Local filesystem implementation of IModelStorage. + +Stores model files in a local directory structure. +""" + +import os +from pathlib import Path +from typing import Optional +import logging + +from .interface import IModelStorage + +logger = logging.getLogger(__name__) + + +class FileModelStorage(IModelStorage): + """ + Local filesystem storage for model files. + + Stores files in a directory structure: + ./models/trtptcache/ + trt/ + .trt + .trt + pt/ + .pt + """ + + def __init__(self, base_path: str = "./models/trtptcache"): + """ + Initialize file storage. + + Args: + base_path: Base directory for storing files (default: ./models/trtptcache) + """ + self.base_path = Path(base_path).resolve() + self._ensure_directories() + + def _ensure_directories(self): + """Create base directory structure if it doesn't exist""" + self.base_path.mkdir(parents=True, exist_ok=True) + logger.info(f"Model storage initialized at: {self.base_path}") + + def _get_full_path(self, key: str) -> Path: + """ + Get full filesystem path for a key. + + Args: + key: Storage key (e.g., "trt/hash123.trt") + + Returns: + Full filesystem path + """ + return self.base_path / key + + def write(self, key: str, data: bytes) -> None: + """ + Write data to filesystem. + + Args: + key: Storage key (e.g., "trt/hash123.trt") + data: Binary data to write + + Raises: + IOError: If write operation fails + """ + file_path = self._get_full_path(key) + + # Ensure parent directory exists + file_path.parent.mkdir(parents=True, exist_ok=True) + + try: + with open(file_path, 'wb') as f: + f.write(data) + logger.debug(f"Wrote {len(data)} bytes to {file_path}") + except Exception as e: + raise IOError(f"Failed to write to {file_path}: {e}") + + def read(self, key: str) -> Optional[bytes]: + """ + Read data from filesystem. + + Args: + key: Storage key + + Returns: + Binary data if found, None otherwise + + Raises: + IOError: If read operation fails + """ + file_path = self._get_full_path(key) + + if not file_path.exists(): + return None + + try: + with open(file_path, 'rb') as f: + data = f.read() + logger.debug(f"Read {len(data)} bytes from {file_path}") + return data + except Exception as e: + raise IOError(f"Failed to read from {file_path}: {e}") + + def exists(self, key: str) -> bool: + """ + Check if file exists. + + Args: + key: Storage key + + Returns: + True if file exists, False otherwise + """ + return self._get_full_path(key).exists() + + def delete(self, key: str) -> bool: + """ + Delete file from filesystem. + + Args: + key: Storage key + + Returns: + True if deleted successfully, False if file didn't exist + """ + file_path = self._get_full_path(key) + + if not file_path.exists(): + return False + + try: + file_path.unlink() + logger.debug(f"Deleted {file_path}") + return True + except Exception as e: + logger.error(f"Failed to delete {file_path}: {e}") + return False + + def get_local_path(self, key: str) -> Optional[str]: + """ + Get local filesystem path for a key. + + Args: + key: Storage key + + Returns: + Local path if file exists, None otherwise + """ + file_path = self._get_full_path(key) + return str(file_path) if file_path.exists() else None + + def get_storage_path(self) -> str: + """ + Get the base storage path. + + Returns: + Base path where files are stored + """ + return str(self.base_path) diff --git a/services/modelstorage/interface.py b/services/modelstorage/interface.py new file mode 100644 index 0000000..07c6855 --- /dev/null +++ b/services/modelstorage/interface.py @@ -0,0 +1,91 @@ +""" +IModelStorage - Interface for model file storage. + +Defines the contract for storing and retrieving model files (TensorRT, PyTorch, etc.) +""" + +from abc import ABC, abstractmethod +from typing import Optional, BinaryIO +from pathlib import Path +import io + + +class IModelStorage(ABC): + """ + Interface for model file storage. + + This abstraction allows swapping storage backends (local filesystem, S3, etc.) + without changing the model conversion and loading logic. + """ + + @abstractmethod + def write(self, key: str, data: bytes) -> None: + """ + Write data to storage with the given key. + + Args: + key: Storage key (e.g., "trt/hash123.trt" or "pt/hash456.pt") + data: Binary data to write + + Raises: + IOError: If write operation fails + """ + pass + + @abstractmethod + def read(self, key: str) -> Optional[bytes]: + """ + Read data from storage by key. + + Args: + key: Storage key + + Returns: + Binary data if found, None otherwise + + Raises: + IOError: If read operation fails + """ + pass + + @abstractmethod + def exists(self, key: str) -> bool: + """ + Check if a key exists in storage. + + Args: + key: Storage key + + Returns: + True if key exists, False otherwise + """ + pass + + @abstractmethod + def delete(self, key: str) -> bool: + """ + Delete data from storage. + + Args: + key: Storage key + + Returns: + True if deleted successfully, False if key didn't exist + """ + pass + + @abstractmethod + def get_local_path(self, key: str) -> Optional[str]: + """ + Get a local filesystem path for the key. + + For local storage, this returns the direct path. + For remote storage (S3), this may download to a temp location or return None. + + Args: + key: Storage key + + Returns: + Local path if available/downloaded, None if not supported + """ + pass diff --git a/services/pt_converter.py b/services/pt_converter.py new file mode 100644 index 0000000..5a349ab --- /dev/null +++ b/services/pt_converter.py @@ -0,0 +1,485 @@ +""" +PTConverter - Convert PyTorch models to TensorRT using torch_tensorrt. + +This service handles conversion of .pt files to TensorRT format with caching +to avoid redundant conversions. It maintains a mapping database between +PT file hashes and their converted TensorRT file hashes. +""" + +import hashlib +import json +import logging +from pathlib import Path +from typing import Optional, Dict, Tuple +import torch +import torch_tensorrt + +from .modelstorage import IModelStorage, FileModelStorage + +logger = logging.getLogger(__name__) + + +class PTConverter: + """ + PyTorch to TensorRT converter with intelligent caching. + + Features: + - Hash-based deduplication: Same PT file only converted once + - Persistent mapping database: PT hash -> TRT hash mapping + - Pluggable storage backend: IModelStorage interface + - Automatic cache management + + Architecture: + 1. Compute hash of input .pt file + 2. Check mapping database for existing conversion + 3. If found, return cached TRT file hash and path + 4. If not found, perform conversion and store mapping + """ + + def __init__( + self, + storage: Optional[IModelStorage] = None, + gpu_id: int = 0, + default_precision: torch.dtype = torch.float16 + ): + """ + Initialize PT converter. + + Args: + storage: Storage backend (defaults to FileModelStorage) + gpu_id: GPU device ID for conversion + default_precision: Default precision for TensorRT conversion (fp16, fp32) + """ + self.storage = storage or FileModelStorage() + self.gpu_id = gpu_id + self.device = torch.device(f'cuda:{gpu_id}') + self.default_precision = default_precision + + # Mapping database: pt_hash -> {"trt_hash": str, "metadata": {...}} + self.mapping_db_key = "pt_to_trt_mapping.json" + self.mapping_db: Dict[str, Dict] = self._load_mapping_db() + + logger.info(f"PTConverter initialized on GPU {gpu_id}") + logger.info(f"Storage backend: {self.storage.__class__.__name__}") + logger.info(f"Storage path: {self.storage.get_storage_path()}") + logger.info(f"Loaded {len(self.mapping_db)} cached conversions") + + def _load_mapping_db(self) -> Dict[str, Dict]: + """ + Load mapping database from storage. + + Returns: + Mapping dictionary (pt_hash -> metadata) + """ + try: + data = self.storage.read(self.mapping_db_key) + if data: + db = json.loads(data.decode('utf-8')) + logger.debug(f"Loaded mapping database with {len(db)} entries") + return db + else: + logger.debug("No existing mapping database found, starting fresh") + return {} + except Exception as e: + logger.warning(f"Failed to load mapping database: {e}. Starting fresh.") + return {} + + def _save_mapping_db(self): + """Save mapping database to storage""" + try: + data = json.dumps(self.mapping_db, indent=2).encode('utf-8') + self.storage.write(self.mapping_db_key, data) + logger.debug(f"Saved mapping database with {len(self.mapping_db)} entries") + except Exception as e: + logger.error(f"Failed to save mapping database: {e}") + + @staticmethod + def compute_file_hash(file_path: str) -> str: + """ + Compute SHA256 hash of a file. + + Args: + file_path: Path to file + + Returns: + Hexadecimal hash string + """ + sha256_hash = hashlib.sha256() + with open(file_path, "rb") as f: + for byte_block in iter(lambda: f.read(65536), b""): + sha256_hash.update(byte_block) + return sha256_hash.hexdigest() + + def get_cached_conversion(self, pt_hash: str) -> Optional[Tuple[str, str]]: + """ + Check if PT file has already been converted. + + Args: + pt_hash: SHA256 hash of the PT file + + Returns: + Tuple of (trt_hash, trt_file_path) if cached, None otherwise + """ + if pt_hash not in self.mapping_db: + return None + + mapping = self.mapping_db[pt_hash] + trt_hash = mapping["trt_hash"] + trt_key = f"trt/{trt_hash}.trt" + + # Verify TRT file still exists in storage + if not self.storage.exists(trt_key): + logger.warning( + f"Mapping exists for PT hash {pt_hash[:16]}... but TRT file missing. " + f"Will reconvert." + ) + # Remove stale mapping + del self.mapping_db[pt_hash] + self._save_mapping_db() + return None + + # Get local path + trt_path = self.storage.get_local_path(trt_key) + if trt_path is None: + logger.error(f"Could not get local path for TRT file {trt_key}") + return None + + logger.info( + f"Found cached conversion for PT hash {pt_hash[:16]}... -> " + f"TRT hash {trt_hash[:16]}..." + ) + return (trt_hash, trt_path) + + def convert( + self, + pt_file_path: str, + input_shapes: Optional[Dict[str, Tuple]] = None, + precision: Optional[torch.dtype] = None, + **conversion_kwargs + ) -> Tuple[str, str]: + """ + Convert PyTorch model to TensorRT. + + If this PT file has been converted before (same hash), returns cached result. + Otherwise, performs conversion and caches the result. + + Args: + pt_file_path: Path to .pt file + input_shapes: Dict of input names to shapes (e.g., {"x": (1, 3, 224, 224)}) + precision: Target precision (fp16, fp32) - defaults to self.default_precision + **conversion_kwargs: Additional arguments for torch_tensorrt.compile() + + Returns: + Tuple of (trt_hash, trt_file_path) + + Raises: + FileNotFoundError: If PT file doesn't exist + RuntimeError: If conversion fails + """ + pt_path = Path(pt_file_path).resolve() + if not pt_path.exists(): + raise FileNotFoundError(f"PT file not found: {pt_file_path}") + + # Compute PT file hash + logger.info(f"Computing hash for {pt_path}...") + pt_hash = self.compute_file_hash(str(pt_path)) + logger.info(f"PT file hash: {pt_hash[:16]}...") + + # Check cache + cached = self.get_cached_conversion(pt_hash) + if cached: + return cached + + # Perform conversion + logger.info(f"Converting {pt_path.name} to TensorRT...") + trt_hash, trt_path = self._perform_conversion( + str(pt_path), + pt_hash, + input_shapes, + precision or self.default_precision, + **conversion_kwargs + ) + + # Store mapping + self.mapping_db[pt_hash] = { + "trt_hash": trt_hash, + "pt_file": str(pt_path), + "input_shapes": str(input_shapes), + "precision": str(precision or self.default_precision), + } + self._save_mapping_db() + + logger.info(f"Conversion complete: PT {pt_hash[:16]}... -> TRT {trt_hash[:16]}...") + return (trt_hash, trt_path) + + def _is_ultralytics_model(self, model) -> bool: + """ + Check if model is from ultralytics (YOLO). + + Args: + model: PyTorch model + + Returns: + True if ultralytics model, False otherwise + """ + model_class_name = model.__class__.__name__ + model_module = model.__class__.__module__ + + # Check if it's an ultralytics model + is_ultralytics = ( + 'ultralytics' in model_module or + model_class_name in ['DetectionModel', 'SegmentationModel', 'PoseModel', 'ClassificationModel'] + ) + + return is_ultralytics + + def _convert_ultralytics_model( + self, + pt_path: str, + pt_hash: str, + input_shapes: Optional[Dict[str, Tuple]], + precision: torch.dtype, + ) -> Tuple[str, str]: + """ + Convert ultralytics YOLO model using ONNX → TensorRT pipeline. + Uses the same approach as scripts/convert_pt_to_tensorrt.py + + Args: + pt_path: Path to PT file + pt_hash: PT file hash + input_shapes: Input tensor shapes + precision: Target precision + + Returns: + Tuple of (trt_hash, trt_file_path) + """ + import tensorrt as trt + import tempfile + import os + import shutil + + logger.info("Detected ultralytics YOLO model, using ONNX → TensorRT pipeline...") + + # Load ultralytics model + try: + from ultralytics import YOLO + model = YOLO(pt_path) + except ImportError: + raise ImportError("ultralytics package not found. Install with: pip install ultralytics") + + # Determine input shape + if not input_shapes: + raise ValueError("input_shapes required for ultralytics conversion") + + input_key = 'images' if 'images' in input_shapes else list(input_shapes.keys())[0] + input_shape = input_shapes[input_key] + + # Export to ONNX first + logger.info(f"Exporting to ONNX (input shape: {input_shape})...") + with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmp_onnx: + onnx_path = tmp_onnx.name + + try: + # Use ultralytics export to ONNX + model.export(format='onnx', imgsz=input_shape[2], batch=input_shape[0]) + # Ultralytics saves as model_name.onnx in same directory + pt_dir = os.path.dirname(pt_path) + pt_name = os.path.splitext(os.path.basename(pt_path))[0] + onnx_export_path = os.path.join(pt_dir, f"{pt_name}.onnx") + + # Move to our temp location (use shutil.move for cross-device support) + if os.path.exists(onnx_export_path): + shutil.move(onnx_export_path, onnx_path) + else: + raise RuntimeError(f"ONNX export failed, file not found: {onnx_export_path}") + + logger.info(f"ONNX export complete: {onnx_path}") + + # Build TensorRT engine from ONNX + logger.info("Building TensorRT engine from ONNX...") + trt_logger = trt.Logger(trt.Logger.WARNING) + builder = trt.Builder(trt_logger) + network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + parser = trt.OnnxParser(network, trt_logger) + + # Parse ONNX + with open(onnx_path, 'rb') as f: + if not parser.parse(f.read()): + errors = [parser.get_error(i) for i in range(parser.num_errors)] + raise RuntimeError(f"Failed to parse ONNX: {errors}") + + # Configure builder + config = builder.create_builder_config() + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30) # 4GB + + # Set precision + if precision == torch.float16: + if builder.platform_has_fast_fp16: + config.set_flag(trt.BuilderFlag.FP16) + logger.info("FP16 mode enabled") + + # Build engine + logger.info("Building TensorRT engine (this may take a few minutes)...") + serialized_engine = builder.build_serialized_network(network, config) + + if serialized_engine is None: + raise RuntimeError("Failed to build TensorRT engine") + + # Convert IHostMemory to bytes + engine_bytes = bytes(serialized_engine) + + # Save to storage + trt_hash = hashlib.sha256(engine_bytes).hexdigest() + trt_key = f"trt/{trt_hash}.trt" + self.storage.write(trt_key, engine_bytes) + + trt_path = self.storage.get_local_path(trt_key) + if trt_path is None: + raise RuntimeError("Failed to get local path for TRT file") + + logger.info(f"TensorRT engine built successfully: {trt_path}") + return (trt_hash, trt_path) + + finally: + # Cleanup ONNX file + if os.path.exists(onnx_path): + os.unlink(onnx_path) + + def _perform_conversion( + self, + pt_path: str, + pt_hash: str, + input_shapes: Optional[Dict[str, Tuple]], + precision: torch.dtype, + **conversion_kwargs + ) -> Tuple[str, str]: + """ + Perform the actual PT to TRT conversion. + + Args: + pt_path: Path to PT file + pt_hash: PT file hash + input_shapes: Input tensor shapes + precision: Target precision + **conversion_kwargs: Additional torch_tensorrt arguments + + Returns: + Tuple of (trt_hash, trt_file_path) + """ + try: + # Load PyTorch model to check type + logger.debug(f"Loading PyTorch model from {pt_path}...") + # Use weights_only=False for models with custom classes (like ultralytics) + # This is safe for trusted local models + loaded = torch.load(pt_path, map_location='cpu', weights_only=False) + + # If model is wrapped in a dict, extract the model + if isinstance(loaded, dict): + if 'model' in loaded: + model = loaded['model'] + elif 'state_dict' in loaded: + raise ValueError( + "PT file contains state_dict only. " + "Please provide a full model or use a different loading method." + ) + else: + model = loaded + + # Check if this is an ultralytics model + if self._is_ultralytics_model(model): + logger.info("Detected ultralytics model, using ultralytics export API") + return self._convert_ultralytics_model(pt_path, pt_hash, input_shapes, precision) + + # For non-ultralytics models, use torch_tensorrt + logger.info("Using torch_tensorrt for conversion") + model.eval() + + # Convert model to target precision to avoid mixed precision issues + if precision == torch.float16: + model = model.half() + elif precision == torch.float32: + model = model.float() + + # Move to GPU + model = model.to(self.device) + + # Prepare inputs for tracing + if input_shapes is None: + raise ValueError( + "input_shapes must be provided for TensorRT conversion. " + "Example: {'x': (1, 3, 224, 224)}" + ) + + # Create sample inputs with matching precision + inputs = [] + for name, shape in input_shapes.items(): + sample_input = torch.randn(shape, device=self.device, dtype=precision) + inputs.append(sample_input) + + # Configure torch_tensorrt + enabled_precisions = {precision} + if precision == torch.float16: + enabled_precisions.add(torch.float32) # Fallback for unsupported ops + + # Compile to TensorRT + logger.info(f"Compiling to TensorRT (precision: {precision})...") + trt_model = torch_tensorrt.compile( + model, + inputs=inputs, + enabled_precisions=enabled_precisions, + **conversion_kwargs + ) + + # Save TRT model to temporary location + import tempfile + with tempfile.NamedTemporaryFile(delete=False, suffix='.trt') as tmp_file: + tmp_path = tmp_file.name + + torch.jit.save(trt_model, tmp_path) + logger.debug(f"Saved TRT model to temporary file: {tmp_path}") + + # Compute TRT file hash + trt_hash = self.compute_file_hash(tmp_path) + logger.info(f"TRT file hash: {trt_hash[:16]}...") + + # Store in storage backend + trt_key = f"trt/{trt_hash}.trt" + with open(tmp_path, 'rb') as f: + trt_data = f.read() + self.storage.write(trt_key, trt_data) + + # Get local path + trt_path = self.storage.get_local_path(trt_key) + if trt_path is None: + raise RuntimeError("Failed to get local path for converted TRT file") + + # Cleanup temp file + Path(tmp_path).unlink() + + logger.info(f"TRT model stored successfully at {trt_path}") + return (trt_hash, trt_path) + + except Exception as e: + logger.error(f"Conversion failed: {e}", exc_info=True) + raise RuntimeError(f"Failed to convert PT to TensorRT: {e}") + + def clear_cache(self): + """Clear all cached conversions and mapping database""" + logger.warning("Clearing all cached conversions...") + self.mapping_db.clear() + self._save_mapping_db() + logger.info("Cache cleared") + + def get_stats(self) -> Dict: + """ + Get conversion statistics. + + Returns: + Dictionary with cache stats + """ + return { + "total_cached_conversions": len(self.mapping_db), + "storage_path": self.storage.get_storage_path(), + "gpu_id": self.gpu_id, + "default_precision": str(self.default_precision), + } diff --git a/test_tracking_realtime.py b/test_tracking_realtime.py index e5bc57c..754f326 100644 --- a/test_tracking_realtime.py +++ b/test_tracking_realtime.py @@ -167,9 +167,11 @@ def main(): """ Main function for real-time tracking visualization. """ + import torch + # Configuration GPU_ID = 0 - MODEL_PATH = "models/yolov8n.trt" + MODEL_PATH = "models/yolov8n.pt" # Changed to PT file RTSP_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test') BUFFER_SIZE = 30 WINDOW_NAME = "Real-time Object Tracking" @@ -178,18 +180,24 @@ def main(): print("Real-time GPU-Accelerated Object Tracking") print("=" * 80) - # Step 1: Create model repository + # Step 1: Create model repository with PT conversion enabled print("\n[1/4] Initializing TensorRT Model Repository...") - model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=4) + model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=4, enable_pt_conversion=True) - # Load detection model + # Load detection model (will auto-convert PT to TRT) model_id = "yolov8_detector" if os.path.exists(MODEL_PATH): try: + print(f"Loading model from {MODEL_PATH}...") + print("Note: First load will convert PT to TensorRT (may take 3-5 minutes)") + print("Subsequent loads will use cached TensorRT engine") + metadata = model_repo.load_model( model_id=model_id, file_path=MODEL_PATH, - num_contexts=4 + num_contexts=4, + pt_input_shapes={"images": (1, 3, 640, 640)}, # Required for PT conversion + pt_precision=torch.float16 # Use FP16 for better performance ) print(f"✓ Model loaded successfully") print(f" Input shape: {metadata.input_shapes}") @@ -197,10 +205,12 @@ def main(): except Exception as e: print(f"✗ Failed to load model: {e}") print(f" Please ensure {MODEL_PATH} exists") + import traceback + traceback.print_exc() return else: print(f"✗ Model file not found: {MODEL_PATH}") - print(f" Please provide a valid TensorRT model file") + print(f" Please provide a valid PyTorch (.pt) or TensorRT (.trt) model file") return # Step 2: Create tracking controller @@ -370,7 +380,7 @@ def main_multi_window(): with separate OpenCV windows for each stream. """ GPU_ID = 0 - MODEL_PATH = "models/yolov8n.trt" + MODEL_PATH = "models/yolov8n.pt" # Load camera URLs from environment camera_urls = [] @@ -389,11 +399,23 @@ def main_multi_window(): print(f"Starting multi-window tracking with {len(camera_urls)} cameras") - # Create shared model repository - model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=8) + # Create shared model repository with PT conversion enabled + import torch + model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=8, enable_pt_conversion=True) if os.path.exists(MODEL_PATH): - model_repo.load_model("detector", MODEL_PATH, num_contexts=8) + print(f"Loading model from {MODEL_PATH}...") + print("Note: First load will convert PT to TensorRT (may take 3-5 minutes)") + print("Subsequent loads will use cached TensorRT engine") + + model_repo.load_model( + model_id="detector", + file_path=MODEL_PATH, + num_contexts=8, + pt_input_shapes={"images": (1, 3, 640, 640)}, # Required for PT conversion + pt_precision=torch.float16 # Use FP16 for better performance + ) + print("✓ Model loaded successfully") else: print(f"Model not found: {MODEL_PATH}") return