feat: inference subsystem and optimization to decoder
This commit is contained in:
commit
3c83a57e44
19 changed files with 3897 additions and 0 deletions
380
services/README_MODEL_REPOSITORY.md
Normal file
380
services/README_MODEL_REPOSITORY.md
Normal file
|
|
@ -0,0 +1,380 @@
|
|||
# TensorRT Model Repository
|
||||
|
||||
Efficient TensorRT model management with context pooling, deduplication, and GPU-to-GPU inference.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Key Features
|
||||
|
||||
1. **Model Deduplication by File Hash**
|
||||
- Multiple model IDs can point to the same model file
|
||||
- Only one engine loaded in VRAM per unique file
|
||||
- Example: 100 cameras with same model = 1 engine (not 100!)
|
||||
|
||||
2. **Context Pooling for Load Balancing**
|
||||
- Each unique engine has N execution contexts (configurable)
|
||||
- Contexts borrowed/returned via mutex-based queue
|
||||
- Enables concurrent inference without context-per-model overhead
|
||||
- Example: 100 cameras sharing 4 contexts efficiently
|
||||
|
||||
3. **GPU-to-GPU Inference**
|
||||
- All inputs/outputs stay in VRAM (zero CPU transfers)
|
||||
- Integrates seamlessly with StreamDecoder (frames already on GPU)
|
||||
- Maximum performance for video inference pipelines
|
||||
|
||||
4. **Thread-Safe Concurrent Inference**
|
||||
- Mutex-based context acquisition (TensorRT best practice)
|
||||
- No shared IExecutionContext across threads (safe)
|
||||
- Multiple threads can infer concurrently (limited by pool size)
|
||||
|
||||
## Design Rationale
|
||||
|
||||
### Why Context Pooling?
|
||||
|
||||
**Without pooling** (naive approach):
|
||||
```
|
||||
100 cameras → 100 model IDs → 100 execution contexts
|
||||
```
|
||||
- Problem: Each context consumes VRAM (layers, workspace, etc.)
|
||||
- Problem: Context creation overhead per camera
|
||||
- Problem: Doesn't scale to hundreds of cameras
|
||||
|
||||
**With pooling** (our approach):
|
||||
```
|
||||
100 cameras → 100 model IDs → 1 shared engine → 4 contexts (pool)
|
||||
```
|
||||
- Solution: Contexts shared across all cameras using same model
|
||||
- Solution: Borrow/return mechanism with mutex queue
|
||||
- Solution: Scales to any number of cameras with fixed context count
|
||||
|
||||
### Memory Savings Example
|
||||
|
||||
YOLOv8n model (~6MB engine file):
|
||||
|
||||
| Approach | Model IDs | Engines | Contexts | Approx VRAM |
|
||||
|----------|-----------|---------|----------|-------------|
|
||||
| Naive | 100 | 100 | 100 | ~1.5 GB |
|
||||
| **Ours (pooled)** | **100** | **1** | **4** | **~30 MB** |
|
||||
|
||||
**50x memory savings!**
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from services.model_repository import TensorRTModelRepository
|
||||
|
||||
# Initialize repository
|
||||
repo = TensorRTModelRepository(
|
||||
gpu_id=0,
|
||||
default_num_contexts=4 # 4 contexts per unique engine
|
||||
)
|
||||
|
||||
# Load model for camera 1
|
||||
repo.load_model(
|
||||
model_id="camera_1",
|
||||
file_path="models/yolov8n.trt"
|
||||
)
|
||||
|
||||
# Load same model for camera 2 (deduplication happens automatically)
|
||||
repo.load_model(
|
||||
model_id="camera_2",
|
||||
file_path="models/yolov8n.trt" # Same file → shares engine and contexts!
|
||||
)
|
||||
|
||||
# Run inference (GPU-to-GPU)
|
||||
import torch
|
||||
input_tensor = torch.rand(1, 3, 640, 640, device='cuda:0')
|
||||
|
||||
outputs = repo.infer(
|
||||
model_id="camera_1",
|
||||
inputs={"images": input_tensor},
|
||||
synchronize=True,
|
||||
timeout=5.0 # Wait up to 5s for available context
|
||||
)
|
||||
|
||||
# Outputs stay on GPU
|
||||
for name, tensor in outputs.items():
|
||||
print(f"{name}: {tensor.shape} on {tensor.device}")
|
||||
```
|
||||
|
||||
### Multi-Camera Scenario
|
||||
|
||||
```python
|
||||
# Setup multiple cameras
|
||||
cameras = [f"camera_{i}" for i in range(100)]
|
||||
|
||||
# Load same model for all cameras
|
||||
for camera_id in cameras:
|
||||
repo.load_model(
|
||||
model_id=camera_id,
|
||||
file_path="models/yolov8n.trt" # Same file for all
|
||||
)
|
||||
|
||||
# Check efficiency
|
||||
stats = repo.get_stats()
|
||||
print(f"Model IDs: {stats['total_model_ids']}") # 100
|
||||
print(f"Unique engines: {stats['unique_engines']}") # 1
|
||||
print(f"Total contexts: {stats['total_contexts']}") # 4
|
||||
```
|
||||
|
||||
### Integration with RTSP Decoder
|
||||
|
||||
```python
|
||||
from services.stream_decoder import StreamDecoderFactory
|
||||
from services.model_repository import TensorRTModelRepository
|
||||
|
||||
# Setup
|
||||
decoder_factory = StreamDecoderFactory(gpu_id=0)
|
||||
model_repo = TensorRTModelRepository(gpu_id=0)
|
||||
|
||||
# Create decoder for camera
|
||||
decoder = decoder_factory.create_decoder("rtsp://camera.ip/stream")
|
||||
decoder.start()
|
||||
|
||||
# Load inference model
|
||||
model_repo.load_model("camera_main", "models/yolov8n.trt")
|
||||
|
||||
# Process frames (everything on GPU)
|
||||
frame_gpu = decoder.get_latest_frame(rgb=True) # torch.Tensor on CUDA
|
||||
|
||||
# Preprocess (stays on GPU)
|
||||
frame_gpu = frame_gpu.float() / 255.0
|
||||
frame_gpu = frame_gpu.unsqueeze(0) # Add batch dim
|
||||
|
||||
# Inference (GPU-to-GPU, zero copy)
|
||||
outputs = model_repo.infer(
|
||||
model_id="camera_main",
|
||||
inputs={"images": frame_gpu}
|
||||
)
|
||||
|
||||
# Post-process outputs (can stay on GPU)
|
||||
# ... NMS, bounding boxes, etc.
|
||||
```
|
||||
|
||||
### Concurrent Inference
|
||||
|
||||
```python
|
||||
import threading
|
||||
|
||||
def process_camera(camera_id: str, model_id: str):
|
||||
# Get frame from decoder (on GPU)
|
||||
frame = decoder.get_latest_frame(rgb=True)
|
||||
|
||||
# Inference automatically borrows/returns context from pool
|
||||
outputs = repo.infer(
|
||||
model_id=model_id,
|
||||
inputs={"images": frame},
|
||||
timeout=10.0 # Wait for available context
|
||||
)
|
||||
|
||||
# Process outputs...
|
||||
|
||||
# Multiple threads can infer concurrently
|
||||
threads = []
|
||||
for i in range(10): # 10 threads
|
||||
t = threading.Thread(
|
||||
target=process_camera,
|
||||
args=(f"camera_{i}", f"camera_{i}")
|
||||
)
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# With 4 contexts: up to 4 inferences run in parallel
|
||||
# Others wait in queue, contexts auto-balanced
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### TensorRTModelRepository
|
||||
|
||||
#### `__init__(gpu_id=0, default_num_contexts=4)`
|
||||
Initialize the repository.
|
||||
|
||||
**Args:**
|
||||
- `gpu_id`: GPU device ID
|
||||
- `default_num_contexts`: Default context pool size per engine
|
||||
|
||||
#### `load_model(model_id, file_path, num_contexts=None, force_reload=False)`
|
||||
Load a TensorRT model.
|
||||
|
||||
**Args:**
|
||||
- `model_id`: Unique identifier (e.g., "camera_1")
|
||||
- `file_path`: Path to .trt/.engine file
|
||||
- `num_contexts`: Context pool size (None = use default)
|
||||
- `force_reload`: Reload if model_id exists
|
||||
|
||||
**Returns:** `ModelMetadata`
|
||||
|
||||
**Deduplication:** If file hash matches existing model, reuses engine + contexts.
|
||||
|
||||
#### `infer(model_id, inputs, synchronize=True, timeout=5.0)`
|
||||
Run inference.
|
||||
|
||||
**Args:**
|
||||
- `model_id`: Model identifier
|
||||
- `inputs`: Dict mapping input names to CUDA tensors
|
||||
- `synchronize`: Wait for completion
|
||||
- `timeout`: Max wait time for context (seconds)
|
||||
|
||||
**Returns:** Dict mapping output names to CUDA tensors
|
||||
|
||||
**Thread-safe:** Borrows context from pool, returns after inference.
|
||||
|
||||
#### `unload_model(model_id)`
|
||||
Unload a model.
|
||||
|
||||
If last reference to engine, fully unloads from VRAM.
|
||||
|
||||
#### `get_metadata(model_id)`
|
||||
Get model metadata.
|
||||
|
||||
**Returns:** `ModelMetadata` or `None`
|
||||
|
||||
#### `get_model_info(model_id)`
|
||||
Get detailed model information.
|
||||
|
||||
**Returns:** Dict with engine references, context pool size, shared model IDs, etc.
|
||||
|
||||
#### `get_stats()`
|
||||
Get repository statistics.
|
||||
|
||||
**Returns:** Dict with total models, unique engines, contexts, memory efficiency.
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Set Appropriate Context Pool Size
|
||||
|
||||
```python
|
||||
# For 10 cameras with same model, 4 contexts is usually enough
|
||||
repo = TensorRTModelRepository(default_num_contexts=4)
|
||||
|
||||
# For high concurrency, increase pool size
|
||||
repo = TensorRTModelRepository(default_num_contexts=8)
|
||||
```
|
||||
|
||||
**Rule of thumb:** Start with 4 contexts, increase if you see timeout errors.
|
||||
|
||||
### 2. Always Use GPU Tensors
|
||||
|
||||
```python
|
||||
# ✅ Good: Input on GPU
|
||||
input_gpu = torch.rand(1, 3, 640, 640, device='cuda:0')
|
||||
outputs = repo.infer(model_id, {"images": input_gpu})
|
||||
|
||||
# ❌ Bad: Input on CPU (will cause error)
|
||||
input_cpu = torch.rand(1, 3, 640, 640)
|
||||
outputs = repo.infer(model_id, {"images": input_cpu}) # ValueError!
|
||||
```
|
||||
|
||||
### 3. Handle Timeout Gracefully
|
||||
|
||||
```python
|
||||
try:
|
||||
outputs = repo.infer(
|
||||
model_id="camera_1",
|
||||
inputs=inputs,
|
||||
timeout=5.0
|
||||
)
|
||||
except RuntimeError as e:
|
||||
# All contexts busy, increase pool size or add backpressure
|
||||
print(f"Inference timeout: {e}")
|
||||
```
|
||||
|
||||
### 4. Use Same File for Deduplication
|
||||
|
||||
```python
|
||||
# ✅ Good: Same file path → deduplication
|
||||
repo.load_model("cam1", "/models/yolo.trt")
|
||||
repo.load_model("cam2", "/models/yolo.trt") # Shares engine!
|
||||
|
||||
# ❌ Bad: Different paths (even if same content) → no deduplication
|
||||
repo.load_model("cam1", "/models/yolo.trt")
|
||||
repo.load_model("cam2", "/models/yolo_copy.trt") # Separate engine
|
||||
```
|
||||
|
||||
## TensorRT Best Practices Implemented
|
||||
|
||||
Based on NVIDIA documentation and web search findings:
|
||||
|
||||
1. **Separate IExecutionContext per concurrent stream** ✅
|
||||
- Each context has its own CUDA stream
|
||||
- Contexts never shared across threads simultaneously
|
||||
|
||||
2. **Mutex-based context management** ✅
|
||||
- Queue-based borrowing with locks
|
||||
- Thread-safe acquire/release pattern
|
||||
|
||||
3. **GPU memory reuse** ✅
|
||||
- Engines shared by file hash
|
||||
- Contexts pooled and reused
|
||||
|
||||
4. **Zero-copy operations** ✅
|
||||
- All data stays in VRAM
|
||||
- DLPack integration with PyTorch
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "No execution context available within timeout"
|
||||
|
||||
**Cause:** All contexts busy with concurrent inferences.
|
||||
|
||||
**Solutions:**
|
||||
1. Increase context pool size:
|
||||
```python
|
||||
repo.load_model(model_id, file_path, num_contexts=8)
|
||||
```
|
||||
2. Increase timeout:
|
||||
```python
|
||||
outputs = repo.infer(model_id, inputs, timeout=30.0)
|
||||
```
|
||||
3. Add backpressure/throttling to limit concurrent requests
|
||||
|
||||
### Out of Memory (OOM)
|
||||
|
||||
**Cause:** Too many unique engines or large context pools.
|
||||
|
||||
**Solutions:**
|
||||
1. Ensure deduplication working (same file paths)
|
||||
2. Reduce context pool sizes
|
||||
3. Use smaller models or quantization (INT8/FP16)
|
||||
|
||||
### Import Error: "tensorrt could not be resolved"
|
||||
|
||||
**Solution:** Install TensorRT:
|
||||
```bash
|
||||
pip install tensorrt
|
||||
# Or use NVIDIA's wheel for your CUDA version
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Batch Processing:** Process multiple frames before synchronizing
|
||||
```python
|
||||
outputs = repo.infer(model_id, inputs, synchronize=False)
|
||||
# ... more inferences ...
|
||||
torch.cuda.synchronize() # Sync once at end
|
||||
```
|
||||
|
||||
2. **Async Inference:** Don't synchronize if not needed immediately
|
||||
```python
|
||||
outputs = repo.infer(model_id, inputs, synchronize=False)
|
||||
# GPU continues working, CPU continues
|
||||
# Synchronize later when you need results
|
||||
```
|
||||
|
||||
3. **Monitor Context Utilization:**
|
||||
```python
|
||||
stats = repo.get_stats()
|
||||
print(f"Contexts: {stats['total_contexts']}")
|
||||
|
||||
# If timeouts occur frequently, increase pool size
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
Part of python-rtsp-worker project.
|
||||
14
services/__init__.py
Normal file
14
services/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
"""
|
||||
Services package for RTSP stream processing with GPU acceleration.
|
||||
"""
|
||||
|
||||
from .stream_decoder import StreamDecoderFactory, StreamDecoder, ConnectionStatus
|
||||
from .jpeg_encoder import JPEGEncoderFactory, encode_frame_to_jpeg
|
||||
|
||||
__all__ = [
|
||||
'StreamDecoderFactory',
|
||||
'StreamDecoder',
|
||||
'ConnectionStatus',
|
||||
'JPEGEncoderFactory',
|
||||
'encode_frame_to_jpeg',
|
||||
]
|
||||
91
services/jpeg_encoder.py
Normal file
91
services/jpeg_encoder.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
"""
|
||||
JPEG Encoder wrapper for GPU-accelerated JPEG encoding using nvImageCodec/nvJPEG.
|
||||
Provides a shared encoder instance that can be used across multiple streams.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import torch
|
||||
import nvidia.nvimgcodec as nvimgcodec
|
||||
|
||||
|
||||
class JPEGEncoderFactory:
|
||||
"""
|
||||
Factory for creating and managing a shared JPEG encoder instance.
|
||||
Thread-safe singleton pattern for efficient resource sharing.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_encoder = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(JPEGEncoderFactory, cls).__new__(cls)
|
||||
cls._encoder = nvimgcodec.Encoder()
|
||||
print("JPEGEncoderFactory initialized with shared nvJPEG encoder")
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def get_encoder(cls):
|
||||
"""Get the shared JPEG encoder instance"""
|
||||
if cls._encoder is None:
|
||||
cls() # Initialize if not already done
|
||||
return cls._encoder
|
||||
|
||||
|
||||
def encode_frame_to_jpeg(rgb_frame: torch.Tensor, quality: int = 95) -> Optional[bytes]:
|
||||
"""
|
||||
Encode an RGB frame to JPEG on GPU and return JPEG bytes.
|
||||
|
||||
This function:
|
||||
1. Takes RGB frame from GPU (stays on GPU during encoding)
|
||||
2. Converts PyTorch tensor to nvImageCodec image via as_image()
|
||||
3. Encodes to JPEG using nvJPEG (GPU operation)
|
||||
4. Transfers only JPEG bytes to CPU
|
||||
5. Returns bytes for saving to disk
|
||||
|
||||
Args:
|
||||
rgb_frame: RGB tensor on GPU, shape (3, H, W) or (H, W, 3), dtype uint8
|
||||
quality: JPEG quality (0-100, default 95)
|
||||
|
||||
Returns:
|
||||
JPEG encoded bytes or None if encoding fails
|
||||
"""
|
||||
if rgb_frame is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Ensure we have (H, W, C) format and contiguous memory
|
||||
if rgb_frame.dim() == 3:
|
||||
if rgb_frame.shape[0] == 3:
|
||||
# Convert from (C, H, W) to (H, W, C)
|
||||
rgb_hwc = rgb_frame.permute(1, 2, 0).contiguous()
|
||||
else:
|
||||
# Already (H, W, C)
|
||||
rgb_hwc = rgb_frame.contiguous()
|
||||
else:
|
||||
raise ValueError(f"Expected 3D tensor, got shape {rgb_frame.shape}")
|
||||
|
||||
# Get shared encoder
|
||||
encoder = JPEGEncoderFactory.get_encoder()
|
||||
|
||||
# Create encode parameters with quality
|
||||
# Quality is set via quality_value (0-100 scale)
|
||||
jpeg_params = nvimgcodec.JpegEncodeParams(optimized_huffman=True)
|
||||
encode_params = nvimgcodec.EncodeParams(
|
||||
quality_value=float(quality),
|
||||
jpeg_encode_params=jpeg_params
|
||||
)
|
||||
|
||||
# Convert PyTorch GPU tensor to nvImageCodec image using __cuda_array_interface__
|
||||
# This is zero-copy - nvimgcodec reads directly from GPU memory
|
||||
nv_image = nvimgcodec.as_image(rgb_hwc)
|
||||
|
||||
# Encode to JPEG on GPU
|
||||
# The encoding happens on GPU, only compressed JPEG bytes are transferred to CPU
|
||||
jpeg_data = encoder.encode(nv_image, "jpeg", encode_params)
|
||||
|
||||
return bytes(jpeg_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error encoding frame to JPEG: {e}")
|
||||
return None
|
||||
631
services/model_repository.py
Normal file
631
services/model_repository.py
Normal file
|
|
@ -0,0 +1,631 @@
|
|||
import threading
|
||||
import hashlib
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
import torch
|
||||
import tensorrt as trt
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelMetadata:
|
||||
"""Metadata for a loaded TensorRT model"""
|
||||
file_path: str
|
||||
file_hash: str
|
||||
input_shapes: Dict[str, Tuple[int, ...]]
|
||||
output_shapes: Dict[str, Tuple[int, ...]]
|
||||
input_names: List[str]
|
||||
output_names: List[str]
|
||||
input_dtypes: Dict[str, torch.dtype]
|
||||
output_dtypes: Dict[str, torch.dtype]
|
||||
|
||||
|
||||
class ExecutionContext:
|
||||
"""
|
||||
Wrapper for TensorRT execution context with CUDA stream.
|
||||
Used in context pool for load balancing.
|
||||
"""
|
||||
def __init__(self, context: trt.IExecutionContext, stream: torch.cuda.Stream,
|
||||
context_id: int, device: torch.device):
|
||||
self.context = context
|
||||
self.stream = stream
|
||||
self.context_id = context_id
|
||||
self.device = device
|
||||
self.in_use = False
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def __repr__(self):
|
||||
return f"ExecutionContext(id={self.context_id}, in_use={self.in_use})"
|
||||
|
||||
|
||||
class SharedEngine:
|
||||
"""
|
||||
Shared TensorRT engine with context pool for load balancing.
|
||||
|
||||
Architecture:
|
||||
- One engine shared across all model_ids with same file hash
|
||||
- Pool of N execution contexts for concurrent inference
|
||||
- Contexts are borrowed/returned using mutex locks
|
||||
- Load balancing: contexts distributed across requests
|
||||
"""
|
||||
def __init__(self, engine: trt.ICudaEngine, file_hash: str, file_path: str,
|
||||
num_contexts: int, device: torch.device, metadata: ModelMetadata):
|
||||
self.engine = engine
|
||||
self.file_hash = file_hash
|
||||
self.file_path = file_path
|
||||
self.metadata = metadata
|
||||
self.device = device
|
||||
self.num_contexts = num_contexts
|
||||
|
||||
# Create context pool
|
||||
self.context_pool: List[ExecutionContext] = []
|
||||
self.available_contexts: Queue[ExecutionContext] = Queue()
|
||||
|
||||
for i in range(num_contexts):
|
||||
ctx = engine.create_execution_context()
|
||||
if ctx is None:
|
||||
raise RuntimeError(f"Failed to create execution context {i}")
|
||||
|
||||
stream = torch.cuda.Stream(device=device)
|
||||
exec_ctx = ExecutionContext(ctx, stream, i, device)
|
||||
self.context_pool.append(exec_ctx)
|
||||
self.available_contexts.put(exec_ctx)
|
||||
|
||||
# Model IDs referencing this engine
|
||||
self.model_ids: set = set()
|
||||
self.lock = threading.Lock()
|
||||
|
||||
print(f"Created context pool with {num_contexts} contexts for engine {file_hash[:8]}...")
|
||||
|
||||
def acquire_context(self, timeout: Optional[float] = None) -> Optional[ExecutionContext]:
|
||||
"""
|
||||
Acquire an available execution context from the pool.
|
||||
Blocks if all contexts are in use.
|
||||
|
||||
Args:
|
||||
timeout: Max time to wait for context (None = wait forever)
|
||||
|
||||
Returns:
|
||||
ExecutionContext or None if timeout
|
||||
"""
|
||||
try:
|
||||
exec_ctx = self.available_contexts.get(timeout=timeout)
|
||||
with exec_ctx.lock:
|
||||
exec_ctx.in_use = True
|
||||
return exec_ctx
|
||||
except:
|
||||
return None
|
||||
|
||||
def release_context(self, exec_ctx: ExecutionContext):
|
||||
"""
|
||||
Return a context to the pool.
|
||||
|
||||
Args:
|
||||
exec_ctx: Context to release
|
||||
"""
|
||||
with exec_ctx.lock:
|
||||
exec_ctx.in_use = False
|
||||
self.available_contexts.put(exec_ctx)
|
||||
|
||||
def add_model_id(self, model_id: str):
|
||||
"""Add a model_id reference to this engine"""
|
||||
with self.lock:
|
||||
self.model_ids.add(model_id)
|
||||
|
||||
def remove_model_id(self, model_id: str) -> int:
|
||||
"""
|
||||
Remove a model_id reference from this engine.
|
||||
Returns the number of remaining references.
|
||||
"""
|
||||
with self.lock:
|
||||
self.model_ids.discard(model_id)
|
||||
return len(self.model_ids)
|
||||
|
||||
def get_reference_count(self) -> int:
|
||||
"""Get number of model_ids using this engine"""
|
||||
with self.lock:
|
||||
return len(self.model_ids)
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup all contexts"""
|
||||
for exec_ctx in self.context_pool:
|
||||
del exec_ctx.context
|
||||
self.context_pool.clear()
|
||||
del self.engine
|
||||
|
||||
|
||||
class TensorRTModelRepository:
|
||||
"""
|
||||
Thread-safe repository for TensorRT models with context pooling and deduplication.
|
||||
|
||||
Architecture:
|
||||
- Deduplication: Multiple model_ids with same file → share one engine
|
||||
- Context Pool: Each unique engine has N execution contexts (configurable)
|
||||
- Load Balancing: Contexts are borrowed/returned via mutex queue
|
||||
- Scalability: Adding 100 cameras with same model = 1 engine + N contexts (not 100 contexts!)
|
||||
|
||||
Best Practices:
|
||||
- GPU-to-GPU: All inputs/outputs stay in VRAM (zero CPU transfers)
|
||||
- Thread Safety: Mutex-based context borrowing (TensorRT best practice)
|
||||
- Memory Efficient: Deduplicate by file hash, share engine across model_ids
|
||||
- Concurrent: N contexts allow N parallel inferences per unique model
|
||||
|
||||
Example:
|
||||
# 100 cameras, same model file
|
||||
for i in range(100):
|
||||
repo.load_model(f"camera_{i}", "yolov8.trt")
|
||||
# Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts!
|
||||
"""
|
||||
|
||||
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4):
|
||||
"""
|
||||
Initialize the model repository.
|
||||
|
||||
Args:
|
||||
gpu_id: GPU device ID to use
|
||||
default_num_contexts: Default number of execution contexts per unique engine
|
||||
"""
|
||||
self.gpu_id = gpu_id
|
||||
self.device = torch.device(f'cuda:{gpu_id}')
|
||||
self.default_num_contexts = default_num_contexts
|
||||
|
||||
# Model ID to engine mapping: model_id -> file_hash
|
||||
self._model_to_hash: Dict[str, str] = {}
|
||||
|
||||
# Shared engines with context pools: file_hash -> SharedEngine
|
||||
self._shared_engines: Dict[str, SharedEngine] = {}
|
||||
|
||||
# Locks for thread safety
|
||||
self._repo_lock = threading.RLock()
|
||||
|
||||
# TensorRT logger
|
||||
self.trt_logger = trt.Logger(trt.Logger.WARNING)
|
||||
|
||||
print(f"TensorRT Model Repository initialized on GPU {gpu_id}")
|
||||
print(f"Default context pool size: {default_num_contexts} contexts per unique model")
|
||||
|
||||
@staticmethod
|
||||
def compute_file_hash(file_path: str) -> str:
|
||||
"""
|
||||
Compute SHA256 hash of a file for deduplication.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
Hexadecimal hash string
|
||||
"""
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
# Read in chunks to handle large files efficiently
|
||||
for byte_block in iter(lambda: f.read(65536), b""):
|
||||
sha256_hash.update(byte_block)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
def _load_engine(self, file_path: str) -> trt.ICudaEngine:
|
||||
"""
|
||||
Load TensorRT engine from file.
|
||||
|
||||
Args:
|
||||
file_path: Path to .trt or .engine file
|
||||
|
||||
Returns:
|
||||
TensorRT engine
|
||||
"""
|
||||
runtime = trt.Runtime(self.trt_logger)
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
engine_data = f.read()
|
||||
|
||||
engine = runtime.deserialize_cuda_engine(engine_data)
|
||||
if engine is None:
|
||||
raise RuntimeError(f"Failed to load TensorRT engine from {file_path}")
|
||||
|
||||
return engine
|
||||
|
||||
def _extract_metadata(self, engine: trt.ICudaEngine,
|
||||
file_path: str, file_hash: str) -> ModelMetadata:
|
||||
"""
|
||||
Extract metadata from TensorRT engine.
|
||||
|
||||
Args:
|
||||
engine: TensorRT engine
|
||||
file_path: Path to model file
|
||||
file_hash: SHA256 hash of model file
|
||||
|
||||
Returns:
|
||||
ModelMetadata object
|
||||
"""
|
||||
input_shapes = {}
|
||||
output_shapes = {}
|
||||
input_names = []
|
||||
output_names = []
|
||||
input_dtypes = {}
|
||||
output_dtypes = {}
|
||||
|
||||
# TensorRT dtype to PyTorch dtype mapping
|
||||
trt_to_torch_dtype = {
|
||||
trt.DataType.FLOAT: torch.float32,
|
||||
trt.DataType.HALF: torch.float16,
|
||||
trt.DataType.INT8: torch.int8,
|
||||
trt.DataType.INT32: torch.int32,
|
||||
trt.DataType.BOOL: torch.bool,
|
||||
}
|
||||
|
||||
# Iterate through all tensors (inputs and outputs) - TensorRT 10.x API
|
||||
for i in range(engine.num_io_tensors):
|
||||
name = engine.get_tensor_name(i)
|
||||
shape = tuple(engine.get_tensor_shape(name))
|
||||
dtype = trt_to_torch_dtype.get(engine.get_tensor_dtype(name), torch.float32)
|
||||
mode = engine.get_tensor_mode(name)
|
||||
|
||||
if mode == trt.TensorIOMode.INPUT:
|
||||
input_names.append(name)
|
||||
input_shapes[name] = shape
|
||||
input_dtypes[name] = dtype
|
||||
else:
|
||||
output_names.append(name)
|
||||
output_shapes[name] = shape
|
||||
output_dtypes[name] = dtype
|
||||
|
||||
return ModelMetadata(
|
||||
file_path=file_path,
|
||||
file_hash=file_hash,
|
||||
input_shapes=input_shapes,
|
||||
output_shapes=output_shapes,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
input_dtypes=input_dtypes,
|
||||
output_dtypes=output_dtypes
|
||||
)
|
||||
|
||||
def load_model(self, model_id: str, file_path: str,
|
||||
num_contexts: Optional[int] = None,
|
||||
force_reload: bool = False) -> ModelMetadata:
|
||||
"""
|
||||
Load a TensorRT model with the given ID.
|
||||
|
||||
Deduplication: If a model with the same file hash is already loaded, the model_id
|
||||
is simply mapped to the existing SharedEngine (no new engine or contexts created).
|
||||
|
||||
Args:
|
||||
model_id: User-defined identifier for this model (e.g., "camera_1")
|
||||
file_path: Path to TensorRT engine file (.trt or .engine)
|
||||
num_contexts: Number of execution contexts in pool (None = use default)
|
||||
force_reload: If True, reload even if model_id exists
|
||||
|
||||
Returns:
|
||||
ModelMetadata for the loaded model
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If model file doesn't exist
|
||||
RuntimeError: If engine loading fails
|
||||
ValueError: If model_id already exists and force_reload is False
|
||||
"""
|
||||
file_path = str(Path(file_path).resolve())
|
||||
|
||||
if not Path(file_path).exists():
|
||||
raise FileNotFoundError(f"Model file not found: {file_path}")
|
||||
|
||||
if num_contexts is None:
|
||||
num_contexts = self.default_num_contexts
|
||||
|
||||
with self._repo_lock:
|
||||
# Check if model_id already exists
|
||||
if model_id in self._model_to_hash and not force_reload:
|
||||
raise ValueError(
|
||||
f"Model ID '{model_id}' already exists. "
|
||||
f"Use force_reload=True to reload or choose a different ID."
|
||||
)
|
||||
|
||||
# Unload existing model if force_reload
|
||||
if model_id in self._model_to_hash and force_reload:
|
||||
self.unload_model(model_id)
|
||||
|
||||
# Compute file hash for deduplication
|
||||
print(f"Computing hash for {file_path}...")
|
||||
file_hash = self.compute_file_hash(file_path)
|
||||
print(f"File hash: {file_hash[:16]}...")
|
||||
|
||||
# Check if this file is already loaded (deduplication)
|
||||
if file_hash in self._shared_engines:
|
||||
shared_engine = self._shared_engines[file_hash]
|
||||
print(f"Engine already loaded (hash match), reusing engine and context pool...")
|
||||
print(f" Existing model_ids using this engine: {shared_engine.model_ids}")
|
||||
else:
|
||||
# Load new engine
|
||||
print(f"Loading TensorRT engine from {file_path}...")
|
||||
engine = self._load_engine(file_path)
|
||||
|
||||
# Extract metadata
|
||||
metadata = self._extract_metadata(engine, file_path, file_hash)
|
||||
|
||||
# Create shared engine with context pool
|
||||
shared_engine = SharedEngine(
|
||||
engine=engine,
|
||||
file_hash=file_hash,
|
||||
file_path=file_path,
|
||||
num_contexts=num_contexts,
|
||||
device=self.device,
|
||||
metadata=metadata
|
||||
)
|
||||
self._shared_engines[file_hash] = shared_engine
|
||||
|
||||
# Add this model_id to the shared engine
|
||||
shared_engine.add_model_id(model_id)
|
||||
|
||||
# Map model_id to file_hash
|
||||
self._model_to_hash[model_id] = file_hash
|
||||
|
||||
print(f"Model '{model_id}' loaded successfully")
|
||||
print(f" Inputs: {shared_engine.metadata.input_names}")
|
||||
for name in shared_engine.metadata.input_names:
|
||||
print(f" {name}: {shared_engine.metadata.input_shapes[name]} ({shared_engine.metadata.input_dtypes[name]})")
|
||||
print(f" Outputs: {shared_engine.metadata.output_names}")
|
||||
for name in shared_engine.metadata.output_names:
|
||||
print(f" {name}: {shared_engine.metadata.output_shapes[name]} ({shared_engine.metadata.output_dtypes[name]})")
|
||||
print(f" Context pool size: {num_contexts}")
|
||||
print(f" Model IDs sharing this engine: {shared_engine.get_reference_count()}")
|
||||
print(f" Unique engines in VRAM: {len(self._shared_engines)}")
|
||||
|
||||
return shared_engine.metadata
|
||||
|
||||
def infer(self, model_id: str, inputs: Dict[str, torch.Tensor],
|
||||
synchronize: bool = True, timeout: Optional[float] = 5.0) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Run GPU-to-GPU inference with the specified model using context pooling.
|
||||
|
||||
All inputs must be CUDA tensors and outputs will be CUDA tensors (stays in VRAM).
|
||||
Thread-safe: Borrows an execution context from the pool with mutex locking.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
inputs: Dictionary mapping input names to CUDA tensors
|
||||
synchronize: If True, wait for inference to complete. If False, async execution.
|
||||
timeout: Max time to wait for available context (seconds)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping output names to CUDA tensors (in VRAM)
|
||||
|
||||
Raises:
|
||||
KeyError: If model_id not found
|
||||
ValueError: If inputs don't match expected shapes or are not on GPU
|
||||
RuntimeError: If no context available within timeout
|
||||
"""
|
||||
# Get shared engine
|
||||
if model_id not in self._model_to_hash:
|
||||
raise KeyError(f"Model '{model_id}' not found. Available: {list(self._model_to_hash.keys())}")
|
||||
|
||||
file_hash = self._model_to_hash[model_id]
|
||||
shared_engine = self._shared_engines[file_hash]
|
||||
metadata = shared_engine.metadata
|
||||
|
||||
# Validate inputs
|
||||
for name in metadata.input_names:
|
||||
if name not in inputs:
|
||||
raise ValueError(f"Missing required input: {name}")
|
||||
|
||||
tensor = inputs[name]
|
||||
if not tensor.is_cuda:
|
||||
raise ValueError(f"Input '{name}' must be a CUDA tensor (on GPU)")
|
||||
|
||||
# Check device
|
||||
if tensor.device != self.device:
|
||||
print(f"Warning: Input '{name}' on {tensor.device}, moving to {self.device}")
|
||||
inputs[name] = tensor.to(self.device)
|
||||
|
||||
# Acquire context from pool (mutex-based)
|
||||
exec_ctx = shared_engine.acquire_context(timeout=timeout)
|
||||
if exec_ctx is None:
|
||||
raise RuntimeError(
|
||||
f"No execution context available for model '{model_id}' within {timeout}s. "
|
||||
f"All {shared_engine.num_contexts} contexts are busy."
|
||||
)
|
||||
|
||||
try:
|
||||
# Prepare output tensors
|
||||
outputs = {}
|
||||
|
||||
# Set input tensors - TensorRT 10.x API
|
||||
for name in metadata.input_names:
|
||||
input_tensor = inputs[name].contiguous()
|
||||
exec_ctx.context.set_tensor_address(name, input_tensor.data_ptr())
|
||||
|
||||
# Allocate and set output tensors
|
||||
for name in metadata.output_names:
|
||||
output_shape = metadata.output_shapes[name]
|
||||
output_dtype = metadata.output_dtypes[name]
|
||||
|
||||
output_tensor = torch.empty(
|
||||
output_shape,
|
||||
dtype=output_dtype,
|
||||
device=self.device
|
||||
)
|
||||
|
||||
outputs[name] = output_tensor
|
||||
exec_ctx.context.set_tensor_address(name, output_tensor.data_ptr())
|
||||
|
||||
# Execute inference on context's stream - TensorRT 10.x API
|
||||
with torch.cuda.stream(exec_ctx.stream):
|
||||
success = exec_ctx.context.execute_async_v3(
|
||||
stream_handle=exec_ctx.stream.cuda_stream
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise RuntimeError(f"Inference failed for model '{model_id}'")
|
||||
|
||||
# Synchronize if requested
|
||||
if synchronize:
|
||||
exec_ctx.stream.synchronize()
|
||||
|
||||
return outputs
|
||||
|
||||
finally:
|
||||
# Always release context back to pool
|
||||
shared_engine.release_context(exec_ctx)
|
||||
|
||||
def infer_batch(self, model_id: str, batch_inputs: List[Dict[str, torch.Tensor]],
|
||||
synchronize: bool = True) -> List[Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Run inference on multiple inputs.
|
||||
Contexts are borrowed/returned for each input, enabling parallel processing.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
batch_inputs: List of input dictionaries
|
||||
synchronize: If True, wait for all inferences to complete
|
||||
|
||||
Returns:
|
||||
List of output dictionaries
|
||||
"""
|
||||
results = []
|
||||
for inputs in batch_inputs:
|
||||
outputs = self.infer(model_id, inputs, synchronize=synchronize)
|
||||
results.append(outputs)
|
||||
|
||||
return results
|
||||
|
||||
def unload_model(self, model_id: str):
|
||||
"""
|
||||
Unload a model from the repository.
|
||||
|
||||
Removes the model_id reference from the shared engine. If this was the last
|
||||
reference, the engine and all its contexts will be fully unloaded from VRAM.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier to unload
|
||||
"""
|
||||
with self._repo_lock:
|
||||
if model_id not in self._model_to_hash:
|
||||
print(f"Warning: Model '{model_id}' not found")
|
||||
return
|
||||
|
||||
file_hash = self._model_to_hash[model_id]
|
||||
|
||||
# Remove model_id from shared engine
|
||||
if file_hash in self._shared_engines:
|
||||
shared_engine = self._shared_engines[file_hash]
|
||||
remaining_refs = shared_engine.remove_model_id(model_id)
|
||||
|
||||
# If no more references, cleanup engine and contexts
|
||||
if remaining_refs == 0:
|
||||
shared_engine.cleanup()
|
||||
del self._shared_engines[file_hash]
|
||||
print(f"Model '{model_id}' unloaded, engine removed from VRAM (0 references)")
|
||||
else:
|
||||
print(f"Model '{model_id}' unloaded, engine kept in VRAM ({remaining_refs} references)")
|
||||
|
||||
# Remove from model_id mapping
|
||||
del self._model_to_hash[model_id]
|
||||
|
||||
def get_metadata(self, model_id: str) -> Optional[ModelMetadata]:
|
||||
"""
|
||||
Get metadata for a loaded model.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
ModelMetadata or None if not found
|
||||
"""
|
||||
if model_id not in self._model_to_hash:
|
||||
return None
|
||||
|
||||
file_hash = self._model_to_hash[model_id]
|
||||
if file_hash not in self._shared_engines:
|
||||
return None
|
||||
|
||||
return self._shared_engines[file_hash].metadata
|
||||
|
||||
def list_models(self) -> Dict[str, ModelMetadata]:
|
||||
"""
|
||||
List all loaded models.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model_id to ModelMetadata
|
||||
"""
|
||||
with self._repo_lock:
|
||||
result = {}
|
||||
for model_id, file_hash in self._model_to_hash.items():
|
||||
if file_hash in self._shared_engines:
|
||||
result[model_id] = self._shared_engines[file_hash].metadata
|
||||
return result
|
||||
|
||||
def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get detailed information about a loaded model.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
Dictionary with model information or None if not found
|
||||
"""
|
||||
if model_id not in self._model_to_hash:
|
||||
return None
|
||||
|
||||
file_hash = self._model_to_hash[model_id]
|
||||
if file_hash not in self._shared_engines:
|
||||
return None
|
||||
|
||||
shared_engine = self._shared_engines[file_hash]
|
||||
metadata = shared_engine.metadata
|
||||
|
||||
return {
|
||||
'model_id': model_id,
|
||||
'file_path': metadata.file_path,
|
||||
'file_hash': metadata.file_hash[:16] + '...',
|
||||
'engine_references': shared_engine.get_reference_count(),
|
||||
'context_pool_size': shared_engine.num_contexts,
|
||||
'shared_with_model_ids': list(shared_engine.model_ids),
|
||||
'inputs': {
|
||||
name: {
|
||||
'shape': metadata.input_shapes[name],
|
||||
'dtype': str(metadata.input_dtypes[name])
|
||||
}
|
||||
for name in metadata.input_names
|
||||
},
|
||||
'outputs': {
|
||||
name: {
|
||||
'shape': metadata.output_shapes[name],
|
||||
'dtype': str(metadata.output_dtypes[name])
|
||||
}
|
||||
for name in metadata.output_names
|
||||
}
|
||||
}
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get repository statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with stats about loaded models and memory usage
|
||||
"""
|
||||
with self._repo_lock:
|
||||
total_contexts = sum(
|
||||
engine.num_contexts
|
||||
for engine in self._shared_engines.values()
|
||||
)
|
||||
|
||||
return {
|
||||
'total_model_ids': len(self._model_to_hash),
|
||||
'unique_engines': len(self._shared_engines),
|
||||
'total_contexts': total_contexts,
|
||||
'memory_efficiency': f"{len(self._model_to_hash)} model IDs using only {len(self._shared_engines)} engines",
|
||||
'gpu_id': self.gpu_id,
|
||||
'models': list(self._model_to_hash.keys())
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
with self._repo_lock:
|
||||
return (f"TensorRTModelRepository(gpu={self.gpu_id}, "
|
||||
f"model_ids={len(self._model_to_hash)}, "
|
||||
f"unique_engines={len(self._shared_engines)})")
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup all models on deletion"""
|
||||
with self._repo_lock:
|
||||
model_ids = list(self._model_to_hash.keys())
|
||||
for model_id in model_ids:
|
||||
self.unload_model(model_id)
|
||||
481
services/stream_decoder.py
Normal file
481
services/stream_decoder.py
Normal file
|
|
@ -0,0 +1,481 @@
|
|||
import threading
|
||||
from typing import Optional
|
||||
from collections import deque
|
||||
from enum import Enum
|
||||
import torch
|
||||
import PyNvVideoCodec as nvc
|
||||
import av
|
||||
import numpy as np
|
||||
from cuda.bindings import driver as cuda_driver
|
||||
from .jpeg_encoder import encode_frame_to_jpeg
|
||||
|
||||
|
||||
def nv12_to_rgb_gpu(nv12_tensor: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
Convert NV12 format to RGB on GPU using PyTorch operations.
|
||||
|
||||
NV12 format:
|
||||
- Y plane: height x width (luminance)
|
||||
- UV plane: (height/2) x width (interleaved U and V, subsampled by 2)
|
||||
|
||||
Total tensor size: (height * 3/2) x width
|
||||
|
||||
Args:
|
||||
nv12_tensor: Input tensor in NV12 format, shape (H*3/2, W)
|
||||
height: Original frame height
|
||||
width: Original frame width
|
||||
|
||||
Returns:
|
||||
RGB tensor, shape (3, H, W) in range [0, 255]
|
||||
"""
|
||||
device = nv12_tensor.device
|
||||
|
||||
# Split Y and UV planes
|
||||
y_plane = nv12_tensor[:height, :].float() # (H, W)
|
||||
uv_plane = nv12_tensor[height:, :].float() # (H/2, W)
|
||||
|
||||
# Reshape UV plane to separate U and V channels
|
||||
# UV is interleaved: U0V0U1V1... we need to deinterleave
|
||||
uv_plane = uv_plane.reshape(height // 2, width // 2, 2) # (H/2, W/2, 2)
|
||||
u_plane = uv_plane[:, :, 0] # (H/2, W/2)
|
||||
v_plane = uv_plane[:, :, 1] # (H/2, W/2)
|
||||
|
||||
# Upsample U and V to full resolution using bilinear interpolation
|
||||
u_upsampled = torch.nn.functional.interpolate(
|
||||
u_plane.unsqueeze(0).unsqueeze(0), # (1, 1, H/2, W/2)
|
||||
size=(height, width),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
).squeeze(0).squeeze(0) # (H, W)
|
||||
|
||||
v_upsampled = torch.nn.functional.interpolate(
|
||||
v_plane.unsqueeze(0).unsqueeze(0), # (1, 1, H/2, W/2)
|
||||
size=(height, width),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
).squeeze(0).squeeze(0) # (H, W)
|
||||
|
||||
# YUV to RGB conversion using BT.601 standard
|
||||
# R = Y + 1.402 * (V - 128)
|
||||
# G = Y - 0.344136 * (U - 128) - 0.714136 * (V - 128)
|
||||
# B = Y + 1.772 * (U - 128)
|
||||
|
||||
y = y_plane
|
||||
u = u_upsampled - 128.0
|
||||
v = v_upsampled - 128.0
|
||||
|
||||
r = y + 1.402 * v
|
||||
g = y - 0.344136 * u - 0.714136 * v
|
||||
b = y + 1.772 * u
|
||||
|
||||
# Clamp to [0, 255] and convert to uint8
|
||||
r = torch.clamp(r, 0, 255).to(torch.uint8)
|
||||
g = torch.clamp(g, 0, 255).to(torch.uint8)
|
||||
b = torch.clamp(b, 0, 255).to(torch.uint8)
|
||||
|
||||
# Stack to (3, H, W)
|
||||
rgb = torch.stack([r, g, b], dim=0)
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
class ConnectionStatus(Enum):
|
||||
DISCONNECTED = "disconnected"
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
ERROR = "error"
|
||||
RECONNECTING = "reconnecting"
|
||||
|
||||
|
||||
class StreamDecoderFactory:
|
||||
"""
|
||||
Factory for creating StreamDecoder instances with shared CUDA context.
|
||||
This minimizes VRAM overhead by sharing the CUDA context across all decoders.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, gpu_id: int = 0):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super(StreamDecoderFactory, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, gpu_id: int = 0):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
|
||||
# Initialize CUDA and get device
|
||||
err, = cuda_driver.cuInit(0)
|
||||
if err != cuda_driver.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"Failed to initialize CUDA: {err}")
|
||||
|
||||
# Get CUDA device
|
||||
err, self.cuda_device = cuda_driver.cuDeviceGet(gpu_id)
|
||||
if err != cuda_driver.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"Failed to get CUDA device {gpu_id}: {err}")
|
||||
|
||||
# Retain primary context (shared across all decoders)
|
||||
err, self.cuda_context = cuda_driver.cuDevicePrimaryCtxRetain(self.cuda_device)
|
||||
if err != cuda_driver.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"Failed to retain CUDA primary context: {err}")
|
||||
|
||||
self._initialized = True
|
||||
print(f"StreamDecoderFactory initialized with shared CUDA context on GPU {gpu_id}")
|
||||
|
||||
def create_decoder(self, rtsp_url: str, buffer_size: int = 30,
|
||||
codec: str = "h264") -> 'StreamDecoder':
|
||||
"""
|
||||
Create a new StreamDecoder instance with shared CUDA context.
|
||||
|
||||
Args:
|
||||
rtsp_url: RTSP stream URL
|
||||
buffer_size: Number of frames to buffer in VRAM
|
||||
codec: Video codec (h264, hevc, etc.)
|
||||
|
||||
Returns:
|
||||
StreamDecoder instance
|
||||
"""
|
||||
return StreamDecoder(
|
||||
rtsp_url=rtsp_url,
|
||||
cuda_context=self.cuda_context,
|
||||
gpu_id=self.gpu_id,
|
||||
buffer_size=buffer_size,
|
||||
codec=codec
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup shared CUDA context on factory destruction"""
|
||||
if hasattr(self, 'cuda_device') and hasattr(self, 'gpu_id'):
|
||||
cuda_driver.cuDevicePrimaryCtxRelease(self.cuda_device)
|
||||
|
||||
|
||||
class StreamDecoder:
|
||||
"""
|
||||
Decodes RTSP stream using NVDEC and maintains a ring buffer of frames in GPU VRAM.
|
||||
Thread-safe for concurrent read/write operations.
|
||||
"""
|
||||
|
||||
def __init__(self, rtsp_url: str, cuda_context, gpu_id: int,
|
||||
buffer_size: int = 30, codec: str = "h264"):
|
||||
"""
|
||||
Initialize StreamDecoder.
|
||||
|
||||
Args:
|
||||
rtsp_url: RTSP stream URL
|
||||
cuda_context: Shared CUDA context handle
|
||||
gpu_id: GPU device ID
|
||||
buffer_size: Number of frames to keep in ring buffer
|
||||
codec: Video codec type
|
||||
"""
|
||||
self.rtsp_url = rtsp_url
|
||||
self.cuda_context = cuda_context
|
||||
self.gpu_id = gpu_id
|
||||
self.buffer_size = buffer_size
|
||||
self.codec = codec
|
||||
|
||||
# Connection status
|
||||
self.status = ConnectionStatus.DISCONNECTED
|
||||
self._status_lock = threading.Lock()
|
||||
|
||||
# Frame buffer (ring buffer) - stores CUDA device pointers
|
||||
self.frame_buffer = deque(maxlen=buffer_size)
|
||||
self._buffer_lock = threading.RLock()
|
||||
|
||||
# Decoder and container instances
|
||||
self.decoder = None
|
||||
self.container = None
|
||||
|
||||
# Decode thread
|
||||
self._decode_thread: Optional[threading.Thread] = None
|
||||
self._stop_flag = threading.Event()
|
||||
|
||||
# Frame metadata
|
||||
self.frame_width: Optional[int] = None
|
||||
self.frame_height: Optional[int] = None
|
||||
self.frame_count: int = 0
|
||||
|
||||
def start(self):
|
||||
"""Start the RTSP stream decoding in background thread"""
|
||||
if self._decode_thread is not None and self._decode_thread.is_alive():
|
||||
print(f"Decoder already running for {self.rtsp_url}")
|
||||
return
|
||||
|
||||
self._stop_flag.clear()
|
||||
self._decode_thread = threading.Thread(target=self._decode_loop, daemon=True)
|
||||
self._decode_thread.start()
|
||||
print(f"Started decoder thread for {self.rtsp_url}")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the decoding thread and cleanup resources"""
|
||||
self._stop_flag.set()
|
||||
if self._decode_thread is not None:
|
||||
self._decode_thread.join(timeout=5.0)
|
||||
self._cleanup()
|
||||
print(f"Stopped decoder for {self.rtsp_url}")
|
||||
|
||||
def _set_status(self, status: ConnectionStatus):
|
||||
"""Thread-safe status update"""
|
||||
with self._status_lock:
|
||||
self.status = status
|
||||
|
||||
def get_status(self) -> ConnectionStatus:
|
||||
"""Get current connection status"""
|
||||
with self._status_lock:
|
||||
return self.status
|
||||
|
||||
def _init_rtsp_connection(self) -> bool:
|
||||
"""Initialize RTSP connection using PyAV + PyNvVideoCodec"""
|
||||
try:
|
||||
self._set_status(ConnectionStatus.CONNECTING)
|
||||
|
||||
# Open RTSP stream with PyAV
|
||||
options = {
|
||||
'rtsp_transport': 'tcp',
|
||||
'max_delay': '500000', # 500ms
|
||||
'rtsp_flags': 'prefer_tcp',
|
||||
'timeout': '5000000', # 5 seconds
|
||||
}
|
||||
|
||||
self.container = av.open(self.rtsp_url, options=options)
|
||||
|
||||
# Get video stream
|
||||
video_stream = self.container.streams.video[0]
|
||||
self.frame_width = video_stream.width
|
||||
self.frame_height = video_stream.height
|
||||
|
||||
print(f"RTSP connected: {self.frame_width}x{self.frame_height}")
|
||||
|
||||
# Map codec name to PyNvVideoCodec codec enum
|
||||
codec_map = {
|
||||
'h264': nvc.cudaVideoCodec.H264,
|
||||
'hevc': nvc.cudaVideoCodec.HEVC,
|
||||
'h265': nvc.cudaVideoCodec.HEVC,
|
||||
}
|
||||
|
||||
codec_id = codec_map.get(self.codec.lower(), nvc.cudaVideoCodec.H264)
|
||||
|
||||
# Initialize NVDEC decoder with shared CUDA context
|
||||
self.decoder = nvc.CreateDecoder(
|
||||
gpuid=self.gpu_id,
|
||||
codec=codec_id,
|
||||
cudacontext=self.cuda_context,
|
||||
usedevicememory=True
|
||||
)
|
||||
|
||||
self._set_status(ConnectionStatus.CONNECTED)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to connect to RTSP stream {self.rtsp_url}: {e}")
|
||||
self._set_status(ConnectionStatus.ERROR)
|
||||
return False
|
||||
|
||||
def _decode_loop(self):
|
||||
"""Main decode loop running in background thread"""
|
||||
retry_count = 0
|
||||
max_retries = 5
|
||||
|
||||
while not self._stop_flag.is_set():
|
||||
# Initialize connection
|
||||
if not self._init_rtsp_connection():
|
||||
retry_count += 1
|
||||
if retry_count >= max_retries:
|
||||
print(f"Max retries reached for {self.rtsp_url}")
|
||||
self._set_status(ConnectionStatus.ERROR)
|
||||
break
|
||||
|
||||
self._set_status(ConnectionStatus.RECONNECTING)
|
||||
self._stop_flag.wait(timeout=2.0)
|
||||
continue
|
||||
|
||||
retry_count = 0 # Reset on successful connection
|
||||
|
||||
try:
|
||||
# Decode loop - iterate through packets from PyAV
|
||||
for packet in self.container.demux(video=0):
|
||||
if self._stop_flag.is_set():
|
||||
break
|
||||
|
||||
if packet.dts is None:
|
||||
continue
|
||||
|
||||
# Convert packet to numpy array
|
||||
packet_data = np.frombuffer(bytes(packet), dtype=np.uint8)
|
||||
|
||||
# Create PacketData and pass numpy array pointer
|
||||
pkt = nvc.PacketData()
|
||||
pkt.bsl_data = packet_data.ctypes.data
|
||||
pkt.bsl = len(packet_data)
|
||||
|
||||
# Decode using NVDEC
|
||||
decoded_frames = self.decoder.Decode(pkt)
|
||||
|
||||
if not decoded_frames:
|
||||
continue
|
||||
|
||||
# Add frames to ring buffer (thread-safe)
|
||||
with self._buffer_lock:
|
||||
for frame in decoded_frames:
|
||||
self.frame_buffer.append(frame)
|
||||
self.frame_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in decode loop for {self.rtsp_url}: {e}")
|
||||
self._set_status(ConnectionStatus.RECONNECTING)
|
||||
self._cleanup()
|
||||
self._stop_flag.wait(timeout=2.0)
|
||||
|
||||
def _cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if self.container:
|
||||
try:
|
||||
self.container.close()
|
||||
except:
|
||||
pass
|
||||
self.container = None
|
||||
|
||||
self.decoder = None
|
||||
|
||||
with self._buffer_lock:
|
||||
self.frame_buffer.clear()
|
||||
|
||||
def get_frame(self, index: int = -1, rgb: bool = True) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Get a frame from the buffer as a CUDA tensor (in VRAM).
|
||||
|
||||
Args:
|
||||
index: Frame index in buffer (-1 for latest, -2 for second latest, etc.)
|
||||
rgb: If True, convert NV12 to RGB. If False, return raw NV12 format.
|
||||
|
||||
Returns:
|
||||
torch.Tensor in CUDA memory (device tensor) or None if buffer empty
|
||||
- If rgb=True: Shape (3, H, W) in RGB format, dtype uint8
|
||||
- If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8
|
||||
"""
|
||||
with self._buffer_lock:
|
||||
if len(self.frame_buffer) == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
decoded_frame = self.frame_buffer[index]
|
||||
|
||||
# Convert DecodedFrame to PyTorch tensor using DLPack (zero-copy)
|
||||
# This keeps the data in GPU memory
|
||||
nv12_tensor = torch.from_dlpack(decoded_frame)
|
||||
|
||||
if not rgb:
|
||||
# Return raw NV12 format
|
||||
return nv12_tensor
|
||||
|
||||
# Convert NV12 to RGB on GPU
|
||||
if self.frame_height is None or self.frame_width is None:
|
||||
print("Frame dimensions not available")
|
||||
return None
|
||||
|
||||
rgb_tensor = nv12_to_rgb_gpu(nv12_tensor, self.frame_height, self.frame_width)
|
||||
return rgb_tensor
|
||||
|
||||
except (IndexError, Exception) as e:
|
||||
print(f"Error getting frame: {e}")
|
||||
return None
|
||||
|
||||
def get_latest_frame(self, rgb: bool = True) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Get the most recent decoded frame as CUDA tensor.
|
||||
|
||||
Args:
|
||||
rgb: If True, convert to RGB. If False, return raw NV12.
|
||||
|
||||
Returns:
|
||||
torch.Tensor on GPU in RGB (3, H, W) or NV12 (H*3/2, W) format
|
||||
"""
|
||||
return self.get_frame(-1, rgb=rgb)
|
||||
|
||||
def get_frame_cpu(self, index: int = -1, rgb: bool = True) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Get a frame from the buffer and copy it to CPU memory as numpy array.
|
||||
|
||||
Args:
|
||||
index: Frame index in buffer (-1 for latest, -2 for second latest, etc.)
|
||||
rgb: If True, convert NV12 to RGB. If False, return raw NV12 format.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray in CPU memory or None if buffer empty
|
||||
- If rgb=True: Shape (H, W, 3) in RGB format, dtype uint8 (HWC format for easy display)
|
||||
- If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8
|
||||
"""
|
||||
# Get frame on GPU
|
||||
gpu_frame = self.get_frame(index=index, rgb=rgb)
|
||||
|
||||
if gpu_frame is None:
|
||||
return None
|
||||
|
||||
# Transfer from GPU to CPU
|
||||
cpu_tensor = gpu_frame.cpu()
|
||||
|
||||
# Convert to numpy array
|
||||
if rgb:
|
||||
# Convert from (3, H, W) to (H, W, 3) for standard image format
|
||||
cpu_array = cpu_tensor.permute(1, 2, 0).numpy()
|
||||
else:
|
||||
# Keep NV12 format as-is
|
||||
cpu_array = cpu_tensor.numpy()
|
||||
|
||||
return cpu_array
|
||||
|
||||
def get_latest_frame_cpu(self, rgb: bool = True) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Get the most recent decoded frame as CPU numpy array.
|
||||
|
||||
Args:
|
||||
rgb: If True, convert to RGB. If False, return raw NV12.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray in CPU memory
|
||||
- If rgb=True: Shape (H, W, 3) in RGB format, dtype uint8
|
||||
- If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8
|
||||
"""
|
||||
return self.get_frame_cpu(-1, rgb=rgb)
|
||||
|
||||
def get_buffer_size(self) -> int:
|
||||
"""Get current number of frames in buffer"""
|
||||
with self._buffer_lock:
|
||||
return len(self.frame_buffer)
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if stream is actively connected"""
|
||||
return self.get_status() == ConnectionStatus.CONNECTED
|
||||
|
||||
def get_frame_as_jpeg(self, index: int = -1, quality: int = 95) -> Optional[bytes]:
|
||||
"""
|
||||
Get a frame from the buffer and encode to JPEG.
|
||||
|
||||
This method:
|
||||
1. Gets RGB frame from buffer (stays on GPU)
|
||||
2. Encodes to JPEG using nvJPEG (GPU operation via shared encoder)
|
||||
3. Transfers JPEG bytes to CPU
|
||||
4. Returns bytes for saving to disk
|
||||
|
||||
Args:
|
||||
index: Frame index in buffer (-1 for latest)
|
||||
quality: JPEG quality (0-100, default 95)
|
||||
|
||||
Returns:
|
||||
JPEG encoded bytes or None if frame unavailable
|
||||
"""
|
||||
# Get RGB frame (on GPU)
|
||||
rgb_frame = self.get_frame(index=index, rgb=True)
|
||||
|
||||
# Use the shared JPEG encoder from jpeg_encoder module
|
||||
return encode_frame_to_jpeg(rgb_frame, quality=quality)
|
||||
|
||||
def __repr__(self):
|
||||
return (f"StreamDecoder(url={self.rtsp_url}, status={self.status.value}, "
|
||||
f"buffer={self.get_buffer_size()}/{self.buffer_size}, "
|
||||
f"frames_decoded={self.frame_count})")
|
||||
Loading…
Add table
Add a link
Reference in a new issue