import threading from typing import Optional from collections import deque from enum import Enum import torch import PyNvVideoCodec as nvc import av import numpy as np from cuda.bindings import driver as cuda_driver from .jpeg_encoder import encode_frame_to_jpeg def nv12_to_rgb_gpu(nv12_tensor: torch.Tensor, height: int, width: int) -> torch.Tensor: """ Convert NV12 format to RGB on GPU using PyTorch operations. NV12 format: - Y plane: height x width (luminance) - UV plane: (height/2) x width (interleaved U and V, subsampled by 2) Total tensor size: (height * 3/2) x width Args: nv12_tensor: Input tensor in NV12 format, shape (H*3/2, W) height: Original frame height width: Original frame width Returns: RGB tensor, shape (3, H, W) in range [0, 255] """ device = nv12_tensor.device # Split Y and UV planes y_plane = nv12_tensor[:height, :].float() # (H, W) uv_plane = nv12_tensor[height:, :].float() # (H/2, W) # Reshape UV plane to separate U and V channels # UV is interleaved: U0V0U1V1... we need to deinterleave uv_plane = uv_plane.reshape(height // 2, width // 2, 2) # (H/2, W/2, 2) u_plane = uv_plane[:, :, 0] # (H/2, W/2) v_plane = uv_plane[:, :, 1] # (H/2, W/2) # Upsample U and V to full resolution using bilinear interpolation u_upsampled = torch.nn.functional.interpolate( u_plane.unsqueeze(0).unsqueeze(0), # (1, 1, H/2, W/2) size=(height, width), mode='bilinear', align_corners=False ).squeeze(0).squeeze(0) # (H, W) v_upsampled = torch.nn.functional.interpolate( v_plane.unsqueeze(0).unsqueeze(0), # (1, 1, H/2, W/2) size=(height, width), mode='bilinear', align_corners=False ).squeeze(0).squeeze(0) # (H, W) # YUV to RGB conversion using BT.601 standard # R = Y + 1.402 * (V - 128) # G = Y - 0.344136 * (U - 128) - 0.714136 * (V - 128) # B = Y + 1.772 * (U - 128) y = y_plane u = u_upsampled - 128.0 v = v_upsampled - 128.0 r = y + 1.402 * v g = y - 0.344136 * u - 0.714136 * v b = y + 1.772 * u # Clamp to [0, 255] and convert to uint8 r = torch.clamp(r, 0, 255).to(torch.uint8) g = torch.clamp(g, 0, 255).to(torch.uint8) b = torch.clamp(b, 0, 255).to(torch.uint8) # Stack to (3, H, W) rgb = torch.stack([r, g, b], dim=0) return rgb class ConnectionStatus(Enum): DISCONNECTED = "disconnected" CONNECTING = "connecting" CONNECTED = "connected" ERROR = "error" RECONNECTING = "reconnecting" class StreamDecoderFactory: """ Factory for creating StreamDecoder instances with shared CUDA context. This minimizes VRAM overhead by sharing the CUDA context across all decoders. """ _instance = None _lock = threading.Lock() def __new__(cls, gpu_id: int = 0): if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super(StreamDecoderFactory, cls).__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self, gpu_id: int = 0): if self._initialized: return self.gpu_id = gpu_id # Initialize CUDA and get device err, = cuda_driver.cuInit(0) if err != cuda_driver.CUresult.CUDA_SUCCESS: raise RuntimeError(f"Failed to initialize CUDA: {err}") # Get CUDA device err, self.cuda_device = cuda_driver.cuDeviceGet(gpu_id) if err != cuda_driver.CUresult.CUDA_SUCCESS: raise RuntimeError(f"Failed to get CUDA device {gpu_id}: {err}") # Retain primary context (shared across all decoders) err, self.cuda_context = cuda_driver.cuDevicePrimaryCtxRetain(self.cuda_device) if err != cuda_driver.CUresult.CUDA_SUCCESS: raise RuntimeError(f"Failed to retain CUDA primary context: {err}") self._initialized = True print(f"StreamDecoderFactory initialized with shared CUDA context on GPU {gpu_id}") def create_decoder(self, rtsp_url: str, buffer_size: int = 30, codec: str = "h264") -> 'StreamDecoder': """ Create a new StreamDecoder instance with shared CUDA context. Args: rtsp_url: RTSP stream URL buffer_size: Number of frames to buffer in VRAM codec: Video codec (h264, hevc, etc.) Returns: StreamDecoder instance """ return StreamDecoder( rtsp_url=rtsp_url, cuda_context=self.cuda_context, gpu_id=self.gpu_id, buffer_size=buffer_size, codec=codec ) def __del__(self): """Cleanup shared CUDA context on factory destruction""" if hasattr(self, 'cuda_device') and hasattr(self, 'gpu_id'): cuda_driver.cuDevicePrimaryCtxRelease(self.cuda_device) class StreamDecoder: """ Decodes RTSP stream using NVDEC and maintains a ring buffer of frames in GPU VRAM. Thread-safe for concurrent read/write operations. """ def __init__(self, rtsp_url: str, cuda_context, gpu_id: int, buffer_size: int = 30, codec: str = "h264"): """ Initialize StreamDecoder. Args: rtsp_url: RTSP stream URL cuda_context: Shared CUDA context handle gpu_id: GPU device ID buffer_size: Number of frames to keep in ring buffer codec: Video codec type """ self.rtsp_url = rtsp_url self.cuda_context = cuda_context self.gpu_id = gpu_id self.buffer_size = buffer_size self.codec = codec # Connection status self.status = ConnectionStatus.DISCONNECTED self._status_lock = threading.Lock() # Frame buffer (ring buffer) - stores CUDA device pointers self.frame_buffer = deque(maxlen=buffer_size) self._buffer_lock = threading.RLock() # Decoder and container instances self.decoder = None self.container = None # Decode thread self._decode_thread: Optional[threading.Thread] = None self._stop_flag = threading.Event() # Frame metadata self.frame_width: Optional[int] = None self.frame_height: Optional[int] = None self.frame_count: int = 0 def start(self): """Start the RTSP stream decoding in background thread""" if self._decode_thread is not None and self._decode_thread.is_alive(): print(f"Decoder already running for {self.rtsp_url}") return self._stop_flag.clear() self._decode_thread = threading.Thread(target=self._decode_loop, daemon=True) self._decode_thread.start() print(f"Started decoder thread for {self.rtsp_url}") def stop(self): """Stop the decoding thread and cleanup resources""" self._stop_flag.set() if self._decode_thread is not None: self._decode_thread.join(timeout=5.0) self._cleanup() print(f"Stopped decoder for {self.rtsp_url}") def _set_status(self, status: ConnectionStatus): """Thread-safe status update""" with self._status_lock: self.status = status def get_status(self) -> ConnectionStatus: """Get current connection status""" with self._status_lock: return self.status def _init_rtsp_connection(self) -> bool: """Initialize RTSP connection using PyAV + PyNvVideoCodec""" try: self._set_status(ConnectionStatus.CONNECTING) # Open RTSP stream with PyAV options = { 'rtsp_transport': 'tcp', 'max_delay': '500000', # 500ms 'rtsp_flags': 'prefer_tcp', 'timeout': '5000000', # 5 seconds } self.container = av.open(self.rtsp_url, options=options) # Get video stream video_stream = self.container.streams.video[0] self.frame_width = video_stream.width self.frame_height = video_stream.height print(f"RTSP connected: {self.frame_width}x{self.frame_height}") # Map codec name to PyNvVideoCodec codec enum codec_map = { 'h264': nvc.cudaVideoCodec.H264, 'hevc': nvc.cudaVideoCodec.HEVC, 'h265': nvc.cudaVideoCodec.HEVC, } codec_id = codec_map.get(self.codec.lower(), nvc.cudaVideoCodec.H264) # Initialize NVDEC decoder with shared CUDA context self.decoder = nvc.CreateDecoder( gpuid=self.gpu_id, codec=codec_id, cudacontext=self.cuda_context, usedevicememory=True ) self._set_status(ConnectionStatus.CONNECTED) return True except Exception as e: print(f"Failed to connect to RTSP stream {self.rtsp_url}: {e}") self._set_status(ConnectionStatus.ERROR) return False def _decode_loop(self): """Main decode loop running in background thread""" retry_count = 0 max_retries = 5 while not self._stop_flag.is_set(): # Initialize connection if not self._init_rtsp_connection(): retry_count += 1 if retry_count >= max_retries: print(f"Max retries reached for {self.rtsp_url}") self._set_status(ConnectionStatus.ERROR) break self._set_status(ConnectionStatus.RECONNECTING) self._stop_flag.wait(timeout=2.0) continue retry_count = 0 # Reset on successful connection try: # Decode loop - iterate through packets from PyAV for packet in self.container.demux(video=0): if self._stop_flag.is_set(): break if packet.dts is None: continue # Convert packet to numpy array packet_data = np.frombuffer(bytes(packet), dtype=np.uint8) # Create PacketData and pass numpy array pointer pkt = nvc.PacketData() pkt.bsl_data = packet_data.ctypes.data pkt.bsl = len(packet_data) # Decode using NVDEC decoded_frames = self.decoder.Decode(pkt) if not decoded_frames: continue # Add frames to ring buffer (thread-safe) with self._buffer_lock: for frame in decoded_frames: self.frame_buffer.append(frame) self.frame_count += 1 except Exception as e: print(f"Error in decode loop for {self.rtsp_url}: {e}") self._set_status(ConnectionStatus.RECONNECTING) self._cleanup() self._stop_flag.wait(timeout=2.0) def _cleanup(self): """Cleanup resources""" if self.container: try: self.container.close() except: pass self.container = None self.decoder = None with self._buffer_lock: self.frame_buffer.clear() def get_frame(self, index: int = -1, rgb: bool = True) -> Optional[torch.Tensor]: """ Get a frame from the buffer as a CUDA tensor (in VRAM). Args: index: Frame index in buffer (-1 for latest, -2 for second latest, etc.) rgb: If True, convert NV12 to RGB. If False, return raw NV12 format. Returns: torch.Tensor in CUDA memory (device tensor) or None if buffer empty - If rgb=True: Shape (3, H, W) in RGB format, dtype uint8 - If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8 """ with self._buffer_lock: if len(self.frame_buffer) == 0: return None try: decoded_frame = self.frame_buffer[index] # Convert DecodedFrame to PyTorch tensor using DLPack (zero-copy) # This keeps the data in GPU memory nv12_tensor = torch.from_dlpack(decoded_frame) if not rgb: # Return raw NV12 format return nv12_tensor # Convert NV12 to RGB on GPU if self.frame_height is None or self.frame_width is None: print("Frame dimensions not available") return None rgb_tensor = nv12_to_rgb_gpu(nv12_tensor, self.frame_height, self.frame_width) return rgb_tensor except (IndexError, Exception) as e: print(f"Error getting frame: {e}") return None def get_latest_frame(self, rgb: bool = True) -> Optional[torch.Tensor]: """ Get the most recent decoded frame as CUDA tensor. Args: rgb: If True, convert to RGB. If False, return raw NV12. Returns: torch.Tensor on GPU in RGB (3, H, W) or NV12 (H*3/2, W) format """ return self.get_frame(-1, rgb=rgb) def get_frame_cpu(self, index: int = -1, rgb: bool = True) -> Optional[np.ndarray]: """ Get a frame from the buffer and copy it to CPU memory as numpy array. Args: index: Frame index in buffer (-1 for latest, -2 for second latest, etc.) rgb: If True, convert NV12 to RGB. If False, return raw NV12 format. Returns: numpy.ndarray in CPU memory or None if buffer empty - If rgb=True: Shape (H, W, 3) in RGB format, dtype uint8 (HWC format for easy display) - If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8 """ # Get frame on GPU gpu_frame = self.get_frame(index=index, rgb=rgb) if gpu_frame is None: return None # Transfer from GPU to CPU cpu_tensor = gpu_frame.cpu() # Convert to numpy array if rgb: # Convert from (3, H, W) to (H, W, 3) for standard image format cpu_array = cpu_tensor.permute(1, 2, 0).numpy() else: # Keep NV12 format as-is cpu_array = cpu_tensor.numpy() return cpu_array def get_latest_frame_cpu(self, rgb: bool = True) -> Optional[np.ndarray]: """ Get the most recent decoded frame as CPU numpy array. Args: rgb: If True, convert to RGB. If False, return raw NV12. Returns: numpy.ndarray in CPU memory - If rgb=True: Shape (H, W, 3) in RGB format, dtype uint8 - If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8 """ return self.get_frame_cpu(-1, rgb=rgb) def get_buffer_size(self) -> int: """Get current number of frames in buffer""" with self._buffer_lock: return len(self.frame_buffer) def is_connected(self) -> bool: """Check if stream is actively connected""" return self.get_status() == ConnectionStatus.CONNECTED def get_frame_as_jpeg(self, index: int = -1, quality: int = 95) -> Optional[bytes]: """ Get a frame from the buffer and encode to JPEG. This method: 1. Gets RGB frame from buffer (stays on GPU) 2. Encodes to JPEG using nvJPEG (GPU operation via shared encoder) 3. Transfers JPEG bytes to CPU 4. Returns bytes for saving to disk Args: index: Frame index in buffer (-1 for latest) quality: JPEG quality (0-100, default 95) Returns: JPEG encoded bytes or None if frame unavailable """ # Get RGB frame (on GPU) rgb_frame = self.get_frame(index=index, rgb=True) # Use the shared JPEG encoder from jpeg_encoder module return encode_frame_to_jpeg(rgb_frame, quality=quality) def __repr__(self): return (f"StreamDecoder(url={self.rtsp_url}, status={self.status.value}, " f"buffer={self.get_buffer_size()}/{self.buffer_size}, " f"frames_decoded={self.frame_count})")