event driven system

This commit is contained in:
Siwat Sirichai 2025-11-10 11:51:06 +07:00
parent 0c5f56c8a6
commit 3a47920186
10 changed files with 782 additions and 253 deletions

View file

@ -1,5 +1,5 @@
import threading
from typing import Optional
from typing import Optional, Callable
from collections import deque
from enum import Enum
import torch
@ -10,6 +10,35 @@ from cuda.bindings import driver as cuda_driver
from .jpeg_encoder import encode_frame_to_jpeg
class FrameReference:
"""
CPU-side reference object for a GPU frame.
This object holds a cloned RGB tensor that is independent of PyNvVideoCodec's
DecodedFrame lifecycle. We don't keep the DecodedFrame to avoid conflicts
with PyNvVideoCodec's internal frame pool management.
"""
def __init__(self, rgb_tensor: torch.Tensor, buffer_index: int, decoder):
self.rgb_tensor = rgb_tensor # Cloned RGB tensor (independent copy)
self.buffer_index = buffer_index
self.decoder = decoder # Reference to decoder for marking as free
self._freed = False
def free(self):
"""Mark this frame as no longer in use"""
if not self._freed:
self._freed = True
self.decoder._mark_frame_free(self.buffer_index)
def is_freed(self) -> bool:
"""Check if this frame has been freed"""
return self._freed
def __del__(self):
"""Auto-free on garbage collection"""
self.free()
def nv12_to_rgb_gpu(nv12_tensor: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
Convert NV12 format to RGB on GPU using PyTorch operations.
@ -183,10 +212,13 @@ class StreamDecoder:
self.status = ConnectionStatus.DISCONNECTED
self._status_lock = threading.Lock()
# Frame buffer (ring buffer) - stores CUDA device pointers
# Frame buffer (ring buffer) - stores FrameReference objects
self.frame_buffer = deque(maxlen=buffer_size)
self._buffer_lock = threading.RLock()
# Track which buffer slots are in use (list of FrameReference objects)
self._in_use_frames = [] # List of FrameReference objects currently held by callbacks
# Decoder and container instances
self.decoder = None
self.container = None
@ -200,6 +232,45 @@ class StreamDecoder:
self.frame_height: Optional[int] = None
self.frame_count: int = 0
# Frame callbacks - event-driven notification
self._frame_callbacks = []
self._callback_lock = threading.Lock()
def register_frame_callback(self, callback: Callable):
"""
Register a callback to be called when a new frame is decoded.
The callback will be called with the decoded frame tensor (GPU) as argument.
Callback signature: callback(frame: torch.Tensor) -> None
Args:
callback: Function to call when new frame arrives
"""
with self._callback_lock:
self._frame_callbacks.append(callback)
def unregister_frame_callback(self, callback: Callable):
"""
Unregister a frame callback.
Args:
callback: The callback function to remove
"""
with self._callback_lock:
if callback in self._frame_callbacks:
self._frame_callbacks.remove(callback)
def _mark_frame_free(self, buffer_index: int):
"""
Mark a frame as freed (called by FrameReference when it's no longer in use).
Args:
buffer_index: Index in the buffer for tracking purposes
"""
with self._buffer_lock:
# Remove from in-use tracking
self._in_use_frames = [f for f in self._in_use_frames if f.buffer_index != buffer_index]
def start(self):
"""Start the RTSP stream decoding in background thread"""
if self._decode_thread is not None and self._decode_thread.is_alive():
@ -278,6 +349,9 @@ class StreamDecoder:
def _decode_loop(self):
"""Main decode loop running in background thread"""
# Set the CUDA device for this thread
torch.cuda.set_device(self.gpu_id)
retry_count = 0
max_retries = 5
@ -319,11 +393,60 @@ class StreamDecoder:
if not decoded_frames:
continue
# Add frames to ring buffer (thread-safe)
# Add frames to ring buffer and fire callbacks
with self._buffer_lock:
for frame in decoded_frames:
self.frame_buffer.append(frame)
self.frame_count += 1
# Check for buffer overflow - discard oldest if needed
if len(self.frame_buffer) >= self.buffer_size:
# Check if oldest frame is still in use
if len(self._in_use_frames) > 0:
oldest_ref = self.frame_buffer[0] if len(self.frame_buffer) > 0 else None
if oldest_ref and not oldest_ref.is_freed():
# Force free the oldest frame to prevent overflow
print(f"[WARNING] Buffer overflow, force-freeing oldest frame (buffer_index={oldest_ref.buffer_index})")
oldest_ref.free()
# Deque will automatically remove oldest when at maxlen
# Convert to tensor
try:
# Convert DecodedFrame to PyTorch tensor using DLPack (zero-copy)
nv12_tensor = torch.from_dlpack(frame)
# Convert NV12 to RGB on GPU
if self.frame_height is not None and self.frame_width is not None:
rgb_tensor = nv12_to_rgb_gpu(nv12_tensor, self.frame_height, self.frame_width)
# CRITICAL: Clone the RGB tensor to break CUDA memory dependency
# The nv12_to_rgb_gpu creates a new tensor, but it still references
# the same CUDA context/stream. We need an independent copy.
rgb_tensor_cloned = rgb_tensor.clone()
# Create FrameReference object for C++-style memory management
# We don't keep the DecodedFrame to avoid conflicts with PyNvVideoCodec's
# internal frame pool - the clone is fully independent
buffer_index = self.frame_count
frame_ref = FrameReference(
rgb_tensor=rgb_tensor_cloned, # Independent cloned tensor
buffer_index=buffer_index,
decoder=self
)
# Add to buffer and in-use tracking
self.frame_buffer.append(frame_ref)
self._in_use_frames.append(frame_ref)
self.frame_count += 1
# Fire callbacks with the cloned RGB tensor from FrameReference
# The tensor is now independent of the DecodedFrame lifecycle
with self._callback_lock:
for callback in self._frame_callbacks:
try:
callback(frame_ref.rgb_tensor)
except Exception as e:
print(f"Error in frame callback: {e}")
except Exception as e:
print(f"Error converting frame for callback: {e}")
except Exception as e:
print(f"Error in decode loop for {self.rtsp_url}: {e}")
@ -351,35 +474,25 @@ class StreamDecoder:
Args:
index: Frame index in buffer (-1 for latest, -2 for second latest, etc.)
rgb: If True, convert NV12 to RGB. If False, return raw NV12 format.
rgb: If True, return RGB tensor. If False, not supported (returns None).
Returns:
torch.Tensor in CUDA memory (device tensor) or None if buffer empty
- If rgb=True: Shape (3, H, W) in RGB format, dtype uint8
- If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8
- If rgb=False: Not supported with FrameReference (returns None)
"""
with self._buffer_lock:
if len(self.frame_buffer) == 0:
return None
if not rgb:
print("Warning: NV12 format not supported with FrameReference, only RGB")
return None
try:
decoded_frame = self.frame_buffer[index]
# Convert DecodedFrame to PyTorch tensor using DLPack (zero-copy)
# This keeps the data in GPU memory
nv12_tensor = torch.from_dlpack(decoded_frame)
if not rgb:
# Return raw NV12 format
return nv12_tensor
# Convert NV12 to RGB on GPU
if self.frame_height is None or self.frame_width is None:
print("Frame dimensions not available")
return None
rgb_tensor = nv12_to_rgb_gpu(nv12_tensor, self.frame_height, self.frame_width)
return rgb_tensor
frame_ref = self.frame_buffer[index]
# Return the RGB tensor from FrameReference (cloned, independent)
return frame_ref.rgb_tensor
except (IndexError, Exception) as e:
print(f"Error getting frame: {e}")
@ -448,6 +561,39 @@ class StreamDecoder:
with self._buffer_lock:
return len(self.frame_buffer)
def get_all_frames(self, rgb: bool = True) -> list:
"""
Get all frames currently in the buffer as CUDA tensors.
This drains the buffer and returns all frames.
Args:
rgb: If True, return RGB tensors. If False, not supported (returns empty list).
Returns:
List of torch.Tensor objects in CUDA memory
"""
if not rgb:
print("Warning: NV12 format not supported with FrameReference, only RGB")
return []
frames = []
with self._buffer_lock:
# Get all frames from buffer
for frame_ref in self.frame_buffer:
try:
# Get RGB tensor from FrameReference
frames.append(frame_ref.rgb_tensor)
except Exception as e:
print(f"Error getting frame: {e}")
continue
# Clear the buffer after reading all frames and free all references
for frame_ref in self.frame_buffer:
frame_ref.free()
self.frame_buffer.clear()
return frames
def get_frame_count(self) -> int:
"""Get total number of frames decoded since start"""
return self.frame_count