import threading from collections import deque from enum import Enum from typing import Callable, Optional import av import numpy as np import PyNvVideoCodec as nvc import torch from cuda.bindings import driver as cuda_driver from .jpeg_encoder import encode_frame_to_jpeg class FrameReference: """ Reference-counted frame wrapper for zero-copy memory management. This allows multiple parts of the pipeline to hold references to the same cloned frame, and tracks when all references are released so the decoder knows when buffer slots can be reused. """ def __init__(self, rgb_tensor: torch.Tensor, buffer_index: int, decoder): self.rgb_tensor = rgb_tensor # Cloned RGB tensor (one clone per frame) self.buffer_index = buffer_index self.decoder = decoder self._freed = False def free(self): """Mark this reference as freed - called by the last user of the frame""" if not self._freed: self._freed = True # Release GPU memory immediately if self.rgb_tensor is not None: del self.rgb_tensor self.rgb_tensor = None self.decoder._mark_frame_free(self.buffer_index) def is_freed(self) -> bool: """Check if this reference has been freed""" return self._freed def __del__(self): """Auto-free on garbage collection""" self.free() def nv12_to_rgb_gpu(nv12_tensor: torch.Tensor, height: int, width: int) -> torch.Tensor: """ Convert NV12 format to RGB on GPU using PyTorch operations. NV12 format: - Y plane: height x width (luminance) - UV plane: (height/2) x width (interleaved U and V, subsampled by 2) Total tensor size: (height * 3/2) x width Args: nv12_tensor: Input tensor in NV12 format, shape (H*3/2, W) height: Original frame height width: Original frame width Returns: RGB tensor, shape (3, H, W) in range [0, 255] """ device = nv12_tensor.device # Split Y and UV planes y_plane = nv12_tensor[:height, :].float() # (H, W) uv_plane = nv12_tensor[height:, :].float() # (H/2, W) # Reshape UV plane to separate U and V channels # UV is interleaved: U0V0U1V1... we need to deinterleave uv_plane = uv_plane.reshape(height // 2, width // 2, 2) # (H/2, W/2, 2) u_plane = uv_plane[:, :, 0] # (H/2, W/2) v_plane = uv_plane[:, :, 1] # (H/2, W/2) # Upsample U and V to full resolution using bilinear interpolation u_upsampled = ( torch.nn.functional.interpolate( u_plane.unsqueeze(0).unsqueeze(0), # (1, 1, H/2, W/2) size=(height, width), mode="bilinear", align_corners=False, ) .squeeze(0) .squeeze(0) ) # (H, W) v_upsampled = ( torch.nn.functional.interpolate( v_plane.unsqueeze(0).unsqueeze(0), # (1, 1, H/2, W/2) size=(height, width), mode="bilinear", align_corners=False, ) .squeeze(0) .squeeze(0) ) # (H, W) # YUV to RGB conversion using BT.601 standard # R = Y + 1.402 * (V - 128) # G = Y - 0.344136 * (U - 128) - 0.714136 * (V - 128) # B = Y + 1.772 * (U - 128) y = y_plane u = u_upsampled - 128.0 v = v_upsampled - 128.0 r = y + 1.402 * v g = y - 0.344136 * u - 0.714136 * v b = y + 1.772 * u # Clamp to [0, 255] and convert to uint8 r = torch.clamp(r, 0, 255).to(torch.uint8) g = torch.clamp(g, 0, 255).to(torch.uint8) b = torch.clamp(b, 0, 255).to(torch.uint8) # Stack to (3, H, W) rgb = torch.stack([r, g, b], dim=0) return rgb class ConnectionStatus(Enum): DISCONNECTED = "disconnected" CONNECTING = "connecting" CONNECTED = "connected" ERROR = "error" RECONNECTING = "reconnecting" class StreamDecoderFactory: """ Factory for creating StreamDecoder instances with shared CUDA context. This minimizes VRAM overhead by sharing the CUDA context across all decoders. """ _instance = None _lock = threading.Lock() def __new__(cls, gpu_id: int = 0): if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super(StreamDecoderFactory, cls).__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self, gpu_id: int = 0): if self._initialized: return self.gpu_id = gpu_id # Initialize CUDA and get device (err,) = cuda_driver.cuInit(0) if err != cuda_driver.CUresult.CUDA_SUCCESS: raise RuntimeError(f"Failed to initialize CUDA: {err}") # Get CUDA device err, self.cuda_device = cuda_driver.cuDeviceGet(gpu_id) if err != cuda_driver.CUresult.CUDA_SUCCESS: raise RuntimeError(f"Failed to get CUDA device {gpu_id}: {err}") # Retain primary context (shared across all decoders) err, self.cuda_context = cuda_driver.cuDevicePrimaryCtxRetain(self.cuda_device) if err != cuda_driver.CUresult.CUDA_SUCCESS: raise RuntimeError(f"Failed to retain CUDA primary context: {err}") self._initialized = True print( f"StreamDecoderFactory initialized with shared CUDA context on GPU {gpu_id}" ) def create_decoder( self, rtsp_url: str, buffer_size: int = 30, codec: str = "h264" ) -> "StreamDecoder": """ Create a new StreamDecoder instance with shared CUDA context. Args: rtsp_url: RTSP stream URL buffer_size: Number of frames to buffer in VRAM codec: Video codec (h264, hevc, etc.) Returns: StreamDecoder instance """ return StreamDecoder( rtsp_url=rtsp_url, cuda_context=self.cuda_context, gpu_id=self.gpu_id, buffer_size=buffer_size, codec=codec, ) def __del__(self): """Cleanup shared CUDA context on factory destruction""" if hasattr(self, "cuda_device") and hasattr(self, "gpu_id"): cuda_driver.cuDevicePrimaryCtxRelease(self.cuda_device) class StreamDecoder: """ Decodes RTSP stream using NVDEC and maintains a ring buffer of frames in GPU VRAM. Thread-safe for concurrent read/write operations. """ def __init__( self, rtsp_url: str, cuda_context, gpu_id: int, buffer_size: int = 30, codec: str = "h264", ): """ Initialize StreamDecoder. Args: rtsp_url: RTSP stream URL cuda_context: Shared CUDA context handle gpu_id: GPU device ID buffer_size: Number of frames to keep in ring buffer codec: Video codec type """ self.rtsp_url = rtsp_url self.cuda_context = cuda_context self.gpu_id = gpu_id self.buffer_size = buffer_size self.codec = codec # Connection status self.status = ConnectionStatus.DISCONNECTED self._status_lock = threading.Lock() # Frame buffer (ring buffer) - stores cloned RGB tensors self.frame_buffer = deque(maxlen=buffer_size) self._buffer_lock = threading.RLock() # Decoder and container instances self.decoder = None self.container = None # Decode thread self._decode_thread: Optional[threading.Thread] = None self._stop_flag = threading.Event() # Frame metadata self.frame_width: Optional[int] = None self.frame_height: Optional[int] = None self.frame_count: int = 0 # Frame callbacks - event-driven notification self._frame_callbacks = [] self._callback_lock = threading.Lock() # Track frames currently in use (referenced by callbacks/pipeline) self._in_use_frames = [] # List of FrameReference objects self._frame_index_counter = 0 # Monotonically increasing frame index def register_frame_callback(self, callback: Callable): """ Register a callback to be called when a new frame is decoded. The callback will be called with the decoded frame tensor (GPU) as argument. Callback signature: callback(frame: torch.Tensor) -> None Args: callback: Function to call when new frame arrives """ with self._callback_lock: self._frame_callbacks.append(callback) def unregister_frame_callback(self, callback: Callable): """ Unregister a frame callback. Args: callback: The callback function to remove """ with self._callback_lock: if callback in self._frame_callbacks: self._frame_callbacks.remove(callback) def _mark_frame_free(self, buffer_index: int): """ Mark a frame as freed (called by FrameReference when it's no longer in use). Args: buffer_index: Index in the buffer for tracking purposes """ with self._buffer_lock: # Remove from in-use tracking self._in_use_frames = [ f for f in self._in_use_frames if f.buffer_index != buffer_index ] def start(self): """Start the RTSP stream decoding in background thread""" if self._decode_thread is not None and self._decode_thread.is_alive(): print(f"Decoder already running for {self.rtsp_url}") return self._stop_flag.clear() self._decode_thread = threading.Thread(target=self._decode_loop, daemon=True) self._decode_thread.start() print(f"Started decoder thread for {self.rtsp_url}") def stop(self): """Stop the decoding thread and cleanup resources""" self._stop_flag.set() if self._decode_thread is not None: self._decode_thread.join(timeout=5.0) self._cleanup() print(f"Stopped decoder for {self.rtsp_url}") def _set_status(self, status: ConnectionStatus): """Thread-safe status update""" with self._status_lock: self.status = status def get_status(self) -> ConnectionStatus: """Get current connection status""" with self._status_lock: return self.status def _init_rtsp_connection(self) -> bool: """Initialize RTSP connection using PyAV + PyNvVideoCodec""" try: self._set_status(ConnectionStatus.CONNECTING) # Open RTSP stream with PyAV options = { "rtsp_transport": "tcp", "max_delay": "500000", # 500ms "rtsp_flags": "prefer_tcp", "timeout": "5000000", # 5 seconds } self.container = av.open(self.rtsp_url, options=options) # Get video stream video_stream = self.container.streams.video[0] self.frame_width = video_stream.width self.frame_height = video_stream.height print(f"RTSP connected: {self.frame_width}x{self.frame_height}") # Map codec name to PyNvVideoCodec codec enum codec_map = { "h264": nvc.cudaVideoCodec.H264, "hevc": nvc.cudaVideoCodec.HEVC, "h265": nvc.cudaVideoCodec.HEVC, } codec_id = codec_map.get(self.codec.lower(), nvc.cudaVideoCodec.H264) # Initialize NVDEC decoder with shared CUDA context self.decoder = nvc.CreateDecoder( gpuid=self.gpu_id, codec=codec_id, cudacontext=self.cuda_context, usedevicememory=True, ) self._set_status(ConnectionStatus.CONNECTED) return True except Exception as e: print(f"Failed to connect to RTSP stream {self.rtsp_url}: {e}") self._set_status(ConnectionStatus.ERROR) return False def _decode_loop(self): """Main decode loop running in background thread""" # Set the CUDA device for this thread torch.cuda.set_device(self.gpu_id) retry_count = 0 max_retries = 5 while not self._stop_flag.is_set(): # Initialize connection if not self._init_rtsp_connection(): retry_count += 1 if retry_count >= max_retries: print(f"Max retries reached for {self.rtsp_url}") self._set_status(ConnectionStatus.ERROR) break self._set_status(ConnectionStatus.RECONNECTING) self._stop_flag.wait(timeout=2.0) continue retry_count = 0 # Reset on successful connection try: # Decode loop - iterate through packets from PyAV for packet in self.container.demux(video=0): if self._stop_flag.is_set(): break if packet.dts is None: continue # Convert packet to numpy array packet_data = np.frombuffer(bytes(packet), dtype=np.uint8) # Create PacketData and pass numpy array pointer pkt = nvc.PacketData() pkt.bsl_data = packet_data.ctypes.data pkt.bsl = len(packet_data) # Decode using NVDEC decoded_frames = self.decoder.Decode(pkt) if not decoded_frames: continue # Add frames to ring buffer and fire callbacks with self._buffer_lock: for frame in decoded_frames: # Convert to tensor immediately after NVDEC try: # Convert DecodedFrame to PyTorch tensor using DLPack (zero-copy) nv12_tensor = torch.from_dlpack(frame) # Convert NV12 to RGB on GPU if ( self.frame_height is not None and self.frame_width is not None ): rgb_tensor = nv12_to_rgb_gpu( nv12_tensor, self.frame_height, self.frame_width ) # CLONE ONCE into our post-decode buffer # This breaks the dependency on PyNvVideoCodec's DecodedFrame # After this, the tensor is fully ours and can be used throughout the pipeline rgb_cloned = rgb_tensor.clone() # Create FrameReference for reference counting frame_ref = FrameReference( rgb_tensor=rgb_cloned, buffer_index=self._frame_index_counter, decoder=self, ) self._frame_index_counter += 1 # Add FrameReference to ring buffer (deque automatically removes oldest when full) self.frame_buffer.append(frame_ref) self.frame_count += 1 # Track this frame as in-use self._in_use_frames.append(frame_ref) # Fire callbacks with the FrameReference # The callback receivers should call .free() when done with self._callback_lock: for callback in self._frame_callbacks: try: callback(frame_ref) except Exception as e: print(f"Error in frame callback: {e}") except Exception as e: print(f"Error converting frame for callback: {e}") except Exception as e: print(f"Error in decode loop for {self.rtsp_url}: {e}") self._set_status(ConnectionStatus.RECONNECTING) self._cleanup() self._stop_flag.wait(timeout=2.0) def _cleanup(self): """Cleanup resources""" if self.container: try: self.container.close() except: pass self.container = None self.decoder = None with self._buffer_lock: self.frame_buffer.clear() def get_frame(self, index: int = -1, rgb: bool = True) -> Optional[torch.Tensor]: """ Get a frame from the buffer as a CUDA tensor (in VRAM). Args: index: Frame index in buffer (-1 for latest, -2 for second latest, etc.) rgb: If True, return RGB tensor. If False, not supported (returns None). Returns: torch.Tensor in CUDA memory (device tensor) or None if buffer empty - If rgb=True: Shape (3, H, W) in RGB format, dtype uint8 - If rgb=False: Not supported with FrameReference (returns None) """ with self._buffer_lock: if len(self.frame_buffer) == 0: return None if not rgb: print( "Warning: NV12 format not supported with FrameReference, only RGB" ) return None try: frame_ref = self.frame_buffer[index] # Return the RGB tensor from FrameReference (cloned, independent) return frame_ref.rgb_tensor except (IndexError, Exception) as e: print(f"Error getting frame: {e}") return None def get_latest_frame(self, rgb: bool = True) -> Optional[torch.Tensor]: """ Get the most recent decoded frame as CUDA tensor. Args: rgb: If True, convert to RGB. If False, return raw NV12. Returns: torch.Tensor on GPU in RGB (3, H, W) or NV12 (H*3/2, W) format """ return self.get_frame(-1, rgb=rgb) def get_frame_cpu(self, index: int = -1, rgb: bool = True) -> Optional[np.ndarray]: """ Get a frame from the buffer and copy it to CPU memory as numpy array. Args: index: Frame index in buffer (-1 for latest, -2 for second latest, etc.) rgb: If True, convert NV12 to RGB. If False, return raw NV12 format. Returns: numpy.ndarray in CPU memory or None if buffer empty - If rgb=True: Shape (H, W, 3) in RGB format, dtype uint8 (HWC format for easy display) - If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8 """ # Get frame on GPU gpu_frame = self.get_frame(index=index, rgb=rgb) if gpu_frame is None: return None # Transfer from GPU to CPU cpu_tensor = gpu_frame.cpu() # Convert to numpy array if rgb: # Convert from (3, H, W) to (H, W, 3) for standard image format cpu_array = cpu_tensor.permute(1, 2, 0).numpy() else: # Keep NV12 format as-is cpu_array = cpu_tensor.numpy() return cpu_array def get_latest_frame_cpu(self, rgb: bool = True) -> Optional[np.ndarray]: """ Get the most recent decoded frame as CPU numpy array. Args: rgb: If True, convert to RGB. If False, return raw NV12. Returns: numpy.ndarray in CPU memory - If rgb=True: Shape (H, W, 3) in RGB format, dtype uint8 - If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8 """ return self.get_frame_cpu(-1, rgb=rgb) def get_buffer_size(self) -> int: """Get current number of frames in buffer""" with self._buffer_lock: return len(self.frame_buffer) def get_all_frames(self, rgb: bool = True) -> list: """ Get all frames currently in the buffer as CUDA tensors. This drains the buffer and returns all frames. Args: rgb: If True, return RGB tensors. If False, not supported (returns empty list). Returns: List of torch.Tensor objects in CUDA memory """ if not rgb: print("Warning: NV12 format not supported with FrameReference, only RGB") return [] frames = [] with self._buffer_lock: # Get all frames from buffer for frame_ref in self.frame_buffer: try: # Get RGB tensor from FrameReference frames.append(frame_ref.rgb_tensor) except Exception as e: print(f"Error getting frame: {e}") continue # Clear the buffer after reading all frames and free all references for frame_ref in self.frame_buffer: frame_ref.free() self.frame_buffer.clear() return frames def get_frame_count(self) -> int: """Get total number of frames decoded since start""" return self.frame_count def is_connected(self) -> bool: """Check if stream is actively connected""" return self.get_status() == ConnectionStatus.CONNECTED def get_frame_as_jpeg(self, index: int = -1, quality: int = 95) -> Optional[bytes]: """ Get a frame from the buffer and encode to JPEG. This method: 1. Gets RGB frame from buffer (stays on GPU) 2. Encodes to JPEG using nvJPEG (GPU operation via shared encoder) 3. Transfers JPEG bytes to CPU 4. Returns bytes for saving to disk Args: index: Frame index in buffer (-1 for latest) quality: JPEG quality (0-100, default 95) Returns: JPEG encoded bytes or None if frame unavailable """ # Get RGB frame (on GPU) rgb_frame = self.get_frame(index=index, rgb=True) # Use the shared JPEG encoder from jpeg_encoder module return encode_frame_to_jpeg(rgb_frame, quality=quality) def __repr__(self): return ( f"StreamDecoder(url={self.rtsp_url}, status={self.status.value}, " f"buffer={self.get_buffer_size()}/{self.buffer_size}, " f"frames_decoded={self.frame_count})" )