feat: inference subsystem and optimization to decoder
This commit is contained in:
commit
3c83a57e44
19 changed files with 3897 additions and 0 deletions
481
services/stream_decoder.py
Normal file
481
services/stream_decoder.py
Normal file
|
|
@ -0,0 +1,481 @@
|
|||
import threading
|
||||
from typing import Optional
|
||||
from collections import deque
|
||||
from enum import Enum
|
||||
import torch
|
||||
import PyNvVideoCodec as nvc
|
||||
import av
|
||||
import numpy as np
|
||||
from cuda.bindings import driver as cuda_driver
|
||||
from .jpeg_encoder import encode_frame_to_jpeg
|
||||
|
||||
|
||||
def nv12_to_rgb_gpu(nv12_tensor: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
Convert NV12 format to RGB on GPU using PyTorch operations.
|
||||
|
||||
NV12 format:
|
||||
- Y plane: height x width (luminance)
|
||||
- UV plane: (height/2) x width (interleaved U and V, subsampled by 2)
|
||||
|
||||
Total tensor size: (height * 3/2) x width
|
||||
|
||||
Args:
|
||||
nv12_tensor: Input tensor in NV12 format, shape (H*3/2, W)
|
||||
height: Original frame height
|
||||
width: Original frame width
|
||||
|
||||
Returns:
|
||||
RGB tensor, shape (3, H, W) in range [0, 255]
|
||||
"""
|
||||
device = nv12_tensor.device
|
||||
|
||||
# Split Y and UV planes
|
||||
y_plane = nv12_tensor[:height, :].float() # (H, W)
|
||||
uv_plane = nv12_tensor[height:, :].float() # (H/2, W)
|
||||
|
||||
# Reshape UV plane to separate U and V channels
|
||||
# UV is interleaved: U0V0U1V1... we need to deinterleave
|
||||
uv_plane = uv_plane.reshape(height // 2, width // 2, 2) # (H/2, W/2, 2)
|
||||
u_plane = uv_plane[:, :, 0] # (H/2, W/2)
|
||||
v_plane = uv_plane[:, :, 1] # (H/2, W/2)
|
||||
|
||||
# Upsample U and V to full resolution using bilinear interpolation
|
||||
u_upsampled = torch.nn.functional.interpolate(
|
||||
u_plane.unsqueeze(0).unsqueeze(0), # (1, 1, H/2, W/2)
|
||||
size=(height, width),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
).squeeze(0).squeeze(0) # (H, W)
|
||||
|
||||
v_upsampled = torch.nn.functional.interpolate(
|
||||
v_plane.unsqueeze(0).unsqueeze(0), # (1, 1, H/2, W/2)
|
||||
size=(height, width),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
).squeeze(0).squeeze(0) # (H, W)
|
||||
|
||||
# YUV to RGB conversion using BT.601 standard
|
||||
# R = Y + 1.402 * (V - 128)
|
||||
# G = Y - 0.344136 * (U - 128) - 0.714136 * (V - 128)
|
||||
# B = Y + 1.772 * (U - 128)
|
||||
|
||||
y = y_plane
|
||||
u = u_upsampled - 128.0
|
||||
v = v_upsampled - 128.0
|
||||
|
||||
r = y + 1.402 * v
|
||||
g = y - 0.344136 * u - 0.714136 * v
|
||||
b = y + 1.772 * u
|
||||
|
||||
# Clamp to [0, 255] and convert to uint8
|
||||
r = torch.clamp(r, 0, 255).to(torch.uint8)
|
||||
g = torch.clamp(g, 0, 255).to(torch.uint8)
|
||||
b = torch.clamp(b, 0, 255).to(torch.uint8)
|
||||
|
||||
# Stack to (3, H, W)
|
||||
rgb = torch.stack([r, g, b], dim=0)
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
class ConnectionStatus(Enum):
|
||||
DISCONNECTED = "disconnected"
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
ERROR = "error"
|
||||
RECONNECTING = "reconnecting"
|
||||
|
||||
|
||||
class StreamDecoderFactory:
|
||||
"""
|
||||
Factory for creating StreamDecoder instances with shared CUDA context.
|
||||
This minimizes VRAM overhead by sharing the CUDA context across all decoders.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, gpu_id: int = 0):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super(StreamDecoderFactory, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, gpu_id: int = 0):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
|
||||
# Initialize CUDA and get device
|
||||
err, = cuda_driver.cuInit(0)
|
||||
if err != cuda_driver.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"Failed to initialize CUDA: {err}")
|
||||
|
||||
# Get CUDA device
|
||||
err, self.cuda_device = cuda_driver.cuDeviceGet(gpu_id)
|
||||
if err != cuda_driver.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"Failed to get CUDA device {gpu_id}: {err}")
|
||||
|
||||
# Retain primary context (shared across all decoders)
|
||||
err, self.cuda_context = cuda_driver.cuDevicePrimaryCtxRetain(self.cuda_device)
|
||||
if err != cuda_driver.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"Failed to retain CUDA primary context: {err}")
|
||||
|
||||
self._initialized = True
|
||||
print(f"StreamDecoderFactory initialized with shared CUDA context on GPU {gpu_id}")
|
||||
|
||||
def create_decoder(self, rtsp_url: str, buffer_size: int = 30,
|
||||
codec: str = "h264") -> 'StreamDecoder':
|
||||
"""
|
||||
Create a new StreamDecoder instance with shared CUDA context.
|
||||
|
||||
Args:
|
||||
rtsp_url: RTSP stream URL
|
||||
buffer_size: Number of frames to buffer in VRAM
|
||||
codec: Video codec (h264, hevc, etc.)
|
||||
|
||||
Returns:
|
||||
StreamDecoder instance
|
||||
"""
|
||||
return StreamDecoder(
|
||||
rtsp_url=rtsp_url,
|
||||
cuda_context=self.cuda_context,
|
||||
gpu_id=self.gpu_id,
|
||||
buffer_size=buffer_size,
|
||||
codec=codec
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup shared CUDA context on factory destruction"""
|
||||
if hasattr(self, 'cuda_device') and hasattr(self, 'gpu_id'):
|
||||
cuda_driver.cuDevicePrimaryCtxRelease(self.cuda_device)
|
||||
|
||||
|
||||
class StreamDecoder:
|
||||
"""
|
||||
Decodes RTSP stream using NVDEC and maintains a ring buffer of frames in GPU VRAM.
|
||||
Thread-safe for concurrent read/write operations.
|
||||
"""
|
||||
|
||||
def __init__(self, rtsp_url: str, cuda_context, gpu_id: int,
|
||||
buffer_size: int = 30, codec: str = "h264"):
|
||||
"""
|
||||
Initialize StreamDecoder.
|
||||
|
||||
Args:
|
||||
rtsp_url: RTSP stream URL
|
||||
cuda_context: Shared CUDA context handle
|
||||
gpu_id: GPU device ID
|
||||
buffer_size: Number of frames to keep in ring buffer
|
||||
codec: Video codec type
|
||||
"""
|
||||
self.rtsp_url = rtsp_url
|
||||
self.cuda_context = cuda_context
|
||||
self.gpu_id = gpu_id
|
||||
self.buffer_size = buffer_size
|
||||
self.codec = codec
|
||||
|
||||
# Connection status
|
||||
self.status = ConnectionStatus.DISCONNECTED
|
||||
self._status_lock = threading.Lock()
|
||||
|
||||
# Frame buffer (ring buffer) - stores CUDA device pointers
|
||||
self.frame_buffer = deque(maxlen=buffer_size)
|
||||
self._buffer_lock = threading.RLock()
|
||||
|
||||
# Decoder and container instances
|
||||
self.decoder = None
|
||||
self.container = None
|
||||
|
||||
# Decode thread
|
||||
self._decode_thread: Optional[threading.Thread] = None
|
||||
self._stop_flag = threading.Event()
|
||||
|
||||
# Frame metadata
|
||||
self.frame_width: Optional[int] = None
|
||||
self.frame_height: Optional[int] = None
|
||||
self.frame_count: int = 0
|
||||
|
||||
def start(self):
|
||||
"""Start the RTSP stream decoding in background thread"""
|
||||
if self._decode_thread is not None and self._decode_thread.is_alive():
|
||||
print(f"Decoder already running for {self.rtsp_url}")
|
||||
return
|
||||
|
||||
self._stop_flag.clear()
|
||||
self._decode_thread = threading.Thread(target=self._decode_loop, daemon=True)
|
||||
self._decode_thread.start()
|
||||
print(f"Started decoder thread for {self.rtsp_url}")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the decoding thread and cleanup resources"""
|
||||
self._stop_flag.set()
|
||||
if self._decode_thread is not None:
|
||||
self._decode_thread.join(timeout=5.0)
|
||||
self._cleanup()
|
||||
print(f"Stopped decoder for {self.rtsp_url}")
|
||||
|
||||
def _set_status(self, status: ConnectionStatus):
|
||||
"""Thread-safe status update"""
|
||||
with self._status_lock:
|
||||
self.status = status
|
||||
|
||||
def get_status(self) -> ConnectionStatus:
|
||||
"""Get current connection status"""
|
||||
with self._status_lock:
|
||||
return self.status
|
||||
|
||||
def _init_rtsp_connection(self) -> bool:
|
||||
"""Initialize RTSP connection using PyAV + PyNvVideoCodec"""
|
||||
try:
|
||||
self._set_status(ConnectionStatus.CONNECTING)
|
||||
|
||||
# Open RTSP stream with PyAV
|
||||
options = {
|
||||
'rtsp_transport': 'tcp',
|
||||
'max_delay': '500000', # 500ms
|
||||
'rtsp_flags': 'prefer_tcp',
|
||||
'timeout': '5000000', # 5 seconds
|
||||
}
|
||||
|
||||
self.container = av.open(self.rtsp_url, options=options)
|
||||
|
||||
# Get video stream
|
||||
video_stream = self.container.streams.video[0]
|
||||
self.frame_width = video_stream.width
|
||||
self.frame_height = video_stream.height
|
||||
|
||||
print(f"RTSP connected: {self.frame_width}x{self.frame_height}")
|
||||
|
||||
# Map codec name to PyNvVideoCodec codec enum
|
||||
codec_map = {
|
||||
'h264': nvc.cudaVideoCodec.H264,
|
||||
'hevc': nvc.cudaVideoCodec.HEVC,
|
||||
'h265': nvc.cudaVideoCodec.HEVC,
|
||||
}
|
||||
|
||||
codec_id = codec_map.get(self.codec.lower(), nvc.cudaVideoCodec.H264)
|
||||
|
||||
# Initialize NVDEC decoder with shared CUDA context
|
||||
self.decoder = nvc.CreateDecoder(
|
||||
gpuid=self.gpu_id,
|
||||
codec=codec_id,
|
||||
cudacontext=self.cuda_context,
|
||||
usedevicememory=True
|
||||
)
|
||||
|
||||
self._set_status(ConnectionStatus.CONNECTED)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to connect to RTSP stream {self.rtsp_url}: {e}")
|
||||
self._set_status(ConnectionStatus.ERROR)
|
||||
return False
|
||||
|
||||
def _decode_loop(self):
|
||||
"""Main decode loop running in background thread"""
|
||||
retry_count = 0
|
||||
max_retries = 5
|
||||
|
||||
while not self._stop_flag.is_set():
|
||||
# Initialize connection
|
||||
if not self._init_rtsp_connection():
|
||||
retry_count += 1
|
||||
if retry_count >= max_retries:
|
||||
print(f"Max retries reached for {self.rtsp_url}")
|
||||
self._set_status(ConnectionStatus.ERROR)
|
||||
break
|
||||
|
||||
self._set_status(ConnectionStatus.RECONNECTING)
|
||||
self._stop_flag.wait(timeout=2.0)
|
||||
continue
|
||||
|
||||
retry_count = 0 # Reset on successful connection
|
||||
|
||||
try:
|
||||
# Decode loop - iterate through packets from PyAV
|
||||
for packet in self.container.demux(video=0):
|
||||
if self._stop_flag.is_set():
|
||||
break
|
||||
|
||||
if packet.dts is None:
|
||||
continue
|
||||
|
||||
# Convert packet to numpy array
|
||||
packet_data = np.frombuffer(bytes(packet), dtype=np.uint8)
|
||||
|
||||
# Create PacketData and pass numpy array pointer
|
||||
pkt = nvc.PacketData()
|
||||
pkt.bsl_data = packet_data.ctypes.data
|
||||
pkt.bsl = len(packet_data)
|
||||
|
||||
# Decode using NVDEC
|
||||
decoded_frames = self.decoder.Decode(pkt)
|
||||
|
||||
if not decoded_frames:
|
||||
continue
|
||||
|
||||
# Add frames to ring buffer (thread-safe)
|
||||
with self._buffer_lock:
|
||||
for frame in decoded_frames:
|
||||
self.frame_buffer.append(frame)
|
||||
self.frame_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in decode loop for {self.rtsp_url}: {e}")
|
||||
self._set_status(ConnectionStatus.RECONNECTING)
|
||||
self._cleanup()
|
||||
self._stop_flag.wait(timeout=2.0)
|
||||
|
||||
def _cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if self.container:
|
||||
try:
|
||||
self.container.close()
|
||||
except:
|
||||
pass
|
||||
self.container = None
|
||||
|
||||
self.decoder = None
|
||||
|
||||
with self._buffer_lock:
|
||||
self.frame_buffer.clear()
|
||||
|
||||
def get_frame(self, index: int = -1, rgb: bool = True) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Get a frame from the buffer as a CUDA tensor (in VRAM).
|
||||
|
||||
Args:
|
||||
index: Frame index in buffer (-1 for latest, -2 for second latest, etc.)
|
||||
rgb: If True, convert NV12 to RGB. If False, return raw NV12 format.
|
||||
|
||||
Returns:
|
||||
torch.Tensor in CUDA memory (device tensor) or None if buffer empty
|
||||
- If rgb=True: Shape (3, H, W) in RGB format, dtype uint8
|
||||
- If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8
|
||||
"""
|
||||
with self._buffer_lock:
|
||||
if len(self.frame_buffer) == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
decoded_frame = self.frame_buffer[index]
|
||||
|
||||
# Convert DecodedFrame to PyTorch tensor using DLPack (zero-copy)
|
||||
# This keeps the data in GPU memory
|
||||
nv12_tensor = torch.from_dlpack(decoded_frame)
|
||||
|
||||
if not rgb:
|
||||
# Return raw NV12 format
|
||||
return nv12_tensor
|
||||
|
||||
# Convert NV12 to RGB on GPU
|
||||
if self.frame_height is None or self.frame_width is None:
|
||||
print("Frame dimensions not available")
|
||||
return None
|
||||
|
||||
rgb_tensor = nv12_to_rgb_gpu(nv12_tensor, self.frame_height, self.frame_width)
|
||||
return rgb_tensor
|
||||
|
||||
except (IndexError, Exception) as e:
|
||||
print(f"Error getting frame: {e}")
|
||||
return None
|
||||
|
||||
def get_latest_frame(self, rgb: bool = True) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Get the most recent decoded frame as CUDA tensor.
|
||||
|
||||
Args:
|
||||
rgb: If True, convert to RGB. If False, return raw NV12.
|
||||
|
||||
Returns:
|
||||
torch.Tensor on GPU in RGB (3, H, W) or NV12 (H*3/2, W) format
|
||||
"""
|
||||
return self.get_frame(-1, rgb=rgb)
|
||||
|
||||
def get_frame_cpu(self, index: int = -1, rgb: bool = True) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Get a frame from the buffer and copy it to CPU memory as numpy array.
|
||||
|
||||
Args:
|
||||
index: Frame index in buffer (-1 for latest, -2 for second latest, etc.)
|
||||
rgb: If True, convert NV12 to RGB. If False, return raw NV12 format.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray in CPU memory or None if buffer empty
|
||||
- If rgb=True: Shape (H, W, 3) in RGB format, dtype uint8 (HWC format for easy display)
|
||||
- If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8
|
||||
"""
|
||||
# Get frame on GPU
|
||||
gpu_frame = self.get_frame(index=index, rgb=rgb)
|
||||
|
||||
if gpu_frame is None:
|
||||
return None
|
||||
|
||||
# Transfer from GPU to CPU
|
||||
cpu_tensor = gpu_frame.cpu()
|
||||
|
||||
# Convert to numpy array
|
||||
if rgb:
|
||||
# Convert from (3, H, W) to (H, W, 3) for standard image format
|
||||
cpu_array = cpu_tensor.permute(1, 2, 0).numpy()
|
||||
else:
|
||||
# Keep NV12 format as-is
|
||||
cpu_array = cpu_tensor.numpy()
|
||||
|
||||
return cpu_array
|
||||
|
||||
def get_latest_frame_cpu(self, rgb: bool = True) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Get the most recent decoded frame as CPU numpy array.
|
||||
|
||||
Args:
|
||||
rgb: If True, convert to RGB. If False, return raw NV12.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray in CPU memory
|
||||
- If rgb=True: Shape (H, W, 3) in RGB format, dtype uint8
|
||||
- If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8
|
||||
"""
|
||||
return self.get_frame_cpu(-1, rgb=rgb)
|
||||
|
||||
def get_buffer_size(self) -> int:
|
||||
"""Get current number of frames in buffer"""
|
||||
with self._buffer_lock:
|
||||
return len(self.frame_buffer)
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if stream is actively connected"""
|
||||
return self.get_status() == ConnectionStatus.CONNECTED
|
||||
|
||||
def get_frame_as_jpeg(self, index: int = -1, quality: int = 95) -> Optional[bytes]:
|
||||
"""
|
||||
Get a frame from the buffer and encode to JPEG.
|
||||
|
||||
This method:
|
||||
1. Gets RGB frame from buffer (stays on GPU)
|
||||
2. Encodes to JPEG using nvJPEG (GPU operation via shared encoder)
|
||||
3. Transfers JPEG bytes to CPU
|
||||
4. Returns bytes for saving to disk
|
||||
|
||||
Args:
|
||||
index: Frame index in buffer (-1 for latest)
|
||||
quality: JPEG quality (0-100, default 95)
|
||||
|
||||
Returns:
|
||||
JPEG encoded bytes or None if frame unavailable
|
||||
"""
|
||||
# Get RGB frame (on GPU)
|
||||
rgb_frame = self.get_frame(index=index, rgb=True)
|
||||
|
||||
# Use the shared JPEG encoder from jpeg_encoder module
|
||||
return encode_frame_to_jpeg(rgb_frame, quality=quality)
|
||||
|
||||
def __repr__(self):
|
||||
return (f"StreamDecoder(url={self.rtsp_url}, status={self.status.value}, "
|
||||
f"buffer={self.get_buffer_size()}/{self.buffer_size}, "
|
||||
f"frames_decoded={self.frame_count})")
|
||||
Loading…
Add table
Add a link
Reference in a new issue