event driven system

This commit is contained in:
Siwat Sirichai 2025-11-10 11:51:06 +07:00
parent 0c5f56c8a6
commit 3a47920186
10 changed files with 782 additions and 253 deletions

View file

@ -1,17 +1,18 @@
"""
ModelController - Async batching layer with ping-pong buffers for inference.
ModelController - Event-driven batching layer with ping-pong buffers for inference.
This module provides batched inference coordination using ping-pong circular buffers
with force-switch timeout mechanism.
with force-switch timeout mechanism using threading and callbacks.
"""
import asyncio
import threading
import torch
from typing import Dict, List, Optional, Callable, Any
from dataclasses import dataclass, field
from enum import Enum
import time
import logging
import queue
logger = logging.getLogger(__name__)
@ -43,7 +44,7 @@ class ModelController:
Features:
- Ping-pong circular buffers (BufferA/BufferB)
- Force-switch timeout to prevent batch starvation
- Async event-driven processing
- Event-driven processing with callbacks
- Thread-safe frame submission
Args:
@ -90,14 +91,15 @@ class ModelController:
self.buffer_a_state = BufferState.IDLE
self.buffer_b_state = BufferState.IDLE
# Async coordination
self.buffer_lock = asyncio.Lock()
# Threading coordination
self.buffer_lock = threading.RLock()
self.last_submit_time = time.time()
# Tasks
self.timeout_task: Optional[asyncio.Task] = None
self.processor_task: Optional[asyncio.Task] = None
# Threads
self.timeout_thread: Optional[threading.Thread] = None
self.processor_threads: Dict[str, threading.Thread] = {}
self.running = False
self.stop_event = threading.Event()
# Result callbacks (stream_id -> callback)
self.result_callbacks: Dict[str, Callable] = {}
@ -130,42 +132,46 @@ class ModelController:
logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1")
return 1
async def start(self):
"""Start the controller background tasks"""
def start(self):
"""Start the controller background threads"""
if self.running:
logger.warning("ModelController already running")
return
self.running = True
self.timeout_task = asyncio.create_task(self._timeout_monitor())
self.processor_task = asyncio.create_task(self._batch_processor())
self.stop_event.clear()
# Start timeout monitor thread
self.timeout_thread = threading.Thread(target=self._timeout_monitor, daemon=True)
self.timeout_thread.start()
# Start processor threads for each buffer
self.processor_threads['A'] = threading.Thread(target=self._batch_processor, args=('A',), daemon=True)
self.processor_threads['B'] = threading.Thread(target=self._batch_processor, args=('B',), daemon=True)
self.processor_threads['A'].start()
self.processor_threads['B'].start()
logger.info("ModelController started")
async def stop(self):
def stop(self):
"""Stop the controller and cleanup"""
if not self.running:
return
logger.info("Stopping ModelController...")
self.running = False
self.stop_event.set()
# Cancel tasks
if self.timeout_task:
self.timeout_task.cancel()
try:
await self.timeout_task
except asyncio.CancelledError:
pass
# Wait for threads to finish
if self.timeout_thread and self.timeout_thread.is_alive():
self.timeout_thread.join(timeout=2.0)
if self.processor_task:
self.processor_task.cancel()
try:
await self.processor_task
except asyncio.CancelledError:
pass
for thread in self.processor_threads.values():
if thread and thread.is_alive():
thread.join(timeout=2.0)
# Process any remaining frames
await self._process_remaining_buffers()
self._process_remaining_buffers()
logger.info("ModelController stopped")
def register_callback(self, stream_id: str, callback: Callable):
@ -189,7 +195,7 @@ class ModelController:
self.result_callbacks.pop(stream_id, None)
logger.debug(f"Unregistered callback for stream: {stream_id}")
async def submit_frame(
def submit_frame(
self,
stream_id: str,
frame: torch.Tensor,
@ -203,7 +209,7 @@ class ModelController:
frame: GPU tensor (3, H, W) or (C, H, W)
metadata: Optional metadata to attach to the frame
"""
async with self.buffer_lock:
with self.buffer_lock:
batch_frame = BatchFrame(
stream_id=stream_id,
frame=frame,
@ -225,23 +231,21 @@ class ModelController:
# Check if we should immediately swap (batch full)
if buffer_size >= self.batch_size:
await self._try_swap_buffers()
self._try_swap_buffers()
async def _timeout_monitor(self):
def _timeout_monitor(self):
"""Monitor force-switch timeout"""
while self.running:
await asyncio.sleep(0.01) # Check every 10ms
async with self.buffer_lock:
while self.running and not self.stop_event.wait(0.01): # Check every 10ms
with self.buffer_lock:
time_since_submit = time.time() - self.last_submit_time
# Check if timeout expired and we have frames waiting
if time_since_submit >= self.force_timeout:
active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b
if len(active_buffer) > 0:
await self._try_swap_buffers()
self._try_swap_buffers()
async def _try_swap_buffers(self):
def _try_swap_buffers(self):
"""
Attempt to swap ping-pong buffers.
Only swaps if the inactive buffer is not currently processing.
@ -266,20 +270,22 @@ class ModelController:
logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})")
async def _batch_processor(self):
"""Background task that processes batches when available"""
while self.running:
await asyncio.sleep(0.001) # Check every 1ms
def _batch_processor(self, buffer_name: str):
"""Background thread that processes a specific buffer when available"""
while self.running and not self.stop_event.is_set():
time.sleep(0.001) # Check every 1ms
# Check if buffer A needs processing
if self.buffer_a_state == BufferState.PROCESSING:
await self._process_buffer("A")
# Check if this buffer needs processing
with self.buffer_lock:
if buffer_name == "A":
should_process = self.buffer_a_state == BufferState.PROCESSING
else:
should_process = self.buffer_b_state == BufferState.PROCESSING
# Check if buffer B needs processing
if self.buffer_b_state == BufferState.PROCESSING:
await self._process_buffer("B")
if should_process:
self._process_buffer(buffer_name)
async def _process_buffer(self, buffer_name: str):
def _process_buffer(self, buffer_name: str):
"""
Process a buffer through inference.
@ -287,7 +293,7 @@ class ModelController:
buffer_name: "A" or "B"
"""
# Extract buffer to process
async with self.buffer_lock:
with self.buffer_lock:
if buffer_name == "A":
batch = self.buffer_a.copy()
self.buffer_a.clear()
@ -297,7 +303,7 @@ class ModelController:
if len(batch) == 0:
# Mark as idle and return
async with self.buffer_lock:
with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
@ -307,7 +313,7 @@ class ModelController:
# Process batch (outside lock to allow concurrent submissions)
try:
start_time = time.time()
results = await self._run_batch_inference(batch)
results = self._run_batch_inference(batch)
inference_time = time.time() - start_time
# Update statistics
@ -323,27 +329,24 @@ class ModelController:
for batch_frame, result in zip(batch, results):
callback = self.result_callbacks.get(batch_frame.stream_id)
if callback:
# Schedule callback asynchronously
if asyncio.iscoroutinefunction(callback):
asyncio.create_task(callback(result))
else:
# Run sync callback in executor to avoid blocking
loop = asyncio.get_event_loop()
loop.call_soon(lambda cb=callback, r=result: cb(r))
# Call callback directly (synchronous)
try:
callback(result)
except Exception as e:
logger.error(f"Error in callback for {batch_frame.stream_id}: {e}", exc_info=True)
except Exception as e:
logger.error(f"Error processing batch: {e}", exc_info=True)
# TODO: Emit error events to streams
finally:
# Mark buffer as idle
async with self.buffer_lock:
with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
self.buffer_b_state = BufferState.IDLE
async def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""
Run inference on a batch of frames.
@ -353,17 +356,15 @@ class ModelController:
Returns:
List of detection results (one per frame)
"""
loop = asyncio.get_event_loop()
# Check if model supports batching
if self.model_batch_size == 1:
# Process frames one at a time for batch_size=1 models
return await self._run_sequential_inference(batch, loop)
return self._run_sequential_inference(batch)
else:
# Use true batching for models that support it
return await self._run_batched_inference(batch, loop)
return self._run_batched_inference(batch)
async def _run_sequential_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
def _run_sequential_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""Run inference sequentially for batch_size=1 models"""
results = []
@ -376,13 +377,10 @@ class ModelController:
processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame
# Run inference for this frame
outputs = await loop.run_in_executor(
None,
lambda p=processed: self.model_repository.infer(
self.model_id,
{"images": p},
synchronize=True
)
outputs = self.model_repository.infer(
self.model_id,
{"images": processed},
synchronize=True
)
# Postprocess
@ -406,7 +404,7 @@ class ModelController:
return results
async def _run_batched_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
def _run_batched_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""Run true batched inference for models that support it"""
# Preprocess frames (on GPU)
preprocessed = []
@ -434,13 +432,10 @@ class ModelController:
batch = batch[:self.model_batch_size]
# Run inference
outputs = await loop.run_in_executor(
None,
lambda: self.model_repository.infer(
self.model_id,
{"images": batch_tensor},
synchronize=True
)
outputs = self.model_repository.infer(
self.model_id,
{"images": batch_tensor},
synchronize=True
)
# Postprocess results (split batch back to individual results)
@ -472,14 +467,14 @@ class ModelController:
return results
async def _process_remaining_buffers(self):
def _process_remaining_buffers(self):
"""Process any remaining frames in buffers during shutdown"""
if len(self.buffer_a) > 0:
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
await self._process_buffer("A")
self._process_buffer("A")
if len(self.buffer_b) > 0:
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
await self._process_buffer("B")
self._process_buffer("B")
def get_stats(self) -> Dict[str, Any]:
"""Get current buffer statistics"""

View file

@ -1,16 +1,17 @@
"""
StreamConnectionManager - Async orchestration for stream processing with batched inference.
StreamConnectionManager - Event-driven orchestration for stream processing with batched inference.
This module provides high-level connection management for multiple RTSP streams,
coordinating decoders, batched inference, and tracking with an event-driven API.
coordinating decoders, batched inference, and tracking with callbacks and threading.
"""
import asyncio
import threading
import time
from typing import Dict, Optional, Callable, AsyncIterator, Tuple, Any, List
from typing import Dict, Optional, Callable, Tuple, Any, List
from dataclasses import dataclass
from enum import Enum
import logging
import queue
import torch
@ -44,7 +45,7 @@ class StreamConnection:
"""
Represents a single stream connection with event emission.
This class wraps a StreamDecoder, polls frames asynchronously, submits them
This class wraps a StreamDecoder, polls frames in a thread, submits them
to the ModelController for batched inference, runs tracking, and emits results
via queues or callbacks.
@ -75,15 +76,19 @@ class StreamConnection:
self.last_frame_time = 0.0
# Event emission
self.result_queue: asyncio.Queue[TrackingResult] = asyncio.Queue()
self.error_queue: asyncio.Queue[Exception] = asyncio.Queue()
self.result_queue: queue.Queue[TrackingResult] = queue.Queue()
self.error_queue: queue.Queue[Exception] = queue.Queue()
# Tasks
self.poller_task: Optional[asyncio.Task] = None
# Event-driven state
self.running = False
async def start(self):
"""Start the connection (decoder and frame polling)"""
def start(self):
"""Start the connection (decoder with frame callback)"""
self.running = True
# Register callback for frame events from decoder
self.decoder.register_frame_callback(self._on_frame_decoded)
# Start decoder (runs in background thread)
self.decoder.start()
@ -93,7 +98,7 @@ class StreamConnection:
elapsed = 0.0
while elapsed < max_wait:
await asyncio.sleep(wait_interval)
time.sleep(wait_interval)
elapsed += wait_interval
if self.decoder.is_connected():
@ -105,21 +110,13 @@ class StreamConnection:
logger.warning(f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying...")
self.status = ConnectionStatus.CONNECTING
# Start frame polling task
self.running = True
self.poller_task = asyncio.create_task(self._frame_poller())
async def stop(self):
def stop(self):
"""Stop the connection and cleanup"""
logger.info(f"Stopping stream {self.stream_id}...")
self.running = False
if self.poller_task:
self.poller_task.cancel()
try:
await self.poller_task
except asyncio.CancelledError:
pass
# Unregister frame callback
self.decoder.unregister_frame_callback(self._on_frame_decoded)
# Stop decoder
self.decoder.stop()
@ -130,55 +127,45 @@ class StreamConnection:
self.status = ConnectionStatus.DISCONNECTED
logger.info(f"Stream {self.stream_id} stopped")
async def _frame_poller(self):
"""Poll frames from threaded decoder and submit to model controller"""
last_decoder_frame_count = -1
def _on_frame_decoded(self, frame: torch.Tensor):
"""
Event handler called by decoder when a new frame is decoded.
This is the event-driven replacement for polling.
while self.running:
try:
# Get current decoder frame count (no data transfer, just counter)
decoder_frame_count = self.decoder.get_frame_count()
Args:
frame: RGB frame tensor on GPU (3, H, W)
"""
if not self.running:
return
# Check if decoder has a new frame (avoid reprocessing same frame)
if decoder_frame_count > last_decoder_frame_count:
# Poll frame from decoder (zero-copy - stays in VRAM)
frame = self.decoder.get_latest_frame(rgb=True)
try:
self.last_frame_time = time.time()
self.frame_count += 1
if frame is not None:
last_decoder_frame_count = decoder_frame_count
self.last_frame_time = time.time()
self.frame_count += 1
# Submit to model controller for batched inference
self.model_controller.submit_frame(
stream_id=self.stream_id,
frame=frame,
metadata={
"frame_number": self.frame_count,
"shape": tuple(frame.shape),
}
)
# Submit to model controller for batched inference
await self.model_controller.submit_frame(
stream_id=self.stream_id,
frame=frame,
metadata={
"frame_number": self.frame_count,
"shape": tuple(frame.shape),
}
)
# Update connection status based on decoder status
if self.decoder.is_connected() and self.status != ConnectionStatus.CONNECTED:
logger.info(f"Stream {self.stream_id} reconnected")
self.status = ConnectionStatus.CONNECTED
elif not self.decoder.is_connected() and self.status == ConnectionStatus.CONNECTED:
logger.warning(f"Stream {self.stream_id} disconnected")
self.status = ConnectionStatus.DISCONNECTED
# Check decoder status
if not self.decoder.is_connected():
if self.status == ConnectionStatus.CONNECTED:
logger.warning(f"Stream {self.stream_id} disconnected")
self.status = ConnectionStatus.DISCONNECTED
# Decoder will auto-reconnect, just update status
await asyncio.sleep(1.0)
if self.decoder.is_connected():
logger.info(f"Stream {self.stream_id} reconnected")
self.status = ConnectionStatus.CONNECTED
except Exception as e:
logger.error(f"Error processing frame for {self.stream_id}: {e}", exc_info=True)
self.error_queue.put(e)
self.status = ConnectionStatus.ERROR
except Exception as e:
logger.error(f"Error in frame poller for {self.stream_id}: {e}", exc_info=True)
await self.error_queue.put(e)
self.status = ConnectionStatus.ERROR
# Sleep until next poll
await asyncio.sleep(self.poll_interval)
async def _handle_inference_result(self, result: Dict[str, Any]):
def _handle_inference_result(self, result: Dict[str, Any]):
"""
Callback invoked by ModelController when inference is done.
Runs tracking and emits final result.
@ -190,12 +177,8 @@ class StreamConnection:
# Extract detections
detections = result["detections"]
# Run tracking (this is sync, so run in executor)
loop = asyncio.get_event_loop()
tracked_objects = await loop.run_in_executor(
None,
lambda: self._run_tracking_sync(detections)
)
# Run tracking (synchronous)
tracked_objects = self._run_tracking_sync(detections)
# Create tracking result
tracking_result = TrackingResult(
@ -208,11 +191,11 @@ class StreamConnection:
)
# Emit to result queue
await self.result_queue.put(tracking_result)
self.result_queue.put(tracking_result)
except Exception as e:
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
await self.error_queue.put(e)
self.error_queue.put(e)
def _run_tracking_sync(self, detections, min_confidence=0.7):
"""
@ -246,12 +229,12 @@ class StreamConnection:
# Update tracker with detections (lightweight, no model dependency!)
return self.tracking_controller.update(detection_list)
async def tracking_results(self) -> AsyncIterator[TrackingResult]:
def tracking_results(self):
"""
Async generator for tracking results.
Generator for tracking results (blocking iterator).
Usage:
async for result in connection.tracking_results():
for result in connection.tracking_results():
print(result.tracked_objects)
Yields:
@ -259,23 +242,23 @@ class StreamConnection:
"""
while self.running or not self.result_queue.empty():
try:
result = await asyncio.wait_for(self.result_queue.get(), timeout=1.0)
result = self.result_queue.get(timeout=1.0)
yield result
except asyncio.TimeoutError:
except queue.Empty:
continue
async def errors(self) -> AsyncIterator[Exception]:
def errors(self):
"""
Async generator for errors.
Generator for errors (blocking iterator).
Yields:
Exception objects as they occur
"""
while self.running or not self.error_queue.empty():
try:
error = await asyncio.wait_for(self.error_queue.get(), timeout=1.0)
error = self.error_queue.get(timeout=1.0)
yield error
except asyncio.TimeoutError:
except queue.Empty:
continue
def get_stats(self) -> Dict[str, Any]:
@ -342,7 +325,7 @@ class StreamConnectionManager:
# State
self.initialized = False
async def initialize(
def initialize(
self,
model_path: str,
model_id: str = "detector",
@ -368,18 +351,14 @@ class StreamConnectionManager:
"""
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
# Load model
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
lambda: self.model_repository.load_model(
model_id,
model_path,
num_contexts=num_contexts,
pt_input_shapes=pt_input_shapes,
pt_precision=pt_precision,
**pt_conversion_kwargs
)
# Load model (synchronous)
self.model_repository.load_model(
model_id,
model_path,
num_contexts=num_contexts,
pt_input_shapes=pt_input_shapes,
pt_precision=pt_precision,
**pt_conversion_kwargs
)
logger.info(f"Loaded model {model_id} from {model_path}")
@ -392,7 +371,7 @@ class StreamConnectionManager:
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
await self.model_controller.start()
self.model_controller.start()
# Don't create a shared tracking controller here
# Each stream will get its own tracking controller to avoid track accumulation
@ -402,7 +381,7 @@ class StreamConnectionManager:
self.initialized = True
logger.info("StreamConnectionManager initialized successfully")
async def connect_stream(
def connect_stream(
self,
rtsp_url: str,
stream_id: Optional[str] = None,
@ -416,8 +395,8 @@ class StreamConnectionManager:
Args:
rtsp_url: RTSP stream URL
stream_id: Optional stream identifier (auto-generated if not provided)
on_tracking_result: Optional callback for tracking results (sync or async)
on_error: Optional callback for errors (sync or async)
on_tracking_result: Optional callback for tracking results (synchronous)
on_error: Optional callback for errors (synchronous)
buffer_size: Decoder buffer size (default: 30)
Returns:
@ -466,22 +445,30 @@ class StreamConnectionManager:
)
# Start connection
await connection.start()
connection.start()
# Store connection
self.connections[stream_id] = connection
# Set up user callbacks if provided
# Set up user callbacks if provided (run in separate threads)
if on_tracking_result:
asyncio.create_task(self._forward_results(connection, on_tracking_result))
threading.Thread(
target=self._forward_results,
args=(connection, on_tracking_result),
daemon=True
).start()
if on_error:
asyncio.create_task(self._forward_errors(connection, on_error))
threading.Thread(
target=self._forward_errors,
args=(connection, on_error),
daemon=True
).start()
logger.info(f"Stream {stream_id} connected successfully")
return connection
async def disconnect_stream(self, stream_id: str):
def disconnect_stream(self, stream_id: str):
"""
Disconnect and cleanup a stream.
@ -490,27 +477,27 @@ class StreamConnectionManager:
"""
connection = self.connections.get(stream_id)
if connection:
await connection.stop()
connection.stop()
del self.connections[stream_id]
logger.info(f"Stream {stream_id} disconnected")
async def disconnect_all(self):
def disconnect_all(self):
"""Disconnect all streams"""
logger.info("Disconnecting all streams...")
stream_ids = list(self.connections.keys())
for stream_id in stream_ids:
await self.disconnect_stream(stream_id)
self.disconnect_stream(stream_id)
async def shutdown(self):
def shutdown(self):
"""Shutdown the manager and cleanup all resources"""
logger.info("Shutting down StreamConnectionManager...")
# Disconnect all streams
await self.disconnect_all()
self.disconnect_all()
# Stop model controller
if self.model_controller:
await self.model_controller.stop()
self.model_controller.stop()
# Note: Model repository cleanup is sync and may cause segfaults
# Leaving cleanup to garbage collection for now
@ -518,37 +505,31 @@ class StreamConnectionManager:
self.initialized = False
logger.info("StreamConnectionManager shutdown complete")
async def _forward_results(self, connection: StreamConnection, callback: Callable):
def _forward_results(self, connection: StreamConnection, callback: Callable):
"""
Forward results from connection to user callback.
Args:
connection: StreamConnection to listen to
callback: User callback (sync or async)
callback: User callback (synchronous)
"""
try:
async for result in connection.tracking_results():
if asyncio.iscoroutinefunction(callback):
await callback(result)
else:
callback(result)
for result in connection.tracking_results():
callback(result)
except Exception as e:
logger.error(f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True)
async def _forward_errors(self, connection: StreamConnection, callback: Callable):
def _forward_errors(self, connection: StreamConnection, callback: Callable):
"""
Forward errors from connection to user callback.
Args:
connection: StreamConnection to listen to
callback: User callback (sync or async)
callback: User callback (synchronous)
"""
try:
async for error in connection.errors():
if asyncio.iscoroutinefunction(callback):
await callback(error)
else:
callback(error)
for error in connection.errors():
callback(error)
except Exception as e:
logger.error(f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True)

View file

@ -1,5 +1,5 @@
import threading
from typing import Optional
from typing import Optional, Callable
from collections import deque
from enum import Enum
import torch
@ -10,6 +10,35 @@ from cuda.bindings import driver as cuda_driver
from .jpeg_encoder import encode_frame_to_jpeg
class FrameReference:
"""
CPU-side reference object for a GPU frame.
This object holds a cloned RGB tensor that is independent of PyNvVideoCodec's
DecodedFrame lifecycle. We don't keep the DecodedFrame to avoid conflicts
with PyNvVideoCodec's internal frame pool management.
"""
def __init__(self, rgb_tensor: torch.Tensor, buffer_index: int, decoder):
self.rgb_tensor = rgb_tensor # Cloned RGB tensor (independent copy)
self.buffer_index = buffer_index
self.decoder = decoder # Reference to decoder for marking as free
self._freed = False
def free(self):
"""Mark this frame as no longer in use"""
if not self._freed:
self._freed = True
self.decoder._mark_frame_free(self.buffer_index)
def is_freed(self) -> bool:
"""Check if this frame has been freed"""
return self._freed
def __del__(self):
"""Auto-free on garbage collection"""
self.free()
def nv12_to_rgb_gpu(nv12_tensor: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
Convert NV12 format to RGB on GPU using PyTorch operations.
@ -183,10 +212,13 @@ class StreamDecoder:
self.status = ConnectionStatus.DISCONNECTED
self._status_lock = threading.Lock()
# Frame buffer (ring buffer) - stores CUDA device pointers
# Frame buffer (ring buffer) - stores FrameReference objects
self.frame_buffer = deque(maxlen=buffer_size)
self._buffer_lock = threading.RLock()
# Track which buffer slots are in use (list of FrameReference objects)
self._in_use_frames = [] # List of FrameReference objects currently held by callbacks
# Decoder and container instances
self.decoder = None
self.container = None
@ -200,6 +232,45 @@ class StreamDecoder:
self.frame_height: Optional[int] = None
self.frame_count: int = 0
# Frame callbacks - event-driven notification
self._frame_callbacks = []
self._callback_lock = threading.Lock()
def register_frame_callback(self, callback: Callable):
"""
Register a callback to be called when a new frame is decoded.
The callback will be called with the decoded frame tensor (GPU) as argument.
Callback signature: callback(frame: torch.Tensor) -> None
Args:
callback: Function to call when new frame arrives
"""
with self._callback_lock:
self._frame_callbacks.append(callback)
def unregister_frame_callback(self, callback: Callable):
"""
Unregister a frame callback.
Args:
callback: The callback function to remove
"""
with self._callback_lock:
if callback in self._frame_callbacks:
self._frame_callbacks.remove(callback)
def _mark_frame_free(self, buffer_index: int):
"""
Mark a frame as freed (called by FrameReference when it's no longer in use).
Args:
buffer_index: Index in the buffer for tracking purposes
"""
with self._buffer_lock:
# Remove from in-use tracking
self._in_use_frames = [f for f in self._in_use_frames if f.buffer_index != buffer_index]
def start(self):
"""Start the RTSP stream decoding in background thread"""
if self._decode_thread is not None and self._decode_thread.is_alive():
@ -278,6 +349,9 @@ class StreamDecoder:
def _decode_loop(self):
"""Main decode loop running in background thread"""
# Set the CUDA device for this thread
torch.cuda.set_device(self.gpu_id)
retry_count = 0
max_retries = 5
@ -319,11 +393,60 @@ class StreamDecoder:
if not decoded_frames:
continue
# Add frames to ring buffer (thread-safe)
# Add frames to ring buffer and fire callbacks
with self._buffer_lock:
for frame in decoded_frames:
self.frame_buffer.append(frame)
self.frame_count += 1
# Check for buffer overflow - discard oldest if needed
if len(self.frame_buffer) >= self.buffer_size:
# Check if oldest frame is still in use
if len(self._in_use_frames) > 0:
oldest_ref = self.frame_buffer[0] if len(self.frame_buffer) > 0 else None
if oldest_ref and not oldest_ref.is_freed():
# Force free the oldest frame to prevent overflow
print(f"[WARNING] Buffer overflow, force-freeing oldest frame (buffer_index={oldest_ref.buffer_index})")
oldest_ref.free()
# Deque will automatically remove oldest when at maxlen
# Convert to tensor
try:
# Convert DecodedFrame to PyTorch tensor using DLPack (zero-copy)
nv12_tensor = torch.from_dlpack(frame)
# Convert NV12 to RGB on GPU
if self.frame_height is not None and self.frame_width is not None:
rgb_tensor = nv12_to_rgb_gpu(nv12_tensor, self.frame_height, self.frame_width)
# CRITICAL: Clone the RGB tensor to break CUDA memory dependency
# The nv12_to_rgb_gpu creates a new tensor, but it still references
# the same CUDA context/stream. We need an independent copy.
rgb_tensor_cloned = rgb_tensor.clone()
# Create FrameReference object for C++-style memory management
# We don't keep the DecodedFrame to avoid conflicts with PyNvVideoCodec's
# internal frame pool - the clone is fully independent
buffer_index = self.frame_count
frame_ref = FrameReference(
rgb_tensor=rgb_tensor_cloned, # Independent cloned tensor
buffer_index=buffer_index,
decoder=self
)
# Add to buffer and in-use tracking
self.frame_buffer.append(frame_ref)
self._in_use_frames.append(frame_ref)
self.frame_count += 1
# Fire callbacks with the cloned RGB tensor from FrameReference
# The tensor is now independent of the DecodedFrame lifecycle
with self._callback_lock:
for callback in self._frame_callbacks:
try:
callback(frame_ref.rgb_tensor)
except Exception as e:
print(f"Error in frame callback: {e}")
except Exception as e:
print(f"Error converting frame for callback: {e}")
except Exception as e:
print(f"Error in decode loop for {self.rtsp_url}: {e}")
@ -351,35 +474,25 @@ class StreamDecoder:
Args:
index: Frame index in buffer (-1 for latest, -2 for second latest, etc.)
rgb: If True, convert NV12 to RGB. If False, return raw NV12 format.
rgb: If True, return RGB tensor. If False, not supported (returns None).
Returns:
torch.Tensor in CUDA memory (device tensor) or None if buffer empty
- If rgb=True: Shape (3, H, W) in RGB format, dtype uint8
- If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8
- If rgb=False: Not supported with FrameReference (returns None)
"""
with self._buffer_lock:
if len(self.frame_buffer) == 0:
return None
if not rgb:
print("Warning: NV12 format not supported with FrameReference, only RGB")
return None
try:
decoded_frame = self.frame_buffer[index]
# Convert DecodedFrame to PyTorch tensor using DLPack (zero-copy)
# This keeps the data in GPU memory
nv12_tensor = torch.from_dlpack(decoded_frame)
if not rgb:
# Return raw NV12 format
return nv12_tensor
# Convert NV12 to RGB on GPU
if self.frame_height is None or self.frame_width is None:
print("Frame dimensions not available")
return None
rgb_tensor = nv12_to_rgb_gpu(nv12_tensor, self.frame_height, self.frame_width)
return rgb_tensor
frame_ref = self.frame_buffer[index]
# Return the RGB tensor from FrameReference (cloned, independent)
return frame_ref.rgb_tensor
except (IndexError, Exception) as e:
print(f"Error getting frame: {e}")
@ -448,6 +561,39 @@ class StreamDecoder:
with self._buffer_lock:
return len(self.frame_buffer)
def get_all_frames(self, rgb: bool = True) -> list:
"""
Get all frames currently in the buffer as CUDA tensors.
This drains the buffer and returns all frames.
Args:
rgb: If True, return RGB tensors. If False, not supported (returns empty list).
Returns:
List of torch.Tensor objects in CUDA memory
"""
if not rgb:
print("Warning: NV12 format not supported with FrameReference, only RGB")
return []
frames = []
with self._buffer_lock:
# Get all frames from buffer
for frame_ref in self.frame_buffer:
try:
# Get RGB tensor from FrameReference
frames.append(frame_ref.rgb_tensor)
except Exception as e:
print(f"Error getting frame: {e}")
continue
# Clear the buffer after reading all frames and free all references
for frame_ref in self.frame_buffer:
frame_ref.free()
self.frame_buffer.clear()
return frames
def get_frame_count(self) -> int:
"""Get total number of frames decoded since start"""
return self.frame_count