event driven system
This commit is contained in:
parent
0c5f56c8a6
commit
3a47920186
10 changed files with 782 additions and 253 deletions
|
|
@ -1,17 +1,18 @@
|
|||
"""
|
||||
ModelController - Async batching layer with ping-pong buffers for inference.
|
||||
ModelController - Event-driven batching layer with ping-pong buffers for inference.
|
||||
|
||||
This module provides batched inference coordination using ping-pong circular buffers
|
||||
with force-switch timeout mechanism.
|
||||
with force-switch timeout mechanism using threading and callbacks.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import torch
|
||||
from typing import Dict, List, Optional, Callable, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import time
|
||||
import logging
|
||||
import queue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -43,7 +44,7 @@ class ModelController:
|
|||
Features:
|
||||
- Ping-pong circular buffers (BufferA/BufferB)
|
||||
- Force-switch timeout to prevent batch starvation
|
||||
- Async event-driven processing
|
||||
- Event-driven processing with callbacks
|
||||
- Thread-safe frame submission
|
||||
|
||||
Args:
|
||||
|
|
@ -90,14 +91,15 @@ class ModelController:
|
|||
self.buffer_a_state = BufferState.IDLE
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
|
||||
# Async coordination
|
||||
self.buffer_lock = asyncio.Lock()
|
||||
# Threading coordination
|
||||
self.buffer_lock = threading.RLock()
|
||||
self.last_submit_time = time.time()
|
||||
|
||||
# Tasks
|
||||
self.timeout_task: Optional[asyncio.Task] = None
|
||||
self.processor_task: Optional[asyncio.Task] = None
|
||||
# Threads
|
||||
self.timeout_thread: Optional[threading.Thread] = None
|
||||
self.processor_threads: Dict[str, threading.Thread] = {}
|
||||
self.running = False
|
||||
self.stop_event = threading.Event()
|
||||
|
||||
# Result callbacks (stream_id -> callback)
|
||||
self.result_callbacks: Dict[str, Callable] = {}
|
||||
|
|
@ -130,42 +132,46 @@ class ModelController:
|
|||
logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1")
|
||||
return 1
|
||||
|
||||
async def start(self):
|
||||
"""Start the controller background tasks"""
|
||||
def start(self):
|
||||
"""Start the controller background threads"""
|
||||
if self.running:
|
||||
logger.warning("ModelController already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.timeout_task = asyncio.create_task(self._timeout_monitor())
|
||||
self.processor_task = asyncio.create_task(self._batch_processor())
|
||||
self.stop_event.clear()
|
||||
|
||||
# Start timeout monitor thread
|
||||
self.timeout_thread = threading.Thread(target=self._timeout_monitor, daemon=True)
|
||||
self.timeout_thread.start()
|
||||
|
||||
# Start processor threads for each buffer
|
||||
self.processor_threads['A'] = threading.Thread(target=self._batch_processor, args=('A',), daemon=True)
|
||||
self.processor_threads['B'] = threading.Thread(target=self._batch_processor, args=('B',), daemon=True)
|
||||
self.processor_threads['A'].start()
|
||||
self.processor_threads['B'].start()
|
||||
|
||||
logger.info("ModelController started")
|
||||
|
||||
async def stop(self):
|
||||
def stop(self):
|
||||
"""Stop the controller and cleanup"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
logger.info("Stopping ModelController...")
|
||||
self.running = False
|
||||
self.stop_event.set()
|
||||
|
||||
# Cancel tasks
|
||||
if self.timeout_task:
|
||||
self.timeout_task.cancel()
|
||||
try:
|
||||
await self.timeout_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Wait for threads to finish
|
||||
if self.timeout_thread and self.timeout_thread.is_alive():
|
||||
self.timeout_thread.join(timeout=2.0)
|
||||
|
||||
if self.processor_task:
|
||||
self.processor_task.cancel()
|
||||
try:
|
||||
await self.processor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
for thread in self.processor_threads.values():
|
||||
if thread and thread.is_alive():
|
||||
thread.join(timeout=2.0)
|
||||
|
||||
# Process any remaining frames
|
||||
await self._process_remaining_buffers()
|
||||
self._process_remaining_buffers()
|
||||
logger.info("ModelController stopped")
|
||||
|
||||
def register_callback(self, stream_id: str, callback: Callable):
|
||||
|
|
@ -189,7 +195,7 @@ class ModelController:
|
|||
self.result_callbacks.pop(stream_id, None)
|
||||
logger.debug(f"Unregistered callback for stream: {stream_id}")
|
||||
|
||||
async def submit_frame(
|
||||
def submit_frame(
|
||||
self,
|
||||
stream_id: str,
|
||||
frame: torch.Tensor,
|
||||
|
|
@ -203,7 +209,7 @@ class ModelController:
|
|||
frame: GPU tensor (3, H, W) or (C, H, W)
|
||||
metadata: Optional metadata to attach to the frame
|
||||
"""
|
||||
async with self.buffer_lock:
|
||||
with self.buffer_lock:
|
||||
batch_frame = BatchFrame(
|
||||
stream_id=stream_id,
|
||||
frame=frame,
|
||||
|
|
@ -225,23 +231,21 @@ class ModelController:
|
|||
|
||||
# Check if we should immediately swap (batch full)
|
||||
if buffer_size >= self.batch_size:
|
||||
await self._try_swap_buffers()
|
||||
self._try_swap_buffers()
|
||||
|
||||
async def _timeout_monitor(self):
|
||||
def _timeout_monitor(self):
|
||||
"""Monitor force-switch timeout"""
|
||||
while self.running:
|
||||
await asyncio.sleep(0.01) # Check every 10ms
|
||||
|
||||
async with self.buffer_lock:
|
||||
while self.running and not self.stop_event.wait(0.01): # Check every 10ms
|
||||
with self.buffer_lock:
|
||||
time_since_submit = time.time() - self.last_submit_time
|
||||
|
||||
# Check if timeout expired and we have frames waiting
|
||||
if time_since_submit >= self.force_timeout:
|
||||
active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
||||
if len(active_buffer) > 0:
|
||||
await self._try_swap_buffers()
|
||||
self._try_swap_buffers()
|
||||
|
||||
async def _try_swap_buffers(self):
|
||||
def _try_swap_buffers(self):
|
||||
"""
|
||||
Attempt to swap ping-pong buffers.
|
||||
Only swaps if the inactive buffer is not currently processing.
|
||||
|
|
@ -266,20 +270,22 @@ class ModelController:
|
|||
|
||||
logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})")
|
||||
|
||||
async def _batch_processor(self):
|
||||
"""Background task that processes batches when available"""
|
||||
while self.running:
|
||||
await asyncio.sleep(0.001) # Check every 1ms
|
||||
def _batch_processor(self, buffer_name: str):
|
||||
"""Background thread that processes a specific buffer when available"""
|
||||
while self.running and not self.stop_event.is_set():
|
||||
time.sleep(0.001) # Check every 1ms
|
||||
|
||||
# Check if buffer A needs processing
|
||||
if self.buffer_a_state == BufferState.PROCESSING:
|
||||
await self._process_buffer("A")
|
||||
# Check if this buffer needs processing
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
should_process = self.buffer_a_state == BufferState.PROCESSING
|
||||
else:
|
||||
should_process = self.buffer_b_state == BufferState.PROCESSING
|
||||
|
||||
# Check if buffer B needs processing
|
||||
if self.buffer_b_state == BufferState.PROCESSING:
|
||||
await self._process_buffer("B")
|
||||
if should_process:
|
||||
self._process_buffer(buffer_name)
|
||||
|
||||
async def _process_buffer(self, buffer_name: str):
|
||||
def _process_buffer(self, buffer_name: str):
|
||||
"""
|
||||
Process a buffer through inference.
|
||||
|
||||
|
|
@ -287,7 +293,7 @@ class ModelController:
|
|||
buffer_name: "A" or "B"
|
||||
"""
|
||||
# Extract buffer to process
|
||||
async with self.buffer_lock:
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
batch = self.buffer_a.copy()
|
||||
self.buffer_a.clear()
|
||||
|
|
@ -297,7 +303,7 @@ class ModelController:
|
|||
|
||||
if len(batch) == 0:
|
||||
# Mark as idle and return
|
||||
async with self.buffer_lock:
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
else:
|
||||
|
|
@ -307,7 +313,7 @@ class ModelController:
|
|||
# Process batch (outside lock to allow concurrent submissions)
|
||||
try:
|
||||
start_time = time.time()
|
||||
results = await self._run_batch_inference(batch)
|
||||
results = self._run_batch_inference(batch)
|
||||
inference_time = time.time() - start_time
|
||||
|
||||
# Update statistics
|
||||
|
|
@ -323,27 +329,24 @@ class ModelController:
|
|||
for batch_frame, result in zip(batch, results):
|
||||
callback = self.result_callbacks.get(batch_frame.stream_id)
|
||||
if callback:
|
||||
# Schedule callback asynchronously
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
asyncio.create_task(callback(result))
|
||||
else:
|
||||
# Run sync callback in executor to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.call_soon(lambda cb=callback, r=result: cb(r))
|
||||
# Call callback directly (synchronous)
|
||||
try:
|
||||
callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in callback for {batch_frame.stream_id}: {e}", exc_info=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing batch: {e}", exc_info=True)
|
||||
# TODO: Emit error events to streams
|
||||
|
||||
finally:
|
||||
# Mark buffer as idle
|
||||
async with self.buffer_lock:
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
else:
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
|
||||
async def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run inference on a batch of frames.
|
||||
|
||||
|
|
@ -353,17 +356,15 @@ class ModelController:
|
|||
Returns:
|
||||
List of detection results (one per frame)
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Check if model supports batching
|
||||
if self.model_batch_size == 1:
|
||||
# Process frames one at a time for batch_size=1 models
|
||||
return await self._run_sequential_inference(batch, loop)
|
||||
return self._run_sequential_inference(batch)
|
||||
else:
|
||||
# Use true batching for models that support it
|
||||
return await self._run_batched_inference(batch, loop)
|
||||
return self._run_batched_inference(batch)
|
||||
|
||||
async def _run_sequential_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
|
||||
def _run_sequential_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||
"""Run inference sequentially for batch_size=1 models"""
|
||||
results = []
|
||||
|
||||
|
|
@ -376,13 +377,10 @@ class ModelController:
|
|||
processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame
|
||||
|
||||
# Run inference for this frame
|
||||
outputs = await loop.run_in_executor(
|
||||
None,
|
||||
lambda p=processed: self.model_repository.infer(
|
||||
self.model_id,
|
||||
{"images": p},
|
||||
synchronize=True
|
||||
)
|
||||
outputs = self.model_repository.infer(
|
||||
self.model_id,
|
||||
{"images": processed},
|
||||
synchronize=True
|
||||
)
|
||||
|
||||
# Postprocess
|
||||
|
|
@ -406,7 +404,7 @@ class ModelController:
|
|||
|
||||
return results
|
||||
|
||||
async def _run_batched_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
|
||||
def _run_batched_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||
"""Run true batched inference for models that support it"""
|
||||
# Preprocess frames (on GPU)
|
||||
preprocessed = []
|
||||
|
|
@ -434,13 +432,10 @@ class ModelController:
|
|||
batch = batch[:self.model_batch_size]
|
||||
|
||||
# Run inference
|
||||
outputs = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model_repository.infer(
|
||||
self.model_id,
|
||||
{"images": batch_tensor},
|
||||
synchronize=True
|
||||
)
|
||||
outputs = self.model_repository.infer(
|
||||
self.model_id,
|
||||
{"images": batch_tensor},
|
||||
synchronize=True
|
||||
)
|
||||
|
||||
# Postprocess results (split batch back to individual results)
|
||||
|
|
@ -472,14 +467,14 @@ class ModelController:
|
|||
|
||||
return results
|
||||
|
||||
async def _process_remaining_buffers(self):
|
||||
def _process_remaining_buffers(self):
|
||||
"""Process any remaining frames in buffers during shutdown"""
|
||||
if len(self.buffer_a) > 0:
|
||||
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
|
||||
await self._process_buffer("A")
|
||||
self._process_buffer("A")
|
||||
if len(self.buffer_b) > 0:
|
||||
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
|
||||
await self._process_buffer("B")
|
||||
self._process_buffer("B")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get current buffer statistics"""
|
||||
|
|
|
|||
|
|
@ -1,16 +1,17 @@
|
|||
"""
|
||||
StreamConnectionManager - Async orchestration for stream processing with batched inference.
|
||||
StreamConnectionManager - Event-driven orchestration for stream processing with batched inference.
|
||||
|
||||
This module provides high-level connection management for multiple RTSP streams,
|
||||
coordinating decoders, batched inference, and tracking with an event-driven API.
|
||||
coordinating decoders, batched inference, and tracking with callbacks and threading.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, Optional, Callable, AsyncIterator, Tuple, Any, List
|
||||
from typing import Dict, Optional, Callable, Tuple, Any, List
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import logging
|
||||
import queue
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -44,7 +45,7 @@ class StreamConnection:
|
|||
"""
|
||||
Represents a single stream connection with event emission.
|
||||
|
||||
This class wraps a StreamDecoder, polls frames asynchronously, submits them
|
||||
This class wraps a StreamDecoder, polls frames in a thread, submits them
|
||||
to the ModelController for batched inference, runs tracking, and emits results
|
||||
via queues or callbacks.
|
||||
|
||||
|
|
@ -75,15 +76,19 @@ class StreamConnection:
|
|||
self.last_frame_time = 0.0
|
||||
|
||||
# Event emission
|
||||
self.result_queue: asyncio.Queue[TrackingResult] = asyncio.Queue()
|
||||
self.error_queue: asyncio.Queue[Exception] = asyncio.Queue()
|
||||
self.result_queue: queue.Queue[TrackingResult] = queue.Queue()
|
||||
self.error_queue: queue.Queue[Exception] = queue.Queue()
|
||||
|
||||
# Tasks
|
||||
self.poller_task: Optional[asyncio.Task] = None
|
||||
# Event-driven state
|
||||
self.running = False
|
||||
|
||||
async def start(self):
|
||||
"""Start the connection (decoder and frame polling)"""
|
||||
def start(self):
|
||||
"""Start the connection (decoder with frame callback)"""
|
||||
self.running = True
|
||||
|
||||
# Register callback for frame events from decoder
|
||||
self.decoder.register_frame_callback(self._on_frame_decoded)
|
||||
|
||||
# Start decoder (runs in background thread)
|
||||
self.decoder.start()
|
||||
|
||||
|
|
@ -93,7 +98,7 @@ class StreamConnection:
|
|||
elapsed = 0.0
|
||||
|
||||
while elapsed < max_wait:
|
||||
await asyncio.sleep(wait_interval)
|
||||
time.sleep(wait_interval)
|
||||
elapsed += wait_interval
|
||||
|
||||
if self.decoder.is_connected():
|
||||
|
|
@ -105,21 +110,13 @@ class StreamConnection:
|
|||
logger.warning(f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying...")
|
||||
self.status = ConnectionStatus.CONNECTING
|
||||
|
||||
# Start frame polling task
|
||||
self.running = True
|
||||
self.poller_task = asyncio.create_task(self._frame_poller())
|
||||
|
||||
async def stop(self):
|
||||
def stop(self):
|
||||
"""Stop the connection and cleanup"""
|
||||
logger.info(f"Stopping stream {self.stream_id}...")
|
||||
self.running = False
|
||||
|
||||
if self.poller_task:
|
||||
self.poller_task.cancel()
|
||||
try:
|
||||
await self.poller_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Unregister frame callback
|
||||
self.decoder.unregister_frame_callback(self._on_frame_decoded)
|
||||
|
||||
# Stop decoder
|
||||
self.decoder.stop()
|
||||
|
|
@ -130,55 +127,45 @@ class StreamConnection:
|
|||
self.status = ConnectionStatus.DISCONNECTED
|
||||
logger.info(f"Stream {self.stream_id} stopped")
|
||||
|
||||
async def _frame_poller(self):
|
||||
"""Poll frames from threaded decoder and submit to model controller"""
|
||||
last_decoder_frame_count = -1
|
||||
def _on_frame_decoded(self, frame: torch.Tensor):
|
||||
"""
|
||||
Event handler called by decoder when a new frame is decoded.
|
||||
This is the event-driven replacement for polling.
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Get current decoder frame count (no data transfer, just counter)
|
||||
decoder_frame_count = self.decoder.get_frame_count()
|
||||
Args:
|
||||
frame: RGB frame tensor on GPU (3, H, W)
|
||||
"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
# Check if decoder has a new frame (avoid reprocessing same frame)
|
||||
if decoder_frame_count > last_decoder_frame_count:
|
||||
# Poll frame from decoder (zero-copy - stays in VRAM)
|
||||
frame = self.decoder.get_latest_frame(rgb=True)
|
||||
try:
|
||||
self.last_frame_time = time.time()
|
||||
self.frame_count += 1
|
||||
|
||||
if frame is not None:
|
||||
last_decoder_frame_count = decoder_frame_count
|
||||
self.last_frame_time = time.time()
|
||||
self.frame_count += 1
|
||||
# Submit to model controller for batched inference
|
||||
self.model_controller.submit_frame(
|
||||
stream_id=self.stream_id,
|
||||
frame=frame,
|
||||
metadata={
|
||||
"frame_number": self.frame_count,
|
||||
"shape": tuple(frame.shape),
|
||||
}
|
||||
)
|
||||
|
||||
# Submit to model controller for batched inference
|
||||
await self.model_controller.submit_frame(
|
||||
stream_id=self.stream_id,
|
||||
frame=frame,
|
||||
metadata={
|
||||
"frame_number": self.frame_count,
|
||||
"shape": tuple(frame.shape),
|
||||
}
|
||||
)
|
||||
# Update connection status based on decoder status
|
||||
if self.decoder.is_connected() and self.status != ConnectionStatus.CONNECTED:
|
||||
logger.info(f"Stream {self.stream_id} reconnected")
|
||||
self.status = ConnectionStatus.CONNECTED
|
||||
elif not self.decoder.is_connected() and self.status == ConnectionStatus.CONNECTED:
|
||||
logger.warning(f"Stream {self.stream_id} disconnected")
|
||||
self.status = ConnectionStatus.DISCONNECTED
|
||||
|
||||
# Check decoder status
|
||||
if not self.decoder.is_connected():
|
||||
if self.status == ConnectionStatus.CONNECTED:
|
||||
logger.warning(f"Stream {self.stream_id} disconnected")
|
||||
self.status = ConnectionStatus.DISCONNECTED
|
||||
# Decoder will auto-reconnect, just update status
|
||||
await asyncio.sleep(1.0)
|
||||
if self.decoder.is_connected():
|
||||
logger.info(f"Stream {self.stream_id} reconnected")
|
||||
self.status = ConnectionStatus.CONNECTED
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame for {self.stream_id}: {e}", exc_info=True)
|
||||
self.error_queue.put(e)
|
||||
self.status = ConnectionStatus.ERROR
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in frame poller for {self.stream_id}: {e}", exc_info=True)
|
||||
await self.error_queue.put(e)
|
||||
self.status = ConnectionStatus.ERROR
|
||||
|
||||
# Sleep until next poll
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
|
||||
async def _handle_inference_result(self, result: Dict[str, Any]):
|
||||
def _handle_inference_result(self, result: Dict[str, Any]):
|
||||
"""
|
||||
Callback invoked by ModelController when inference is done.
|
||||
Runs tracking and emits final result.
|
||||
|
|
@ -190,12 +177,8 @@ class StreamConnection:
|
|||
# Extract detections
|
||||
detections = result["detections"]
|
||||
|
||||
# Run tracking (this is sync, so run in executor)
|
||||
loop = asyncio.get_event_loop()
|
||||
tracked_objects = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._run_tracking_sync(detections)
|
||||
)
|
||||
# Run tracking (synchronous)
|
||||
tracked_objects = self._run_tracking_sync(detections)
|
||||
|
||||
# Create tracking result
|
||||
tracking_result = TrackingResult(
|
||||
|
|
@ -208,11 +191,11 @@ class StreamConnection:
|
|||
)
|
||||
|
||||
# Emit to result queue
|
||||
await self.result_queue.put(tracking_result)
|
||||
self.result_queue.put(tracking_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
|
||||
await self.error_queue.put(e)
|
||||
self.error_queue.put(e)
|
||||
|
||||
def _run_tracking_sync(self, detections, min_confidence=0.7):
|
||||
"""
|
||||
|
|
@ -246,12 +229,12 @@ class StreamConnection:
|
|||
# Update tracker with detections (lightweight, no model dependency!)
|
||||
return self.tracking_controller.update(detection_list)
|
||||
|
||||
async def tracking_results(self) -> AsyncIterator[TrackingResult]:
|
||||
def tracking_results(self):
|
||||
"""
|
||||
Async generator for tracking results.
|
||||
Generator for tracking results (blocking iterator).
|
||||
|
||||
Usage:
|
||||
async for result in connection.tracking_results():
|
||||
for result in connection.tracking_results():
|
||||
print(result.tracked_objects)
|
||||
|
||||
Yields:
|
||||
|
|
@ -259,23 +242,23 @@ class StreamConnection:
|
|||
"""
|
||||
while self.running or not self.result_queue.empty():
|
||||
try:
|
||||
result = await asyncio.wait_for(self.result_queue.get(), timeout=1.0)
|
||||
result = self.result_queue.get(timeout=1.0)
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
async def errors(self) -> AsyncIterator[Exception]:
|
||||
def errors(self):
|
||||
"""
|
||||
Async generator for errors.
|
||||
Generator for errors (blocking iterator).
|
||||
|
||||
Yields:
|
||||
Exception objects as they occur
|
||||
"""
|
||||
while self.running or not self.error_queue.empty():
|
||||
try:
|
||||
error = await asyncio.wait_for(self.error_queue.get(), timeout=1.0)
|
||||
error = self.error_queue.get(timeout=1.0)
|
||||
yield error
|
||||
except asyncio.TimeoutError:
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
|
|
@ -342,7 +325,7 @@ class StreamConnectionManager:
|
|||
# State
|
||||
self.initialized = False
|
||||
|
||||
async def initialize(
|
||||
def initialize(
|
||||
self,
|
||||
model_path: str,
|
||||
model_id: str = "detector",
|
||||
|
|
@ -368,18 +351,14 @@ class StreamConnectionManager:
|
|||
"""
|
||||
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
|
||||
|
||||
# Load model
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model_repository.load_model(
|
||||
model_id,
|
||||
model_path,
|
||||
num_contexts=num_contexts,
|
||||
pt_input_shapes=pt_input_shapes,
|
||||
pt_precision=pt_precision,
|
||||
**pt_conversion_kwargs
|
||||
)
|
||||
# Load model (synchronous)
|
||||
self.model_repository.load_model(
|
||||
model_id,
|
||||
model_path,
|
||||
num_contexts=num_contexts,
|
||||
pt_input_shapes=pt_input_shapes,
|
||||
pt_precision=pt_precision,
|
||||
**pt_conversion_kwargs
|
||||
)
|
||||
logger.info(f"Loaded model {model_id} from {model_path}")
|
||||
|
||||
|
|
@ -392,7 +371,7 @@ class StreamConnectionManager:
|
|||
preprocess_fn=preprocess_fn,
|
||||
postprocess_fn=postprocess_fn,
|
||||
)
|
||||
await self.model_controller.start()
|
||||
self.model_controller.start()
|
||||
|
||||
# Don't create a shared tracking controller here
|
||||
# Each stream will get its own tracking controller to avoid track accumulation
|
||||
|
|
@ -402,7 +381,7 @@ class StreamConnectionManager:
|
|||
self.initialized = True
|
||||
logger.info("StreamConnectionManager initialized successfully")
|
||||
|
||||
async def connect_stream(
|
||||
def connect_stream(
|
||||
self,
|
||||
rtsp_url: str,
|
||||
stream_id: Optional[str] = None,
|
||||
|
|
@ -416,8 +395,8 @@ class StreamConnectionManager:
|
|||
Args:
|
||||
rtsp_url: RTSP stream URL
|
||||
stream_id: Optional stream identifier (auto-generated if not provided)
|
||||
on_tracking_result: Optional callback for tracking results (sync or async)
|
||||
on_error: Optional callback for errors (sync or async)
|
||||
on_tracking_result: Optional callback for tracking results (synchronous)
|
||||
on_error: Optional callback for errors (synchronous)
|
||||
buffer_size: Decoder buffer size (default: 30)
|
||||
|
||||
Returns:
|
||||
|
|
@ -466,22 +445,30 @@ class StreamConnectionManager:
|
|||
)
|
||||
|
||||
# Start connection
|
||||
await connection.start()
|
||||
connection.start()
|
||||
|
||||
# Store connection
|
||||
self.connections[stream_id] = connection
|
||||
|
||||
# Set up user callbacks if provided
|
||||
# Set up user callbacks if provided (run in separate threads)
|
||||
if on_tracking_result:
|
||||
asyncio.create_task(self._forward_results(connection, on_tracking_result))
|
||||
threading.Thread(
|
||||
target=self._forward_results,
|
||||
args=(connection, on_tracking_result),
|
||||
daemon=True
|
||||
).start()
|
||||
|
||||
if on_error:
|
||||
asyncio.create_task(self._forward_errors(connection, on_error))
|
||||
threading.Thread(
|
||||
target=self._forward_errors,
|
||||
args=(connection, on_error),
|
||||
daemon=True
|
||||
).start()
|
||||
|
||||
logger.info(f"Stream {stream_id} connected successfully")
|
||||
return connection
|
||||
|
||||
async def disconnect_stream(self, stream_id: str):
|
||||
def disconnect_stream(self, stream_id: str):
|
||||
"""
|
||||
Disconnect and cleanup a stream.
|
||||
|
||||
|
|
@ -490,27 +477,27 @@ class StreamConnectionManager:
|
|||
"""
|
||||
connection = self.connections.get(stream_id)
|
||||
if connection:
|
||||
await connection.stop()
|
||||
connection.stop()
|
||||
del self.connections[stream_id]
|
||||
logger.info(f"Stream {stream_id} disconnected")
|
||||
|
||||
async def disconnect_all(self):
|
||||
def disconnect_all(self):
|
||||
"""Disconnect all streams"""
|
||||
logger.info("Disconnecting all streams...")
|
||||
stream_ids = list(self.connections.keys())
|
||||
for stream_id in stream_ids:
|
||||
await self.disconnect_stream(stream_id)
|
||||
self.disconnect_stream(stream_id)
|
||||
|
||||
async def shutdown(self):
|
||||
def shutdown(self):
|
||||
"""Shutdown the manager and cleanup all resources"""
|
||||
logger.info("Shutting down StreamConnectionManager...")
|
||||
|
||||
# Disconnect all streams
|
||||
await self.disconnect_all()
|
||||
self.disconnect_all()
|
||||
|
||||
# Stop model controller
|
||||
if self.model_controller:
|
||||
await self.model_controller.stop()
|
||||
self.model_controller.stop()
|
||||
|
||||
# Note: Model repository cleanup is sync and may cause segfaults
|
||||
# Leaving cleanup to garbage collection for now
|
||||
|
|
@ -518,37 +505,31 @@ class StreamConnectionManager:
|
|||
self.initialized = False
|
||||
logger.info("StreamConnectionManager shutdown complete")
|
||||
|
||||
async def _forward_results(self, connection: StreamConnection, callback: Callable):
|
||||
def _forward_results(self, connection: StreamConnection, callback: Callable):
|
||||
"""
|
||||
Forward results from connection to user callback.
|
||||
|
||||
Args:
|
||||
connection: StreamConnection to listen to
|
||||
callback: User callback (sync or async)
|
||||
callback: User callback (synchronous)
|
||||
"""
|
||||
try:
|
||||
async for result in connection.tracking_results():
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(result)
|
||||
else:
|
||||
callback(result)
|
||||
for result in connection.tracking_results():
|
||||
callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True)
|
||||
|
||||
async def _forward_errors(self, connection: StreamConnection, callback: Callable):
|
||||
def _forward_errors(self, connection: StreamConnection, callback: Callable):
|
||||
"""
|
||||
Forward errors from connection to user callback.
|
||||
|
||||
Args:
|
||||
connection: StreamConnection to listen to
|
||||
callback: User callback (sync or async)
|
||||
callback: User callback (synchronous)
|
||||
"""
|
||||
try:
|
||||
async for error in connection.errors():
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(error)
|
||||
else:
|
||||
callback(error)
|
||||
for error in connection.errors():
|
||||
callback(error)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import threading
|
||||
from typing import Optional
|
||||
from typing import Optional, Callable
|
||||
from collections import deque
|
||||
from enum import Enum
|
||||
import torch
|
||||
|
|
@ -10,6 +10,35 @@ from cuda.bindings import driver as cuda_driver
|
|||
from .jpeg_encoder import encode_frame_to_jpeg
|
||||
|
||||
|
||||
class FrameReference:
|
||||
"""
|
||||
CPU-side reference object for a GPU frame.
|
||||
|
||||
This object holds a cloned RGB tensor that is independent of PyNvVideoCodec's
|
||||
DecodedFrame lifecycle. We don't keep the DecodedFrame to avoid conflicts
|
||||
with PyNvVideoCodec's internal frame pool management.
|
||||
"""
|
||||
def __init__(self, rgb_tensor: torch.Tensor, buffer_index: int, decoder):
|
||||
self.rgb_tensor = rgb_tensor # Cloned RGB tensor (independent copy)
|
||||
self.buffer_index = buffer_index
|
||||
self.decoder = decoder # Reference to decoder for marking as free
|
||||
self._freed = False
|
||||
|
||||
def free(self):
|
||||
"""Mark this frame as no longer in use"""
|
||||
if not self._freed:
|
||||
self._freed = True
|
||||
self.decoder._mark_frame_free(self.buffer_index)
|
||||
|
||||
def is_freed(self) -> bool:
|
||||
"""Check if this frame has been freed"""
|
||||
return self._freed
|
||||
|
||||
def __del__(self):
|
||||
"""Auto-free on garbage collection"""
|
||||
self.free()
|
||||
|
||||
|
||||
def nv12_to_rgb_gpu(nv12_tensor: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
Convert NV12 format to RGB on GPU using PyTorch operations.
|
||||
|
|
@ -183,10 +212,13 @@ class StreamDecoder:
|
|||
self.status = ConnectionStatus.DISCONNECTED
|
||||
self._status_lock = threading.Lock()
|
||||
|
||||
# Frame buffer (ring buffer) - stores CUDA device pointers
|
||||
# Frame buffer (ring buffer) - stores FrameReference objects
|
||||
self.frame_buffer = deque(maxlen=buffer_size)
|
||||
self._buffer_lock = threading.RLock()
|
||||
|
||||
# Track which buffer slots are in use (list of FrameReference objects)
|
||||
self._in_use_frames = [] # List of FrameReference objects currently held by callbacks
|
||||
|
||||
# Decoder and container instances
|
||||
self.decoder = None
|
||||
self.container = None
|
||||
|
|
@ -200,6 +232,45 @@ class StreamDecoder:
|
|||
self.frame_height: Optional[int] = None
|
||||
self.frame_count: int = 0
|
||||
|
||||
# Frame callbacks - event-driven notification
|
||||
self._frame_callbacks = []
|
||||
self._callback_lock = threading.Lock()
|
||||
|
||||
def register_frame_callback(self, callback: Callable):
|
||||
"""
|
||||
Register a callback to be called when a new frame is decoded.
|
||||
|
||||
The callback will be called with the decoded frame tensor (GPU) as argument.
|
||||
Callback signature: callback(frame: torch.Tensor) -> None
|
||||
|
||||
Args:
|
||||
callback: Function to call when new frame arrives
|
||||
"""
|
||||
with self._callback_lock:
|
||||
self._frame_callbacks.append(callback)
|
||||
|
||||
def unregister_frame_callback(self, callback: Callable):
|
||||
"""
|
||||
Unregister a frame callback.
|
||||
|
||||
Args:
|
||||
callback: The callback function to remove
|
||||
"""
|
||||
with self._callback_lock:
|
||||
if callback in self._frame_callbacks:
|
||||
self._frame_callbacks.remove(callback)
|
||||
|
||||
def _mark_frame_free(self, buffer_index: int):
|
||||
"""
|
||||
Mark a frame as freed (called by FrameReference when it's no longer in use).
|
||||
|
||||
Args:
|
||||
buffer_index: Index in the buffer for tracking purposes
|
||||
"""
|
||||
with self._buffer_lock:
|
||||
# Remove from in-use tracking
|
||||
self._in_use_frames = [f for f in self._in_use_frames if f.buffer_index != buffer_index]
|
||||
|
||||
def start(self):
|
||||
"""Start the RTSP stream decoding in background thread"""
|
||||
if self._decode_thread is not None and self._decode_thread.is_alive():
|
||||
|
|
@ -278,6 +349,9 @@ class StreamDecoder:
|
|||
|
||||
def _decode_loop(self):
|
||||
"""Main decode loop running in background thread"""
|
||||
# Set the CUDA device for this thread
|
||||
torch.cuda.set_device(self.gpu_id)
|
||||
|
||||
retry_count = 0
|
||||
max_retries = 5
|
||||
|
||||
|
|
@ -319,11 +393,60 @@ class StreamDecoder:
|
|||
if not decoded_frames:
|
||||
continue
|
||||
|
||||
# Add frames to ring buffer (thread-safe)
|
||||
# Add frames to ring buffer and fire callbacks
|
||||
with self._buffer_lock:
|
||||
for frame in decoded_frames:
|
||||
self.frame_buffer.append(frame)
|
||||
self.frame_count += 1
|
||||
# Check for buffer overflow - discard oldest if needed
|
||||
if len(self.frame_buffer) >= self.buffer_size:
|
||||
# Check if oldest frame is still in use
|
||||
if len(self._in_use_frames) > 0:
|
||||
oldest_ref = self.frame_buffer[0] if len(self.frame_buffer) > 0 else None
|
||||
if oldest_ref and not oldest_ref.is_freed():
|
||||
# Force free the oldest frame to prevent overflow
|
||||
print(f"[WARNING] Buffer overflow, force-freeing oldest frame (buffer_index={oldest_ref.buffer_index})")
|
||||
oldest_ref.free()
|
||||
|
||||
# Deque will automatically remove oldest when at maxlen
|
||||
|
||||
# Convert to tensor
|
||||
try:
|
||||
# Convert DecodedFrame to PyTorch tensor using DLPack (zero-copy)
|
||||
nv12_tensor = torch.from_dlpack(frame)
|
||||
|
||||
# Convert NV12 to RGB on GPU
|
||||
if self.frame_height is not None and self.frame_width is not None:
|
||||
rgb_tensor = nv12_to_rgb_gpu(nv12_tensor, self.frame_height, self.frame_width)
|
||||
|
||||
# CRITICAL: Clone the RGB tensor to break CUDA memory dependency
|
||||
# The nv12_to_rgb_gpu creates a new tensor, but it still references
|
||||
# the same CUDA context/stream. We need an independent copy.
|
||||
rgb_tensor_cloned = rgb_tensor.clone()
|
||||
|
||||
# Create FrameReference object for C++-style memory management
|
||||
# We don't keep the DecodedFrame to avoid conflicts with PyNvVideoCodec's
|
||||
# internal frame pool - the clone is fully independent
|
||||
buffer_index = self.frame_count
|
||||
frame_ref = FrameReference(
|
||||
rgb_tensor=rgb_tensor_cloned, # Independent cloned tensor
|
||||
buffer_index=buffer_index,
|
||||
decoder=self
|
||||
)
|
||||
|
||||
# Add to buffer and in-use tracking
|
||||
self.frame_buffer.append(frame_ref)
|
||||
self._in_use_frames.append(frame_ref)
|
||||
self.frame_count += 1
|
||||
|
||||
# Fire callbacks with the cloned RGB tensor from FrameReference
|
||||
# The tensor is now independent of the DecodedFrame lifecycle
|
||||
with self._callback_lock:
|
||||
for callback in self._frame_callbacks:
|
||||
try:
|
||||
callback(frame_ref.rgb_tensor)
|
||||
except Exception as e:
|
||||
print(f"Error in frame callback: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error converting frame for callback: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in decode loop for {self.rtsp_url}: {e}")
|
||||
|
|
@ -351,35 +474,25 @@ class StreamDecoder:
|
|||
|
||||
Args:
|
||||
index: Frame index in buffer (-1 for latest, -2 for second latest, etc.)
|
||||
rgb: If True, convert NV12 to RGB. If False, return raw NV12 format.
|
||||
rgb: If True, return RGB tensor. If False, not supported (returns None).
|
||||
|
||||
Returns:
|
||||
torch.Tensor in CUDA memory (device tensor) or None if buffer empty
|
||||
- If rgb=True: Shape (3, H, W) in RGB format, dtype uint8
|
||||
- If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8
|
||||
- If rgb=False: Not supported with FrameReference (returns None)
|
||||
"""
|
||||
with self._buffer_lock:
|
||||
if len(self.frame_buffer) == 0:
|
||||
return None
|
||||
|
||||
if not rgb:
|
||||
print("Warning: NV12 format not supported with FrameReference, only RGB")
|
||||
return None
|
||||
|
||||
try:
|
||||
decoded_frame = self.frame_buffer[index]
|
||||
|
||||
# Convert DecodedFrame to PyTorch tensor using DLPack (zero-copy)
|
||||
# This keeps the data in GPU memory
|
||||
nv12_tensor = torch.from_dlpack(decoded_frame)
|
||||
|
||||
if not rgb:
|
||||
# Return raw NV12 format
|
||||
return nv12_tensor
|
||||
|
||||
# Convert NV12 to RGB on GPU
|
||||
if self.frame_height is None or self.frame_width is None:
|
||||
print("Frame dimensions not available")
|
||||
return None
|
||||
|
||||
rgb_tensor = nv12_to_rgb_gpu(nv12_tensor, self.frame_height, self.frame_width)
|
||||
return rgb_tensor
|
||||
frame_ref = self.frame_buffer[index]
|
||||
# Return the RGB tensor from FrameReference (cloned, independent)
|
||||
return frame_ref.rgb_tensor
|
||||
|
||||
except (IndexError, Exception) as e:
|
||||
print(f"Error getting frame: {e}")
|
||||
|
|
@ -448,6 +561,39 @@ class StreamDecoder:
|
|||
with self._buffer_lock:
|
||||
return len(self.frame_buffer)
|
||||
|
||||
def get_all_frames(self, rgb: bool = True) -> list:
|
||||
"""
|
||||
Get all frames currently in the buffer as CUDA tensors.
|
||||
This drains the buffer and returns all frames.
|
||||
|
||||
Args:
|
||||
rgb: If True, return RGB tensors. If False, not supported (returns empty list).
|
||||
|
||||
Returns:
|
||||
List of torch.Tensor objects in CUDA memory
|
||||
"""
|
||||
if not rgb:
|
||||
print("Warning: NV12 format not supported with FrameReference, only RGB")
|
||||
return []
|
||||
|
||||
frames = []
|
||||
with self._buffer_lock:
|
||||
# Get all frames from buffer
|
||||
for frame_ref in self.frame_buffer:
|
||||
try:
|
||||
# Get RGB tensor from FrameReference
|
||||
frames.append(frame_ref.rgb_tensor)
|
||||
except Exception as e:
|
||||
print(f"Error getting frame: {e}")
|
||||
continue
|
||||
|
||||
# Clear the buffer after reading all frames and free all references
|
||||
for frame_ref in self.frame_buffer:
|
||||
frame_ref.free()
|
||||
self.frame_buffer.clear()
|
||||
|
||||
return frames
|
||||
|
||||
def get_frame_count(self) -> int:
|
||||
"""Get total number of frames decoded since start"""
|
||||
return self.frame_count
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue