python-rtsp-worker/services/stream_decoder.py
2025-11-11 02:28:33 +07:00

656 lines
22 KiB
Python

import threading
from collections import deque
from enum import Enum
from typing import Callable, Optional
import av
import numpy as np
import PyNvVideoCodec as nvc
import torch
from cuda.bindings import driver as cuda_driver
from .jpeg_encoder import encode_frame_to_jpeg
class FrameReference:
"""
Reference-counted frame wrapper for zero-copy memory management.
This allows multiple parts of the pipeline to hold references to the same
cloned frame, and tracks when all references are released so the decoder
knows when buffer slots can be reused.
"""
def __init__(self, rgb_tensor: torch.Tensor, buffer_index: int, decoder):
self.rgb_tensor = rgb_tensor # Cloned RGB tensor (one clone per frame)
self.buffer_index = buffer_index
self.decoder = decoder
self._freed = False
def free(self):
"""Mark this reference as freed - called by the last user of the frame"""
if not self._freed:
self._freed = True
# Release GPU memory immediately
if self.rgb_tensor is not None:
del self.rgb_tensor
self.rgb_tensor = None
self.decoder._mark_frame_free(self.buffer_index)
def is_freed(self) -> bool:
"""Check if this reference has been freed"""
return self._freed
def __del__(self):
"""Auto-free on garbage collection"""
self.free()
def nv12_to_rgb_gpu(nv12_tensor: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
Convert NV12 format to RGB on GPU using PyTorch operations.
NV12 format:
- Y plane: height x width (luminance)
- UV plane: (height/2) x width (interleaved U and V, subsampled by 2)
Total tensor size: (height * 3/2) x width
Args:
nv12_tensor: Input tensor in NV12 format, shape (H*3/2, W)
height: Original frame height
width: Original frame width
Returns:
RGB tensor, shape (3, H, W) in range [0, 255]
"""
device = nv12_tensor.device
# Split Y and UV planes
y_plane = nv12_tensor[:height, :].float() # (H, W)
uv_plane = nv12_tensor[height:, :].float() # (H/2, W)
# Reshape UV plane to separate U and V channels
# UV is interleaved: U0V0U1V1... we need to deinterleave
uv_plane = uv_plane.reshape(height // 2, width // 2, 2) # (H/2, W/2, 2)
u_plane = uv_plane[:, :, 0] # (H/2, W/2)
v_plane = uv_plane[:, :, 1] # (H/2, W/2)
# Upsample U and V to full resolution using bilinear interpolation
u_upsampled = (
torch.nn.functional.interpolate(
u_plane.unsqueeze(0).unsqueeze(0), # (1, 1, H/2, W/2)
size=(height, width),
mode="bilinear",
align_corners=False,
)
.squeeze(0)
.squeeze(0)
) # (H, W)
v_upsampled = (
torch.nn.functional.interpolate(
v_plane.unsqueeze(0).unsqueeze(0), # (1, 1, H/2, W/2)
size=(height, width),
mode="bilinear",
align_corners=False,
)
.squeeze(0)
.squeeze(0)
) # (H, W)
# YUV to RGB conversion using BT.601 standard
# R = Y + 1.402 * (V - 128)
# G = Y - 0.344136 * (U - 128) - 0.714136 * (V - 128)
# B = Y + 1.772 * (U - 128)
y = y_plane
u = u_upsampled - 128.0
v = v_upsampled - 128.0
r = y + 1.402 * v
g = y - 0.344136 * u - 0.714136 * v
b = y + 1.772 * u
# Clamp to [0, 255] and convert to uint8
r = torch.clamp(r, 0, 255).to(torch.uint8)
g = torch.clamp(g, 0, 255).to(torch.uint8)
b = torch.clamp(b, 0, 255).to(torch.uint8)
# Stack to (3, H, W)
rgb = torch.stack([r, g, b], dim=0)
return rgb
class ConnectionStatus(Enum):
DISCONNECTED = "disconnected"
CONNECTING = "connecting"
CONNECTED = "connected"
ERROR = "error"
RECONNECTING = "reconnecting"
class StreamDecoderFactory:
"""
Factory for creating StreamDecoder instances with shared CUDA context.
This minimizes VRAM overhead by sharing the CUDA context across all decoders.
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, gpu_id: int = 0):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(StreamDecoderFactory, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, gpu_id: int = 0):
if self._initialized:
return
self.gpu_id = gpu_id
# Initialize CUDA and get device
(err,) = cuda_driver.cuInit(0)
if err != cuda_driver.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to initialize CUDA: {err}")
# Get CUDA device
err, self.cuda_device = cuda_driver.cuDeviceGet(gpu_id)
if err != cuda_driver.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to get CUDA device {gpu_id}: {err}")
# Retain primary context (shared across all decoders)
err, self.cuda_context = cuda_driver.cuDevicePrimaryCtxRetain(self.cuda_device)
if err != cuda_driver.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"Failed to retain CUDA primary context: {err}")
self._initialized = True
print(
f"StreamDecoderFactory initialized with shared CUDA context on GPU {gpu_id}"
)
def create_decoder(
self, rtsp_url: str, buffer_size: int = 30, codec: str = "h264"
) -> "StreamDecoder":
"""
Create a new StreamDecoder instance with shared CUDA context.
Args:
rtsp_url: RTSP stream URL
buffer_size: Number of frames to buffer in VRAM
codec: Video codec (h264, hevc, etc.)
Returns:
StreamDecoder instance
"""
return StreamDecoder(
rtsp_url=rtsp_url,
cuda_context=self.cuda_context,
gpu_id=self.gpu_id,
buffer_size=buffer_size,
codec=codec,
)
def __del__(self):
"""Cleanup shared CUDA context on factory destruction"""
if hasattr(self, "cuda_device") and hasattr(self, "gpu_id"):
cuda_driver.cuDevicePrimaryCtxRelease(self.cuda_device)
class StreamDecoder:
"""
Decodes RTSP stream using NVDEC and maintains a ring buffer of frames in GPU VRAM.
Thread-safe for concurrent read/write operations.
"""
def __init__(
self,
rtsp_url: str,
cuda_context,
gpu_id: int,
buffer_size: int = 30,
codec: str = "h264",
):
"""
Initialize StreamDecoder.
Args:
rtsp_url: RTSP stream URL
cuda_context: Shared CUDA context handle
gpu_id: GPU device ID
buffer_size: Number of frames to keep in ring buffer
codec: Video codec type
"""
self.rtsp_url = rtsp_url
self.cuda_context = cuda_context
self.gpu_id = gpu_id
self.buffer_size = buffer_size
self.codec = codec
# Connection status
self.status = ConnectionStatus.DISCONNECTED
self._status_lock = threading.Lock()
# Frame buffer (ring buffer) - stores cloned RGB tensors
self.frame_buffer = deque(maxlen=buffer_size)
self._buffer_lock = threading.RLock()
# Decoder and container instances
self.decoder = None
self.container = None
# Decode thread
self._decode_thread: Optional[threading.Thread] = None
self._stop_flag = threading.Event()
# Frame metadata
self.frame_width: Optional[int] = None
self.frame_height: Optional[int] = None
self.frame_count: int = 0
# Frame callbacks - event-driven notification
self._frame_callbacks = []
self._callback_lock = threading.Lock()
# Track frames currently in use (referenced by callbacks/pipeline)
self._in_use_frames = [] # List of FrameReference objects
self._frame_index_counter = 0 # Monotonically increasing frame index
def register_frame_callback(self, callback: Callable):
"""
Register a callback to be called when a new frame is decoded.
The callback will be called with the decoded frame tensor (GPU) as argument.
Callback signature: callback(frame: torch.Tensor) -> None
Args:
callback: Function to call when new frame arrives
"""
with self._callback_lock:
self._frame_callbacks.append(callback)
def unregister_frame_callback(self, callback: Callable):
"""
Unregister a frame callback.
Args:
callback: The callback function to remove
"""
with self._callback_lock:
if callback in self._frame_callbacks:
self._frame_callbacks.remove(callback)
def _mark_frame_free(self, buffer_index: int):
"""
Mark a frame as freed (called by FrameReference when it's no longer in use).
Args:
buffer_index: Index in the buffer for tracking purposes
"""
with self._buffer_lock:
# Remove from in-use tracking
self._in_use_frames = [
f for f in self._in_use_frames if f.buffer_index != buffer_index
]
def start(self):
"""Start the RTSP stream decoding in background thread"""
if self._decode_thread is not None and self._decode_thread.is_alive():
print(f"Decoder already running for {self.rtsp_url}")
return
self._stop_flag.clear()
self._decode_thread = threading.Thread(target=self._decode_loop, daemon=True)
self._decode_thread.start()
print(f"Started decoder thread for {self.rtsp_url}")
def stop(self):
"""Stop the decoding thread and cleanup resources"""
self._stop_flag.set()
if self._decode_thread is not None:
self._decode_thread.join(timeout=5.0)
self._cleanup()
print(f"Stopped decoder for {self.rtsp_url}")
def _set_status(self, status: ConnectionStatus):
"""Thread-safe status update"""
with self._status_lock:
self.status = status
def get_status(self) -> ConnectionStatus:
"""Get current connection status"""
with self._status_lock:
return self.status
def _init_rtsp_connection(self) -> bool:
"""Initialize RTSP connection using PyAV + PyNvVideoCodec"""
try:
self._set_status(ConnectionStatus.CONNECTING)
# Open RTSP stream with PyAV
options = {
"rtsp_transport": "tcp",
"max_delay": "500000", # 500ms
"rtsp_flags": "prefer_tcp",
"timeout": "5000000", # 5 seconds
}
self.container = av.open(self.rtsp_url, options=options)
# Get video stream
video_stream = self.container.streams.video[0]
self.frame_width = video_stream.width
self.frame_height = video_stream.height
print(f"RTSP connected: {self.frame_width}x{self.frame_height}")
# Map codec name to PyNvVideoCodec codec enum
codec_map = {
"h264": nvc.cudaVideoCodec.H264,
"hevc": nvc.cudaVideoCodec.HEVC,
"h265": nvc.cudaVideoCodec.HEVC,
}
codec_id = codec_map.get(self.codec.lower(), nvc.cudaVideoCodec.H264)
# Initialize NVDEC decoder with shared CUDA context
self.decoder = nvc.CreateDecoder(
gpuid=self.gpu_id,
codec=codec_id,
cudacontext=self.cuda_context,
usedevicememory=True,
)
self._set_status(ConnectionStatus.CONNECTED)
return True
except Exception as e:
print(f"Failed to connect to RTSP stream {self.rtsp_url}: {e}")
self._set_status(ConnectionStatus.ERROR)
return False
def _decode_loop(self):
"""Main decode loop running in background thread"""
# Set the CUDA device for this thread
torch.cuda.set_device(self.gpu_id)
retry_count = 0
max_retries = 5
while not self._stop_flag.is_set():
# Initialize connection
if not self._init_rtsp_connection():
retry_count += 1
if retry_count >= max_retries:
print(f"Max retries reached for {self.rtsp_url}")
self._set_status(ConnectionStatus.ERROR)
break
self._set_status(ConnectionStatus.RECONNECTING)
self._stop_flag.wait(timeout=2.0)
continue
retry_count = 0 # Reset on successful connection
try:
# Decode loop - iterate through packets from PyAV
for packet in self.container.demux(video=0):
if self._stop_flag.is_set():
break
if packet.dts is None:
continue
# Convert packet to numpy array
packet_data = np.frombuffer(bytes(packet), dtype=np.uint8)
# Create PacketData and pass numpy array pointer
pkt = nvc.PacketData()
pkt.bsl_data = packet_data.ctypes.data
pkt.bsl = len(packet_data)
# Decode using NVDEC
decoded_frames = self.decoder.Decode(pkt)
if not decoded_frames:
continue
# Add frames to ring buffer and fire callbacks
with self._buffer_lock:
for frame in decoded_frames:
# Convert to tensor immediately after NVDEC
try:
# Convert DecodedFrame to PyTorch tensor using DLPack (zero-copy)
nv12_tensor = torch.from_dlpack(frame)
# Convert NV12 to RGB on GPU
if (
self.frame_height is not None
and self.frame_width is not None
):
rgb_tensor = nv12_to_rgb_gpu(
nv12_tensor, self.frame_height, self.frame_width
)
# CLONE ONCE into our post-decode buffer
# This breaks the dependency on PyNvVideoCodec's DecodedFrame
# After this, the tensor is fully ours and can be used throughout the pipeline
rgb_cloned = rgb_tensor.clone()
# Create FrameReference for reference counting
frame_ref = FrameReference(
rgb_tensor=rgb_cloned,
buffer_index=self._frame_index_counter,
decoder=self,
)
self._frame_index_counter += 1
# Add FrameReference to ring buffer (deque automatically removes oldest when full)
self.frame_buffer.append(frame_ref)
self.frame_count += 1
# Track this frame as in-use
self._in_use_frames.append(frame_ref)
# Fire callbacks with the FrameReference
# The callback receivers should call .free() when done
with self._callback_lock:
for callback in self._frame_callbacks:
try:
callback(frame_ref)
except Exception as e:
print(f"Error in frame callback: {e}")
except Exception as e:
print(f"Error converting frame for callback: {e}")
except Exception as e:
print(f"Error in decode loop for {self.rtsp_url}: {e}")
self._set_status(ConnectionStatus.RECONNECTING)
self._cleanup()
self._stop_flag.wait(timeout=2.0)
def _cleanup(self):
"""Cleanup resources"""
if self.container:
try:
self.container.close()
except:
pass
self.container = None
self.decoder = None
with self._buffer_lock:
self.frame_buffer.clear()
def get_frame(self, index: int = -1, rgb: bool = True) -> Optional[torch.Tensor]:
"""
Get a frame from the buffer as a CUDA tensor (in VRAM).
Args:
index: Frame index in buffer (-1 for latest, -2 for second latest, etc.)
rgb: If True, return RGB tensor. If False, not supported (returns None).
Returns:
torch.Tensor in CUDA memory (device tensor) or None if buffer empty
- If rgb=True: Shape (3, H, W) in RGB format, dtype uint8
- If rgb=False: Not supported with FrameReference (returns None)
"""
with self._buffer_lock:
if len(self.frame_buffer) == 0:
return None
if not rgb:
print(
"Warning: NV12 format not supported with FrameReference, only RGB"
)
return None
try:
frame_ref = self.frame_buffer[index]
# Return the RGB tensor from FrameReference (cloned, independent)
return frame_ref.rgb_tensor
except (IndexError, Exception) as e:
print(f"Error getting frame: {e}")
return None
def get_latest_frame(self, rgb: bool = True) -> Optional[torch.Tensor]:
"""
Get the most recent decoded frame as CUDA tensor.
Args:
rgb: If True, convert to RGB. If False, return raw NV12.
Returns:
torch.Tensor on GPU in RGB (3, H, W) or NV12 (H*3/2, W) format
"""
return self.get_frame(-1, rgb=rgb)
def get_frame_cpu(self, index: int = -1, rgb: bool = True) -> Optional[np.ndarray]:
"""
Get a frame from the buffer and copy it to CPU memory as numpy array.
Args:
index: Frame index in buffer (-1 for latest, -2 for second latest, etc.)
rgb: If True, convert NV12 to RGB. If False, return raw NV12 format.
Returns:
numpy.ndarray in CPU memory or None if buffer empty
- If rgb=True: Shape (H, W, 3) in RGB format, dtype uint8 (HWC format for easy display)
- If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8
"""
# Get frame on GPU
gpu_frame = self.get_frame(index=index, rgb=rgb)
if gpu_frame is None:
return None
# Transfer from GPU to CPU
cpu_tensor = gpu_frame.cpu()
# Convert to numpy array
if rgb:
# Convert from (3, H, W) to (H, W, 3) for standard image format
cpu_array = cpu_tensor.permute(1, 2, 0).numpy()
else:
# Keep NV12 format as-is
cpu_array = cpu_tensor.numpy()
return cpu_array
def get_latest_frame_cpu(self, rgb: bool = True) -> Optional[np.ndarray]:
"""
Get the most recent decoded frame as CPU numpy array.
Args:
rgb: If True, convert to RGB. If False, return raw NV12.
Returns:
numpy.ndarray in CPU memory
- If rgb=True: Shape (H, W, 3) in RGB format, dtype uint8
- If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8
"""
return self.get_frame_cpu(-1, rgb=rgb)
def get_buffer_size(self) -> int:
"""Get current number of frames in buffer"""
with self._buffer_lock:
return len(self.frame_buffer)
def get_all_frames(self, rgb: bool = True) -> list:
"""
Get all frames currently in the buffer as CUDA tensors.
This drains the buffer and returns all frames.
Args:
rgb: If True, return RGB tensors. If False, not supported (returns empty list).
Returns:
List of torch.Tensor objects in CUDA memory
"""
if not rgb:
print("Warning: NV12 format not supported with FrameReference, only RGB")
return []
frames = []
with self._buffer_lock:
# Get all frames from buffer
for frame_ref in self.frame_buffer:
try:
# Get RGB tensor from FrameReference
frames.append(frame_ref.rgb_tensor)
except Exception as e:
print(f"Error getting frame: {e}")
continue
# Clear the buffer after reading all frames and free all references
for frame_ref in self.frame_buffer:
frame_ref.free()
self.frame_buffer.clear()
return frames
def get_frame_count(self) -> int:
"""Get total number of frames decoded since start"""
return self.frame_count
def is_connected(self) -> bool:
"""Check if stream is actively connected"""
return self.get_status() == ConnectionStatus.CONNECTED
def get_frame_as_jpeg(self, index: int = -1, quality: int = 95) -> Optional[bytes]:
"""
Get a frame from the buffer and encode to JPEG.
This method:
1. Gets RGB frame from buffer (stays on GPU)
2. Encodes to JPEG using nvJPEG (GPU operation via shared encoder)
3. Transfers JPEG bytes to CPU
4. Returns bytes for saving to disk
Args:
index: Frame index in buffer (-1 for latest)
quality: JPEG quality (0-100, default 95)
Returns:
JPEG encoded bytes or None if frame unavailable
"""
# Get RGB frame (on GPU)
rgb_frame = self.get_frame(index=index, rgb=True)
# Use the shared JPEG encoder from jpeg_encoder module
return encode_frame_to_jpeg(rgb_frame, quality=quality)
def __repr__(self):
return (
f"StreamDecoder(url={self.rtsp_url}, status={self.status.value}, "
f"buffer={self.get_buffer_size()}/{self.buffer_size}, "
f"frames_decoded={self.frame_count})"
)