python-rtsp-worker/services/stream_decoder.py

631 lines
23 KiB
Python

import threading
from typing import Optional, Callable
from collections import deque
from enum import Enum
import torch
import PyNvVideoCodec as nvc
import av
import numpy as np
from cuda.bindings import driver as cuda_driver
from .jpeg_encoder import encode_frame_to_jpeg
class FrameReference:
"""
CPU-side reference object for a GPU frame.
This object holds a cloned RGB tensor that is independent of PyNvVideoCodec's
DecodedFrame lifecycle. We don't keep the DecodedFrame to avoid conflicts
with PyNvVideoCodec's internal frame pool management.
"""
def __init__(self, rgb_tensor: torch.Tensor, buffer_index: int, decoder):
self.rgb_tensor = rgb_tensor # Cloned RGB tensor (independent copy)
self.buffer_index = buffer_index
self.decoder = decoder # Reference to decoder for marking as free
self._freed = False
def free(self):
"""Mark this frame as no longer in use"""
if not self._freed:
self._freed = True
self.decoder._mark_frame_free(self.buffer_index)
def is_freed(self) -> bool:
"""Check if this frame has been freed"""
return self._freed
def __del__(self):
"""Auto-free on garbage collection"""
self.free()
def nv12_to_rgb_gpu(nv12_tensor: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
Convert NV12 format to RGB on GPU using PyTorch operations.
NV12 format:
- Y plane: height x width (luminance)
- UV plane: (height/2) x width (interleaved U and V, subsampled by 2)
Total tensor size: (height * 3/2) x width
Args:
nv12_tensor: Input tensor in NV12 format, shape (H*3/2, W)
height: Original frame height
width: Original frame width
Returns:
RGB tensor, shape (3, H, W) in range [0, 255]
"""
device = nv12_tensor.device
# Split Y and UV planes
y_plane = nv12_tensor[:height, :].float() # (H, W)
uv_plane = nv12_tensor[height:, :].float() # (H/2, W)
# Reshape UV plane to separate U and V channels
# UV is interleaved: U0V0U1V1... we need to deinterleave
uv_plane = uv_plane.reshape(height // 2, width // 2, 2) # (H/2, W/2, 2)
u_plane = uv_plane[:, :, 0] # (H/2, W/2)
v_plane = uv_plane[:, :, 1] # (H/2, W/2)
# Upsample U and V to full resolution using bilinear interpolation
u_upsampled = torch.nn.functional.interpolate(
u_plane.unsqueeze(0).unsqueeze(0), # (1, 1, H/2, W/2)
size=(height, width),
mode='bilinear',
align_corners=False
).squeeze(0).squeeze(0) # (H, W)
v_upsampled = torch.nn.functional.interpolate(
v_plane.unsqueeze(0).unsqueeze(0), # (1, 1, H/2, W/2)
size=(height, width),
mode='bilinear',
align_corners=False
).squeeze(0).squeeze(0) # (H, W)
# YUV to RGB conversion using BT.601 standard
# R = Y + 1.402 * (V - 128)
# G = Y - 0.344136 * (U - 128) - 0.714136 * (V - 128)
# B = Y + 1.772 * (U - 128)
y = y_plane
u = u_upsampled - 128.0
v = v_upsampled - 128.0
r = y + 1.402 * v
g = y - 0.344136 * u - 0.714136 * v
b = y + 1.772 * u
# Clamp to [0, 255] and convert to uint8
r = torch.clamp(r, 0, 255).to(torch.uint8)
g = torch.clamp(g, 0, 255).to(torch.uint8)
b = torch.clamp(b, 0, 255).to(torch.uint8)
# Stack to (3, H, W)
rgb = torch.stack([r, g, b], dim=0)
return rgb
class ConnectionStatus(Enum):
DISCONNECTED = "disconnected"
CONNECTING = "connecting"
CONNECTED = "connected"
ERROR = "error"
RECONNECTING = "reconnecting"
class StreamDecoderFactory:
"""
Factory for creating StreamDecoder instances with shared CUDA context.
This minimizes VRAM overhead by sharing the CUDA context across all decoders.
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, gpu_id: int = 0):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(StreamDecoderFactory, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, gpu_id: int = 0):
if self._initialized:
return
self.gpu_id = gpu_id
# Initialize CUDA and get device
err, = cuda_driver.cuInit(0)
if err != cuda_driver.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to initialize CUDA: {err}")
# Get CUDA device
err, self.cuda_device = cuda_driver.cuDeviceGet(gpu_id)
if err != cuda_driver.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to get CUDA device {gpu_id}: {err}")
# Retain primary context (shared across all decoders)
err, self.cuda_context = cuda_driver.cuDevicePrimaryCtxRetain(self.cuda_device)
if err != cuda_driver.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to retain CUDA primary context: {err}")
self._initialized = True
print(f"StreamDecoderFactory initialized with shared CUDA context on GPU {gpu_id}")
def create_decoder(self, rtsp_url: str, buffer_size: int = 30,
codec: str = "h264") -> 'StreamDecoder':
"""
Create a new StreamDecoder instance with shared CUDA context.
Args:
rtsp_url: RTSP stream URL
buffer_size: Number of frames to buffer in VRAM
codec: Video codec (h264, hevc, etc.)
Returns:
StreamDecoder instance
"""
return StreamDecoder(
rtsp_url=rtsp_url,
cuda_context=self.cuda_context,
gpu_id=self.gpu_id,
buffer_size=buffer_size,
codec=codec
)
def __del__(self):
"""Cleanup shared CUDA context on factory destruction"""
if hasattr(self, 'cuda_device') and hasattr(self, 'gpu_id'):
cuda_driver.cuDevicePrimaryCtxRelease(self.cuda_device)
class StreamDecoder:
"""
Decodes RTSP stream using NVDEC and maintains a ring buffer of frames in GPU VRAM.
Thread-safe for concurrent read/write operations.
"""
def __init__(self, rtsp_url: str, cuda_context, gpu_id: int,
buffer_size: int = 30, codec: str = "h264"):
"""
Initialize StreamDecoder.
Args:
rtsp_url: RTSP stream URL
cuda_context: Shared CUDA context handle
gpu_id: GPU device ID
buffer_size: Number of frames to keep in ring buffer
codec: Video codec type
"""
self.rtsp_url = rtsp_url
self.cuda_context = cuda_context
self.gpu_id = gpu_id
self.buffer_size = buffer_size
self.codec = codec
# Connection status
self.status = ConnectionStatus.DISCONNECTED
self._status_lock = threading.Lock()
# Frame buffer (ring buffer) - stores FrameReference objects
self.frame_buffer = deque(maxlen=buffer_size)
self._buffer_lock = threading.RLock()
# Track which buffer slots are in use (list of FrameReference objects)
self._in_use_frames = [] # List of FrameReference objects currently held by callbacks
# Decoder and container instances
self.decoder = None
self.container = None
# Decode thread
self._decode_thread: Optional[threading.Thread] = None
self._stop_flag = threading.Event()
# Frame metadata
self.frame_width: Optional[int] = None
self.frame_height: Optional[int] = None
self.frame_count: int = 0
# Frame callbacks - event-driven notification
self._frame_callbacks = []
self._callback_lock = threading.Lock()
def register_frame_callback(self, callback: Callable):
"""
Register a callback to be called when a new frame is decoded.
The callback will be called with the decoded frame tensor (GPU) as argument.
Callback signature: callback(frame: torch.Tensor) -> None
Args:
callback: Function to call when new frame arrives
"""
with self._callback_lock:
self._frame_callbacks.append(callback)
def unregister_frame_callback(self, callback: Callable):
"""
Unregister a frame callback.
Args:
callback: The callback function to remove
"""
with self._callback_lock:
if callback in self._frame_callbacks:
self._frame_callbacks.remove(callback)
def _mark_frame_free(self, buffer_index: int):
"""
Mark a frame as freed (called by FrameReference when it's no longer in use).
Args:
buffer_index: Index in the buffer for tracking purposes
"""
with self._buffer_lock:
# Remove from in-use tracking
self._in_use_frames = [f for f in self._in_use_frames if f.buffer_index != buffer_index]
def start(self):
"""Start the RTSP stream decoding in background thread"""
if self._decode_thread is not None and self._decode_thread.is_alive():
print(f"Decoder already running for {self.rtsp_url}")
return
self._stop_flag.clear()
self._decode_thread = threading.Thread(target=self._decode_loop, daemon=True)
self._decode_thread.start()
print(f"Started decoder thread for {self.rtsp_url}")
def stop(self):
"""Stop the decoding thread and cleanup resources"""
self._stop_flag.set()
if self._decode_thread is not None:
self._decode_thread.join(timeout=5.0)
self._cleanup()
print(f"Stopped decoder for {self.rtsp_url}")
def _set_status(self, status: ConnectionStatus):
"""Thread-safe status update"""
with self._status_lock:
self.status = status
def get_status(self) -> ConnectionStatus:
"""Get current connection status"""
with self._status_lock:
return self.status
def _init_rtsp_connection(self) -> bool:
"""Initialize RTSP connection using PyAV + PyNvVideoCodec"""
try:
self._set_status(ConnectionStatus.CONNECTING)
# Open RTSP stream with PyAV
options = {
'rtsp_transport': 'tcp',
'max_delay': '500000', # 500ms
'rtsp_flags': 'prefer_tcp',
'timeout': '5000000', # 5 seconds
}
self.container = av.open(self.rtsp_url, options=options)
# Get video stream
video_stream = self.container.streams.video[0]
self.frame_width = video_stream.width
self.frame_height = video_stream.height
print(f"RTSP connected: {self.frame_width}x{self.frame_height}")
# Map codec name to PyNvVideoCodec codec enum
codec_map = {
'h264': nvc.cudaVideoCodec.H264,
'hevc': nvc.cudaVideoCodec.HEVC,
'h265': nvc.cudaVideoCodec.HEVC,
}
codec_id = codec_map.get(self.codec.lower(), nvc.cudaVideoCodec.H264)
# Initialize NVDEC decoder with shared CUDA context
self.decoder = nvc.CreateDecoder(
gpuid=self.gpu_id,
codec=codec_id,
cudacontext=self.cuda_context,
usedevicememory=True
)
self._set_status(ConnectionStatus.CONNECTED)
return True
except Exception as e:
print(f"Failed to connect to RTSP stream {self.rtsp_url}: {e}")
self._set_status(ConnectionStatus.ERROR)
return False
def _decode_loop(self):
"""Main decode loop running in background thread"""
# Set the CUDA device for this thread
torch.cuda.set_device(self.gpu_id)
retry_count = 0
max_retries = 5
while not self._stop_flag.is_set():
# Initialize connection
if not self._init_rtsp_connection():
retry_count += 1
if retry_count >= max_retries:
print(f"Max retries reached for {self.rtsp_url}")
self._set_status(ConnectionStatus.ERROR)
break
self._set_status(ConnectionStatus.RECONNECTING)
self._stop_flag.wait(timeout=2.0)
continue
retry_count = 0 # Reset on successful connection
try:
# Decode loop - iterate through packets from PyAV
for packet in self.container.demux(video=0):
if self._stop_flag.is_set():
break
if packet.dts is None:
continue
# Convert packet to numpy array
packet_data = np.frombuffer(bytes(packet), dtype=np.uint8)
# Create PacketData and pass numpy array pointer
pkt = nvc.PacketData()
pkt.bsl_data = packet_data.ctypes.data
pkt.bsl = len(packet_data)
# Decode using NVDEC
decoded_frames = self.decoder.Decode(pkt)
if not decoded_frames:
continue
# Add frames to ring buffer and fire callbacks
with self._buffer_lock:
for frame in decoded_frames:
# Check for buffer overflow - discard oldest if needed
if len(self.frame_buffer) >= self.buffer_size:
# Check if oldest frame is still in use
if len(self._in_use_frames) > 0:
oldest_ref = self.frame_buffer[0] if len(self.frame_buffer) > 0 else None
if oldest_ref and not oldest_ref.is_freed():
# Force free the oldest frame to prevent overflow
print(f"[WARNING] Buffer overflow, force-freeing oldest frame (buffer_index={oldest_ref.buffer_index})")
oldest_ref.free()
# Deque will automatically remove oldest when at maxlen
# Convert to tensor
try:
# Convert DecodedFrame to PyTorch tensor using DLPack (zero-copy)
nv12_tensor = torch.from_dlpack(frame)
# Convert NV12 to RGB on GPU
if self.frame_height is not None and self.frame_width is not None:
rgb_tensor = nv12_to_rgb_gpu(nv12_tensor, self.frame_height, self.frame_width)
# CRITICAL: Clone the RGB tensor to break CUDA memory dependency
# The nv12_to_rgb_gpu creates a new tensor, but it still references
# the same CUDA context/stream. We need an independent copy.
rgb_tensor_cloned = rgb_tensor.clone()
# Create FrameReference object for C++-style memory management
# We don't keep the DecodedFrame to avoid conflicts with PyNvVideoCodec's
# internal frame pool - the clone is fully independent
buffer_index = self.frame_count
frame_ref = FrameReference(
rgb_tensor=rgb_tensor_cloned, # Independent cloned tensor
buffer_index=buffer_index,
decoder=self
)
# Add to buffer and in-use tracking
self.frame_buffer.append(frame_ref)
self._in_use_frames.append(frame_ref)
self.frame_count += 1
# Fire callbacks with the cloned RGB tensor from FrameReference
# The tensor is now independent of the DecodedFrame lifecycle
with self._callback_lock:
for callback in self._frame_callbacks:
try:
callback(frame_ref.rgb_tensor)
except Exception as e:
print(f"Error in frame callback: {e}")
except Exception as e:
print(f"Error converting frame for callback: {e}")
except Exception as e:
print(f"Error in decode loop for {self.rtsp_url}: {e}")
self._set_status(ConnectionStatus.RECONNECTING)
self._cleanup()
self._stop_flag.wait(timeout=2.0)
def _cleanup(self):
"""Cleanup resources"""
if self.container:
try:
self.container.close()
except:
pass
self.container = None
self.decoder = None
with self._buffer_lock:
self.frame_buffer.clear()
def get_frame(self, index: int = -1, rgb: bool = True) -> Optional[torch.Tensor]:
"""
Get a frame from the buffer as a CUDA tensor (in VRAM).
Args:
index: Frame index in buffer (-1 for latest, -2 for second latest, etc.)
rgb: If True, return RGB tensor. If False, not supported (returns None).
Returns:
torch.Tensor in CUDA memory (device tensor) or None if buffer empty
- If rgb=True: Shape (3, H, W) in RGB format, dtype uint8
- If rgb=False: Not supported with FrameReference (returns None)
"""
with self._buffer_lock:
if len(self.frame_buffer) == 0:
return None
if not rgb:
print("Warning: NV12 format not supported with FrameReference, only RGB")
return None
try:
frame_ref = self.frame_buffer[index]
# Return the RGB tensor from FrameReference (cloned, independent)
return frame_ref.rgb_tensor
except (IndexError, Exception) as e:
print(f"Error getting frame: {e}")
return None
def get_latest_frame(self, rgb: bool = True) -> Optional[torch.Tensor]:
"""
Get the most recent decoded frame as CUDA tensor.
Args:
rgb: If True, convert to RGB. If False, return raw NV12.
Returns:
torch.Tensor on GPU in RGB (3, H, W) or NV12 (H*3/2, W) format
"""
return self.get_frame(-1, rgb=rgb)
def get_frame_cpu(self, index: int = -1, rgb: bool = True) -> Optional[np.ndarray]:
"""
Get a frame from the buffer and copy it to CPU memory as numpy array.
Args:
index: Frame index in buffer (-1 for latest, -2 for second latest, etc.)
rgb: If True, convert NV12 to RGB. If False, return raw NV12 format.
Returns:
numpy.ndarray in CPU memory or None if buffer empty
- If rgb=True: Shape (H, W, 3) in RGB format, dtype uint8 (HWC format for easy display)
- If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8
"""
# Get frame on GPU
gpu_frame = self.get_frame(index=index, rgb=rgb)
if gpu_frame is None:
return None
# Transfer from GPU to CPU
cpu_tensor = gpu_frame.cpu()
# Convert to numpy array
if rgb:
# Convert from (3, H, W) to (H, W, 3) for standard image format
cpu_array = cpu_tensor.permute(1, 2, 0).numpy()
else:
# Keep NV12 format as-is
cpu_array = cpu_tensor.numpy()
return cpu_array
def get_latest_frame_cpu(self, rgb: bool = True) -> Optional[np.ndarray]:
"""
Get the most recent decoded frame as CPU numpy array.
Args:
rgb: If True, convert to RGB. If False, return raw NV12.
Returns:
numpy.ndarray in CPU memory
- If rgb=True: Shape (H, W, 3) in RGB format, dtype uint8
- If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8
"""
return self.get_frame_cpu(-1, rgb=rgb)
def get_buffer_size(self) -> int:
"""Get current number of frames in buffer"""
with self._buffer_lock:
return len(self.frame_buffer)
def get_all_frames(self, rgb: bool = True) -> list:
"""
Get all frames currently in the buffer as CUDA tensors.
This drains the buffer and returns all frames.
Args:
rgb: If True, return RGB tensors. If False, not supported (returns empty list).
Returns:
List of torch.Tensor objects in CUDA memory
"""
if not rgb:
print("Warning: NV12 format not supported with FrameReference, only RGB")
return []
frames = []
with self._buffer_lock:
# Get all frames from buffer
for frame_ref in self.frame_buffer:
try:
# Get RGB tensor from FrameReference
frames.append(frame_ref.rgb_tensor)
except Exception as e:
print(f"Error getting frame: {e}")
continue
# Clear the buffer after reading all frames and free all references
for frame_ref in self.frame_buffer:
frame_ref.free()
self.frame_buffer.clear()
return frames
def get_frame_count(self) -> int:
"""Get total number of frames decoded since start"""
return self.frame_count
def is_connected(self) -> bool:
"""Check if stream is actively connected"""
return self.get_status() == ConnectionStatus.CONNECTED
def get_frame_as_jpeg(self, index: int = -1, quality: int = 95) -> Optional[bytes]:
"""
Get a frame from the buffer and encode to JPEG.
This method:
1. Gets RGB frame from buffer (stays on GPU)
2. Encodes to JPEG using nvJPEG (GPU operation via shared encoder)
3. Transfers JPEG bytes to CPU
4. Returns bytes for saving to disk
Args:
index: Frame index in buffer (-1 for latest)
quality: JPEG quality (0-100, default 95)
Returns:
JPEG encoded bytes or None if frame unavailable
"""
# Get RGB frame (on GPU)
rgb_frame = self.get_frame(index=index, rgb=True)
# Use the shared JPEG encoder from jpeg_encoder module
return encode_frame_to_jpeg(rgb_frame, quality=quality)
def __repr__(self):
return (f"StreamDecoder(url={self.rtsp_url}, status={self.status.value}, "
f"buffer={self.get_buffer_size()}/{self.buffer_size}, "
f"frames_decoded={self.frame_count})")