event driven system
This commit is contained in:
parent
0c5f56c8a6
commit
3a47920186
10 changed files with 782 additions and 253 deletions
|
|
@ -1,5 +1,5 @@
|
|||
import threading
|
||||
from typing import Optional
|
||||
from typing import Optional, Callable
|
||||
from collections import deque
|
||||
from enum import Enum
|
||||
import torch
|
||||
|
|
@ -10,6 +10,35 @@ from cuda.bindings import driver as cuda_driver
|
|||
from .jpeg_encoder import encode_frame_to_jpeg
|
||||
|
||||
|
||||
class FrameReference:
|
||||
"""
|
||||
CPU-side reference object for a GPU frame.
|
||||
|
||||
This object holds a cloned RGB tensor that is independent of PyNvVideoCodec's
|
||||
DecodedFrame lifecycle. We don't keep the DecodedFrame to avoid conflicts
|
||||
with PyNvVideoCodec's internal frame pool management.
|
||||
"""
|
||||
def __init__(self, rgb_tensor: torch.Tensor, buffer_index: int, decoder):
|
||||
self.rgb_tensor = rgb_tensor # Cloned RGB tensor (independent copy)
|
||||
self.buffer_index = buffer_index
|
||||
self.decoder = decoder # Reference to decoder for marking as free
|
||||
self._freed = False
|
||||
|
||||
def free(self):
|
||||
"""Mark this frame as no longer in use"""
|
||||
if not self._freed:
|
||||
self._freed = True
|
||||
self.decoder._mark_frame_free(self.buffer_index)
|
||||
|
||||
def is_freed(self) -> bool:
|
||||
"""Check if this frame has been freed"""
|
||||
return self._freed
|
||||
|
||||
def __del__(self):
|
||||
"""Auto-free on garbage collection"""
|
||||
self.free()
|
||||
|
||||
|
||||
def nv12_to_rgb_gpu(nv12_tensor: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
Convert NV12 format to RGB on GPU using PyTorch operations.
|
||||
|
|
@ -183,10 +212,13 @@ class StreamDecoder:
|
|||
self.status = ConnectionStatus.DISCONNECTED
|
||||
self._status_lock = threading.Lock()
|
||||
|
||||
# Frame buffer (ring buffer) - stores CUDA device pointers
|
||||
# Frame buffer (ring buffer) - stores FrameReference objects
|
||||
self.frame_buffer = deque(maxlen=buffer_size)
|
||||
self._buffer_lock = threading.RLock()
|
||||
|
||||
# Track which buffer slots are in use (list of FrameReference objects)
|
||||
self._in_use_frames = [] # List of FrameReference objects currently held by callbacks
|
||||
|
||||
# Decoder and container instances
|
||||
self.decoder = None
|
||||
self.container = None
|
||||
|
|
@ -200,6 +232,45 @@ class StreamDecoder:
|
|||
self.frame_height: Optional[int] = None
|
||||
self.frame_count: int = 0
|
||||
|
||||
# Frame callbacks - event-driven notification
|
||||
self._frame_callbacks = []
|
||||
self._callback_lock = threading.Lock()
|
||||
|
||||
def register_frame_callback(self, callback: Callable):
|
||||
"""
|
||||
Register a callback to be called when a new frame is decoded.
|
||||
|
||||
The callback will be called with the decoded frame tensor (GPU) as argument.
|
||||
Callback signature: callback(frame: torch.Tensor) -> None
|
||||
|
||||
Args:
|
||||
callback: Function to call when new frame arrives
|
||||
"""
|
||||
with self._callback_lock:
|
||||
self._frame_callbacks.append(callback)
|
||||
|
||||
def unregister_frame_callback(self, callback: Callable):
|
||||
"""
|
||||
Unregister a frame callback.
|
||||
|
||||
Args:
|
||||
callback: The callback function to remove
|
||||
"""
|
||||
with self._callback_lock:
|
||||
if callback in self._frame_callbacks:
|
||||
self._frame_callbacks.remove(callback)
|
||||
|
||||
def _mark_frame_free(self, buffer_index: int):
|
||||
"""
|
||||
Mark a frame as freed (called by FrameReference when it's no longer in use).
|
||||
|
||||
Args:
|
||||
buffer_index: Index in the buffer for tracking purposes
|
||||
"""
|
||||
with self._buffer_lock:
|
||||
# Remove from in-use tracking
|
||||
self._in_use_frames = [f for f in self._in_use_frames if f.buffer_index != buffer_index]
|
||||
|
||||
def start(self):
|
||||
"""Start the RTSP stream decoding in background thread"""
|
||||
if self._decode_thread is not None and self._decode_thread.is_alive():
|
||||
|
|
@ -278,6 +349,9 @@ class StreamDecoder:
|
|||
|
||||
def _decode_loop(self):
|
||||
"""Main decode loop running in background thread"""
|
||||
# Set the CUDA device for this thread
|
||||
torch.cuda.set_device(self.gpu_id)
|
||||
|
||||
retry_count = 0
|
||||
max_retries = 5
|
||||
|
||||
|
|
@ -319,11 +393,60 @@ class StreamDecoder:
|
|||
if not decoded_frames:
|
||||
continue
|
||||
|
||||
# Add frames to ring buffer (thread-safe)
|
||||
# Add frames to ring buffer and fire callbacks
|
||||
with self._buffer_lock:
|
||||
for frame in decoded_frames:
|
||||
self.frame_buffer.append(frame)
|
||||
self.frame_count += 1
|
||||
# Check for buffer overflow - discard oldest if needed
|
||||
if len(self.frame_buffer) >= self.buffer_size:
|
||||
# Check if oldest frame is still in use
|
||||
if len(self._in_use_frames) > 0:
|
||||
oldest_ref = self.frame_buffer[0] if len(self.frame_buffer) > 0 else None
|
||||
if oldest_ref and not oldest_ref.is_freed():
|
||||
# Force free the oldest frame to prevent overflow
|
||||
print(f"[WARNING] Buffer overflow, force-freeing oldest frame (buffer_index={oldest_ref.buffer_index})")
|
||||
oldest_ref.free()
|
||||
|
||||
# Deque will automatically remove oldest when at maxlen
|
||||
|
||||
# Convert to tensor
|
||||
try:
|
||||
# Convert DecodedFrame to PyTorch tensor using DLPack (zero-copy)
|
||||
nv12_tensor = torch.from_dlpack(frame)
|
||||
|
||||
# Convert NV12 to RGB on GPU
|
||||
if self.frame_height is not None and self.frame_width is not None:
|
||||
rgb_tensor = nv12_to_rgb_gpu(nv12_tensor, self.frame_height, self.frame_width)
|
||||
|
||||
# CRITICAL: Clone the RGB tensor to break CUDA memory dependency
|
||||
# The nv12_to_rgb_gpu creates a new tensor, but it still references
|
||||
# the same CUDA context/stream. We need an independent copy.
|
||||
rgb_tensor_cloned = rgb_tensor.clone()
|
||||
|
||||
# Create FrameReference object for C++-style memory management
|
||||
# We don't keep the DecodedFrame to avoid conflicts with PyNvVideoCodec's
|
||||
# internal frame pool - the clone is fully independent
|
||||
buffer_index = self.frame_count
|
||||
frame_ref = FrameReference(
|
||||
rgb_tensor=rgb_tensor_cloned, # Independent cloned tensor
|
||||
buffer_index=buffer_index,
|
||||
decoder=self
|
||||
)
|
||||
|
||||
# Add to buffer and in-use tracking
|
||||
self.frame_buffer.append(frame_ref)
|
||||
self._in_use_frames.append(frame_ref)
|
||||
self.frame_count += 1
|
||||
|
||||
# Fire callbacks with the cloned RGB tensor from FrameReference
|
||||
# The tensor is now independent of the DecodedFrame lifecycle
|
||||
with self._callback_lock:
|
||||
for callback in self._frame_callbacks:
|
||||
try:
|
||||
callback(frame_ref.rgb_tensor)
|
||||
except Exception as e:
|
||||
print(f"Error in frame callback: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error converting frame for callback: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in decode loop for {self.rtsp_url}: {e}")
|
||||
|
|
@ -351,35 +474,25 @@ class StreamDecoder:
|
|||
|
||||
Args:
|
||||
index: Frame index in buffer (-1 for latest, -2 for second latest, etc.)
|
||||
rgb: If True, convert NV12 to RGB. If False, return raw NV12 format.
|
||||
rgb: If True, return RGB tensor. If False, not supported (returns None).
|
||||
|
||||
Returns:
|
||||
torch.Tensor in CUDA memory (device tensor) or None if buffer empty
|
||||
- If rgb=True: Shape (3, H, W) in RGB format, dtype uint8
|
||||
- If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8
|
||||
- If rgb=False: Not supported with FrameReference (returns None)
|
||||
"""
|
||||
with self._buffer_lock:
|
||||
if len(self.frame_buffer) == 0:
|
||||
return None
|
||||
|
||||
if not rgb:
|
||||
print("Warning: NV12 format not supported with FrameReference, only RGB")
|
||||
return None
|
||||
|
||||
try:
|
||||
decoded_frame = self.frame_buffer[index]
|
||||
|
||||
# Convert DecodedFrame to PyTorch tensor using DLPack (zero-copy)
|
||||
# This keeps the data in GPU memory
|
||||
nv12_tensor = torch.from_dlpack(decoded_frame)
|
||||
|
||||
if not rgb:
|
||||
# Return raw NV12 format
|
||||
return nv12_tensor
|
||||
|
||||
# Convert NV12 to RGB on GPU
|
||||
if self.frame_height is None or self.frame_width is None:
|
||||
print("Frame dimensions not available")
|
||||
return None
|
||||
|
||||
rgb_tensor = nv12_to_rgb_gpu(nv12_tensor, self.frame_height, self.frame_width)
|
||||
return rgb_tensor
|
||||
frame_ref = self.frame_buffer[index]
|
||||
# Return the RGB tensor from FrameReference (cloned, independent)
|
||||
return frame_ref.rgb_tensor
|
||||
|
||||
except (IndexError, Exception) as e:
|
||||
print(f"Error getting frame: {e}")
|
||||
|
|
@ -448,6 +561,39 @@ class StreamDecoder:
|
|||
with self._buffer_lock:
|
||||
return len(self.frame_buffer)
|
||||
|
||||
def get_all_frames(self, rgb: bool = True) -> list:
|
||||
"""
|
||||
Get all frames currently in the buffer as CUDA tensors.
|
||||
This drains the buffer and returns all frames.
|
||||
|
||||
Args:
|
||||
rgb: If True, return RGB tensors. If False, not supported (returns empty list).
|
||||
|
||||
Returns:
|
||||
List of torch.Tensor objects in CUDA memory
|
||||
"""
|
||||
if not rgb:
|
||||
print("Warning: NV12 format not supported with FrameReference, only RGB")
|
||||
return []
|
||||
|
||||
frames = []
|
||||
with self._buffer_lock:
|
||||
# Get all frames from buffer
|
||||
for frame_ref in self.frame_buffer:
|
||||
try:
|
||||
# Get RGB tensor from FrameReference
|
||||
frames.append(frame_ref.rgb_tensor)
|
||||
except Exception as e:
|
||||
print(f"Error getting frame: {e}")
|
||||
continue
|
||||
|
||||
# Clear the buffer after reading all frames and free all references
|
||||
for frame_ref in self.frame_buffer:
|
||||
frame_ref.free()
|
||||
self.frame_buffer.clear()
|
||||
|
||||
return frames
|
||||
|
||||
def get_frame_count(self) -> int:
|
||||
"""Get total number of frames decoded since start"""
|
||||
return self.frame_count
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue