1108 lines
38 KiB
Markdown
1108 lines
38 KiB
Markdown
# Event-Driven Stream Processing Architecture with Batching
|
|
|
|
## Overview
|
|
|
|
This document describes the AsyncIO-based event-driven architecture for connecting stream decoders to models and tracking, with support for batched inference using ping-pong circular buffers.
|
|
|
|
## Architecture Diagram
|
|
|
|
```
|
|
┌─────────────────────────────────────────────────────────────────┐
|
|
│ StreamConnectionManager │
|
|
│ - Manages multiple stream connections │
|
|
│ - Routes events to user callbacks/generators │
|
|
│ - Coordinates ModelController and TrackingController │
|
|
└─────────────────────────────────────────────────────────────────┘
|
|
│
|
|
├──────────────────┬──────────────────┐
|
|
▼ ▼ ▼
|
|
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
|
│ StreamConnection│ │ StreamConnection│ │ StreamConnection│
|
|
│ (Stream 1) │ │ (Stream 2) │ │ (Stream N) │
|
|
│ │ │ │ │ │
|
|
│ - StreamDecoder │ │ - StreamDecoder │ │ - StreamDecoder │
|
|
│ - Frame Poller │ │ - Frame Poller │ │ - Frame Poller │
|
|
│ - Event Emitter │ │ - Event Emitter │ │ - Event Emitter │
|
|
└─────────────────┘ └─────────────────┘ └─────────────────┘
|
|
│ │ │
|
|
└──────────────────┴──────────────────┘
|
|
│
|
|
▼
|
|
┌─────────────────────────────────────┐
|
|
│ ModelController │
|
|
│ ┌────────────┐ ┌────────────┐ │
|
|
│ │ Buffer A │ │ Buffer B │ │
|
|
│ │ (Active) │ │(Processing)│ │
|
|
│ │ [frame1] │ │ [frame9] │ │
|
|
│ │ [frame2] │ │ [frame10] │ │
|
|
│ │ [frame3] │ │ [...] │ │
|
|
│ └────────────┘ └────────────┘ │
|
|
│ │
|
|
│ - Batch accumulation │
|
|
│ - Force timeout monitor │
|
|
│ - Ping-pong switching │
|
|
└─────────────────────────────────────┘
|
|
│
|
|
┌────────────┴────────────┐
|
|
▼ ▼
|
|
┌─────────────────────┐ ┌─────────────────────┐
|
|
│ TensorRTModelRepo │ │ TrackingController │
|
|
│ - Batched inference │ │ - Track association │
|
|
│ - Context pooling │ │ - Track management │
|
|
└─────────────────────┘ └─────────────────────┘
|
|
│
|
|
▼
|
|
┌─────────────────────────┐
|
|
│ User Callbacks/Queues │
|
|
│ - on_tracking_result │
|
|
│ - on_detections │
|
|
│ - on_error │
|
|
└─────────────────────────┘
|
|
```
|
|
|
|
## Component Details
|
|
|
|
### 1. ModelController (Async Batching Layer)
|
|
|
|
**Responsibilities:**
|
|
- Accumulate frames from multiple streams into batches
|
|
- Manage ping-pong buffers (BufferA/BufferB)
|
|
- Monitor force-switch timeout
|
|
- Execute batched inference
|
|
- Route results back to streams
|
|
|
|
**Ping-Pong Buffer Logic:**
|
|
- **BufferA (Active)**: Accumulates incoming frames
|
|
- **BufferB (Processing)**: Being processed through inference
|
|
- **Switch Triggers:**
|
|
1. Active buffer reaches `batch_size` → immediate swap
|
|
2. `force_timeout` expires AND processing buffer is idle → force swap
|
|
3. Never switch if processing buffer is busy
|
|
|
|
### 2. StreamConnectionManager
|
|
|
|
**Responsibilities:**
|
|
- Create and manage stream connections
|
|
- Coordinate ModelController and TrackingController
|
|
- Route tracking results to user callbacks/generators
|
|
- Handle stream lifecycle (connect, disconnect, errors)
|
|
|
|
### 3. StreamConnection
|
|
|
|
**Responsibilities:**
|
|
- Wrap a single StreamDecoder
|
|
- Poll frames from threaded decoder (bridge to async)
|
|
- Submit frames to ModelController
|
|
- Emit events to user code
|
|
|
|
---
|
|
|
|
## Pseudo Code Implementation
|
|
|
|
### 1. ModelController with Ping-Pong Buffers
|
|
|
|
```python
|
|
import asyncio
|
|
import torch
|
|
from typing import Dict, List, Tuple, Optional, Callable
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
import time
|
|
|
|
@dataclass
|
|
class BatchFrame:
|
|
"""Represents a frame in the batch buffer"""
|
|
stream_id: str
|
|
frame: torch.Tensor # GPU tensor (3, H, W)
|
|
timestamp: float
|
|
metadata: Dict = None
|
|
|
|
class BufferState(Enum):
|
|
IDLE = "idle"
|
|
FILLING = "filling"
|
|
PROCESSING = "processing"
|
|
|
|
class ModelController:
|
|
"""
|
|
Manages batched inference with ping-pong buffers and force-switch timeout.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_repository,
|
|
model_id: str,
|
|
batch_size: int = 16,
|
|
force_timeout: float = 0.05, # 50ms
|
|
preprocess_fn: Callable = None,
|
|
postprocess_fn: Callable = None,
|
|
):
|
|
self.model_repository = model_repository
|
|
self.model_id = model_id
|
|
self.batch_size = batch_size
|
|
self.force_timeout = force_timeout
|
|
self.preprocess_fn = preprocess_fn
|
|
self.postprocess_fn = postprocess_fn
|
|
|
|
# Ping-pong buffers
|
|
self.buffer_a: List[BatchFrame] = []
|
|
self.buffer_b: List[BatchFrame] = []
|
|
|
|
# Buffer states
|
|
self.active_buffer = "A" # Which buffer is currently active (filling)
|
|
self.buffer_a_state = BufferState.IDLE
|
|
self.buffer_b_state = BufferState.IDLE
|
|
|
|
# Async coordination
|
|
self.buffer_lock = asyncio.Lock()
|
|
self.last_submit_time = time.time()
|
|
|
|
# Tasks
|
|
self.timeout_task: Optional[asyncio.Task] = None
|
|
self.processor_task: Optional[asyncio.Task] = None
|
|
self.running = False
|
|
|
|
# Result callbacks (stream_id -> callback)
|
|
self.result_callbacks: Dict[str, Callable] = {}
|
|
|
|
async def start(self):
|
|
"""Start the controller background tasks"""
|
|
self.running = True
|
|
self.timeout_task = asyncio.create_task(self._timeout_monitor())
|
|
self.processor_task = asyncio.create_task(self._batch_processor())
|
|
|
|
async def stop(self):
|
|
"""Stop the controller and cleanup"""
|
|
self.running = False
|
|
if self.timeout_task:
|
|
self.timeout_task.cancel()
|
|
if self.processor_task:
|
|
self.processor_task.cancel()
|
|
|
|
# Process any remaining frames
|
|
await self._process_remaining_buffers()
|
|
|
|
def register_callback(self, stream_id: str, callback: Callable):
|
|
"""Register a callback for inference results from a stream"""
|
|
self.result_callbacks[stream_id] = callback
|
|
|
|
def unregister_callback(self, stream_id: str):
|
|
"""Unregister a stream callback"""
|
|
self.result_callbacks.pop(stream_id, None)
|
|
|
|
async def submit_frame(self, stream_id: str, frame: torch.Tensor, metadata: Dict = None):
|
|
"""
|
|
Submit a frame for batched inference.
|
|
|
|
Args:
|
|
stream_id: Unique stream identifier
|
|
frame: GPU tensor (3, H, W)
|
|
metadata: Optional metadata to attach
|
|
"""
|
|
async with self.buffer_lock:
|
|
batch_frame = BatchFrame(
|
|
stream_id=stream_id,
|
|
frame=frame,
|
|
timestamp=time.time(),
|
|
metadata=metadata or {}
|
|
)
|
|
|
|
# Add to active buffer
|
|
if self.active_buffer == "A":
|
|
self.buffer_a.append(batch_frame)
|
|
self.buffer_a_state = BufferState.FILLING
|
|
buffer_size = len(self.buffer_a)
|
|
else:
|
|
self.buffer_b.append(batch_frame)
|
|
self.buffer_b_state = BufferState.FILLING
|
|
buffer_size = len(self.buffer_b)
|
|
|
|
self.last_submit_time = time.time()
|
|
|
|
# Check if we should immediately swap (batch full)
|
|
if buffer_size >= self.batch_size:
|
|
await self._try_swap_buffers()
|
|
|
|
async def _timeout_monitor(self):
|
|
"""Monitor force-switch timeout"""
|
|
while self.running:
|
|
await asyncio.sleep(0.01) # Check every 10ms
|
|
|
|
async with self.buffer_lock:
|
|
time_since_submit = time.time() - self.last_submit_time
|
|
|
|
# Check if timeout expired and we have frames waiting
|
|
if time_since_submit >= self.force_timeout:
|
|
active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
|
if len(active_buffer) > 0:
|
|
await self._try_swap_buffers()
|
|
|
|
async def _try_swap_buffers(self):
|
|
"""
|
|
Attempt to swap ping-pong buffers.
|
|
Only swaps if the inactive buffer is not currently processing.
|
|
|
|
This method should be called with buffer_lock held.
|
|
"""
|
|
# Check if inactive buffer is available
|
|
inactive_state = self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
|
|
|
|
if inactive_state != BufferState.PROCESSING:
|
|
# Swap active buffer
|
|
old_active = self.active_buffer
|
|
self.active_buffer = "B" if old_active == "A" else "A"
|
|
|
|
# Mark old active buffer as ready for processing
|
|
if old_active == "A":
|
|
self.buffer_a_state = BufferState.PROCESSING
|
|
else:
|
|
self.buffer_b_state = BufferState.PROCESSING
|
|
|
|
# Signal processor that there's work to do
|
|
# (The processor task is already running and will pick it up)
|
|
|
|
async def _batch_processor(self):
|
|
"""Background task that processes batches when available"""
|
|
while self.running:
|
|
await asyncio.sleep(0.001) # Check every 1ms
|
|
|
|
# Check if buffer A needs processing
|
|
if self.buffer_a_state == BufferState.PROCESSING:
|
|
await self._process_buffer("A")
|
|
|
|
# Check if buffer B needs processing
|
|
if self.buffer_b_state == BufferState.PROCESSING:
|
|
await self._process_buffer("B")
|
|
|
|
async def _process_buffer(self, buffer_name: str):
|
|
"""
|
|
Process a buffer through inference.
|
|
|
|
Args:
|
|
buffer_name: "A" or "B"
|
|
"""
|
|
async with self.buffer_lock:
|
|
# Get buffer to process
|
|
if buffer_name == "A":
|
|
batch = self.buffer_a.copy()
|
|
self.buffer_a.clear()
|
|
else:
|
|
batch = self.buffer_b.copy()
|
|
self.buffer_b.clear()
|
|
|
|
if len(batch) == 0:
|
|
# Mark as idle and return
|
|
async with self.buffer_lock:
|
|
if buffer_name == "A":
|
|
self.buffer_a_state = BufferState.IDLE
|
|
else:
|
|
self.buffer_b_state = BufferState.IDLE
|
|
return
|
|
|
|
# Process batch (outside lock to allow concurrent submissions)
|
|
try:
|
|
results = await self._run_batch_inference(batch)
|
|
|
|
# Emit results to callbacks
|
|
for batch_frame, result in zip(batch, results):
|
|
callback = self.result_callbacks.get(batch_frame.stream_id)
|
|
if callback:
|
|
# Schedule callback asynchronously
|
|
if asyncio.iscoroutinefunction(callback):
|
|
asyncio.create_task(callback(result))
|
|
else:
|
|
callback(result)
|
|
|
|
except Exception as e:
|
|
print(f"Error processing batch: {e}")
|
|
# TODO: Emit error events
|
|
|
|
finally:
|
|
# Mark buffer as idle
|
|
async with self.buffer_lock:
|
|
if buffer_name == "A":
|
|
self.buffer_a_state = BufferState.IDLE
|
|
else:
|
|
self.buffer_b_state = BufferState.IDLE
|
|
|
|
async def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict]:
|
|
"""
|
|
Run inference on a batch of frames.
|
|
|
|
Args:
|
|
batch: List of BatchFrame objects
|
|
|
|
Returns:
|
|
List of detection results (one per frame)
|
|
"""
|
|
# Preprocess frames (on GPU)
|
|
preprocessed = []
|
|
for batch_frame in batch:
|
|
if self.preprocess_fn:
|
|
processed = self.preprocess_fn(batch_frame.frame)
|
|
else:
|
|
processed = batch_frame.frame
|
|
preprocessed.append(processed)
|
|
|
|
# Stack into batch tensor: (N, C, H, W)
|
|
batch_tensor = torch.stack(preprocessed, dim=0)
|
|
|
|
# Run inference (TensorRT model repository is sync, so run in executor)
|
|
loop = asyncio.get_event_loop()
|
|
outputs = await loop.run_in_executor(
|
|
None,
|
|
lambda: self.model_repository.infer(
|
|
self.model_id,
|
|
{"images": batch_tensor},
|
|
synchronize=True
|
|
)
|
|
)
|
|
|
|
# Postprocess results (split batch back to individual results)
|
|
results = []
|
|
for i, batch_frame in enumerate(batch):
|
|
# Extract single frame output from batch
|
|
frame_output = {k: v[i:i+1] for k, v in outputs.items()}
|
|
|
|
if self.postprocess_fn:
|
|
detections = self.postprocess_fn(frame_output)
|
|
else:
|
|
detections = frame_output
|
|
|
|
result = {
|
|
"stream_id": batch_frame.stream_id,
|
|
"timestamp": batch_frame.timestamp,
|
|
"detections": detections,
|
|
"metadata": batch_frame.metadata,
|
|
}
|
|
results.append(result)
|
|
|
|
return results
|
|
|
|
async def _process_remaining_buffers(self):
|
|
"""Process any remaining frames in buffers during shutdown"""
|
|
if len(self.buffer_a) > 0:
|
|
await self._process_buffer("A")
|
|
if len(self.buffer_b) > 0:
|
|
await self._process_buffer("B")
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""Get current buffer statistics"""
|
|
return {
|
|
"active_buffer": self.active_buffer,
|
|
"buffer_a_size": len(self.buffer_a),
|
|
"buffer_b_size": len(self.buffer_b),
|
|
"buffer_a_state": self.buffer_a_state.value,
|
|
"buffer_b_state": self.buffer_b_state.value,
|
|
"registered_streams": len(self.result_callbacks),
|
|
}
|
|
```
|
|
|
|
### 2. StreamConnectionManager
|
|
|
|
```python
|
|
import asyncio
|
|
from typing import Dict, Optional, Callable, AsyncIterator
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
|
|
class ConnectionStatus(Enum):
|
|
CONNECTING = "connecting"
|
|
CONNECTED = "connected"
|
|
DISCONNECTED = "disconnected"
|
|
ERROR = "error"
|
|
|
|
@dataclass
|
|
class TrackingResult:
|
|
"""Result emitted to user callbacks"""
|
|
stream_id: str
|
|
timestamp: float
|
|
tracked_objects: List # List of TrackedObject from TrackingController
|
|
detections: List # Raw detections
|
|
frame_shape: Tuple[int, int, int]
|
|
metadata: Dict
|
|
|
|
class StreamConnection:
|
|
"""Represents a single stream connection with event emission"""
|
|
|
|
def __init__(
|
|
self,
|
|
stream_id: str,
|
|
decoder,
|
|
model_controller: ModelController,
|
|
tracking_controller,
|
|
poll_interval: float = 0.01,
|
|
):
|
|
self.stream_id = stream_id
|
|
self.decoder = decoder
|
|
self.model_controller = model_controller
|
|
self.tracking_controller = tracking_controller
|
|
self.poll_interval = poll_interval
|
|
|
|
self.status = ConnectionStatus.CONNECTING
|
|
self.frame_count = 0
|
|
self.last_frame_time = 0.0
|
|
|
|
# Event emission
|
|
self.result_queue: asyncio.Queue[TrackingResult] = asyncio.Queue()
|
|
self.error_queue: asyncio.Queue[Exception] = asyncio.Queue()
|
|
|
|
# Tasks
|
|
self.poller_task: Optional[asyncio.Task] = None
|
|
self.running = False
|
|
|
|
async def start(self):
|
|
"""Start the connection (decoder and frame polling)"""
|
|
# Start decoder (runs in background thread)
|
|
self.decoder.start()
|
|
|
|
# Wait for initial connection
|
|
await asyncio.sleep(2.0)
|
|
|
|
if self.decoder.is_connected():
|
|
self.status = ConnectionStatus.CONNECTED
|
|
else:
|
|
self.status = ConnectionStatus.ERROR
|
|
raise ConnectionError(f"Failed to connect to stream {self.stream_id}")
|
|
|
|
# Start frame polling task
|
|
self.running = True
|
|
self.poller_task = asyncio.create_task(self._frame_poller())
|
|
|
|
async def stop(self):
|
|
"""Stop the connection and cleanup"""
|
|
self.running = False
|
|
|
|
if self.poller_task:
|
|
self.poller_task.cancel()
|
|
try:
|
|
await self.poller_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Stop decoder
|
|
self.decoder.stop()
|
|
|
|
# Unregister from model controller
|
|
self.model_controller.unregister_callback(self.stream_id)
|
|
|
|
self.status = ConnectionStatus.DISCONNECTED
|
|
|
|
async def _frame_poller(self):
|
|
"""Poll frames from threaded decoder and submit to model controller"""
|
|
last_frame_ptr = None
|
|
|
|
while self.running:
|
|
try:
|
|
# Poll frame from decoder (runs in thread)
|
|
frame = self.decoder.get_latest_frame(rgb=True)
|
|
|
|
# Check if we got a new frame (avoid reprocessing same frame)
|
|
if frame is not None and frame.data_ptr() != last_frame_ptr:
|
|
last_frame_ptr = frame.data_ptr()
|
|
self.last_frame_time = time.time()
|
|
self.frame_count += 1
|
|
|
|
# Submit to model controller for batched inference
|
|
await self.model_controller.submit_frame(
|
|
stream_id=self.stream_id,
|
|
frame=frame,
|
|
metadata={
|
|
"frame_number": self.frame_count,
|
|
"shape": frame.shape,
|
|
}
|
|
)
|
|
|
|
# Check decoder status
|
|
if not self.decoder.is_connected():
|
|
self.status = ConnectionStatus.DISCONNECTED
|
|
# Decoder will auto-reconnect, just update status
|
|
await asyncio.sleep(1.0)
|
|
if self.decoder.is_connected():
|
|
self.status = ConnectionStatus.CONNECTED
|
|
|
|
except Exception as e:
|
|
await self.error_queue.put(e)
|
|
self.status = ConnectionStatus.ERROR
|
|
|
|
# Sleep until next poll
|
|
await asyncio.sleep(self.poll_interval)
|
|
|
|
async def _handle_inference_result(self, result: Dict):
|
|
"""
|
|
Callback invoked by ModelController when inference is done.
|
|
Runs tracking and emits final result.
|
|
"""
|
|
try:
|
|
# Extract detections
|
|
detections = result["detections"]
|
|
|
|
# Run tracking (this is sync, so run in executor)
|
|
loop = asyncio.get_event_loop()
|
|
tracked_objects = await loop.run_in_executor(
|
|
None,
|
|
lambda: self.tracking_controller.update_tracks(detections)
|
|
)
|
|
|
|
# Create tracking result
|
|
tracking_result = TrackingResult(
|
|
stream_id=self.stream_id,
|
|
timestamp=result["timestamp"],
|
|
tracked_objects=tracked_objects,
|
|
detections=detections,
|
|
frame_shape=result["metadata"].get("shape"),
|
|
metadata=result["metadata"],
|
|
)
|
|
|
|
# Emit to result queue
|
|
await self.result_queue.put(tracking_result)
|
|
|
|
except Exception as e:
|
|
await self.error_queue.put(e)
|
|
|
|
async def tracking_results(self) -> AsyncIterator[TrackingResult]:
|
|
"""
|
|
Async generator for tracking results.
|
|
|
|
Usage:
|
|
async for result in connection.tracking_results():
|
|
print(result.tracked_objects)
|
|
"""
|
|
while self.running or not self.result_queue.empty():
|
|
try:
|
|
result = await asyncio.wait_for(self.result_queue.get(), timeout=1.0)
|
|
yield result
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
|
|
async def errors(self) -> AsyncIterator[Exception]:
|
|
"""Async generator for errors"""
|
|
while self.running or not self.error_queue.empty():
|
|
try:
|
|
error = await asyncio.wait_for(self.error_queue.get(), timeout=1.0)
|
|
yield error
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""Get connection statistics"""
|
|
return {
|
|
"stream_id": self.stream_id,
|
|
"status": self.status.value,
|
|
"frame_count": self.frame_count,
|
|
"last_frame_time": self.last_frame_time,
|
|
"decoder_connected": self.decoder.is_connected(),
|
|
"decoder_buffer_size": self.decoder.get_buffer_size(),
|
|
}
|
|
|
|
|
|
class StreamConnectionManager:
|
|
"""
|
|
High-level manager for stream connections with batched inference.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
gpu_id: int = 0,
|
|
batch_size: int = 16,
|
|
force_timeout: float = 0.05,
|
|
poll_interval: float = 0.01,
|
|
):
|
|
self.gpu_id = gpu_id
|
|
self.batch_size = batch_size
|
|
self.force_timeout = force_timeout
|
|
self.poll_interval = poll_interval
|
|
|
|
# Factories
|
|
from services import StreamDecoderFactory, TrackingFactory
|
|
from services.model_repository import TensorRTModelRepository
|
|
|
|
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
|
|
self.tracking_factory = TrackingFactory(gpu_id=gpu_id)
|
|
self.model_repository = TensorRTModelRepository(gpu_id=gpu_id)
|
|
|
|
# Controllers
|
|
self.model_controller: Optional[ModelController] = None
|
|
self.tracking_controller = None
|
|
|
|
# Connections
|
|
self.connections: Dict[str, StreamConnection] = {}
|
|
|
|
# State
|
|
self.initialized = False
|
|
|
|
async def initialize(
|
|
self,
|
|
model_path: str,
|
|
model_id: str = "detector",
|
|
preprocess_fn: Callable = None,
|
|
postprocess_fn: Callable = None,
|
|
):
|
|
"""
|
|
Initialize the manager with a model.
|
|
|
|
Args:
|
|
model_path: Path to TensorRT model file
|
|
model_id: Model identifier
|
|
preprocess_fn: Preprocessing function (e.g., YOLOv8Utils.preprocess)
|
|
postprocess_fn: Postprocessing function (e.g., YOLOv8Utils.postprocess)
|
|
"""
|
|
# Load model
|
|
loop = asyncio.get_event_loop()
|
|
await loop.run_in_executor(
|
|
None,
|
|
lambda: self.model_repository.load_model(model_id, model_path, num_contexts=4)
|
|
)
|
|
|
|
# Create model controller
|
|
self.model_controller = ModelController(
|
|
model_repository=self.model_repository,
|
|
model_id=model_id,
|
|
batch_size=self.batch_size,
|
|
force_timeout=self.force_timeout,
|
|
preprocess_fn=preprocess_fn,
|
|
postprocess_fn=postprocess_fn,
|
|
)
|
|
await self.model_controller.start()
|
|
|
|
# Create tracking controller
|
|
self.tracking_controller = self.tracking_factory.create_controller(
|
|
model_repository=self.model_repository,
|
|
model_id=model_id,
|
|
tracker_type="iou",
|
|
)
|
|
|
|
self.initialized = True
|
|
|
|
async def connect_stream(
|
|
self,
|
|
rtsp_url: str,
|
|
stream_id: Optional[str] = None,
|
|
on_tracking_result: Optional[Callable] = None,
|
|
on_error: Optional[Callable] = None,
|
|
) -> StreamConnection:
|
|
"""
|
|
Connect to a stream and start processing.
|
|
|
|
Args:
|
|
rtsp_url: RTSP stream URL
|
|
stream_id: Optional stream identifier (auto-generated if not provided)
|
|
on_tracking_result: Optional callback for tracking results
|
|
on_error: Optional callback for errors
|
|
|
|
Returns:
|
|
StreamConnection object for this stream
|
|
"""
|
|
if not self.initialized:
|
|
raise RuntimeError("Manager not initialized. Call initialize() first.")
|
|
|
|
# Generate stream ID if not provided
|
|
if stream_id is None:
|
|
stream_id = f"stream_{len(self.connections)}"
|
|
|
|
# Create decoder
|
|
decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=30)
|
|
|
|
# Create connection
|
|
connection = StreamConnection(
|
|
stream_id=stream_id,
|
|
decoder=decoder,
|
|
model_controller=self.model_controller,
|
|
tracking_controller=self.tracking_controller,
|
|
poll_interval=self.poll_interval,
|
|
)
|
|
|
|
# Register callback with model controller
|
|
self.model_controller.register_callback(
|
|
stream_id,
|
|
connection._handle_inference_result
|
|
)
|
|
|
|
# Start connection
|
|
await connection.start()
|
|
|
|
# Store connection
|
|
self.connections[stream_id] = connection
|
|
|
|
# Set up user callbacks if provided
|
|
if on_tracking_result:
|
|
asyncio.create_task(self._forward_results(connection, on_tracking_result))
|
|
|
|
if on_error:
|
|
asyncio.create_task(self._forward_errors(connection, on_error))
|
|
|
|
return connection
|
|
|
|
async def disconnect_stream(self, stream_id: str):
|
|
"""Disconnect and cleanup a stream"""
|
|
connection = self.connections.get(stream_id)
|
|
if connection:
|
|
await connection.stop()
|
|
del self.connections[stream_id]
|
|
|
|
async def disconnect_all(self):
|
|
"""Disconnect all streams"""
|
|
stream_ids = list(self.connections.keys())
|
|
for stream_id in stream_ids:
|
|
await self.disconnect_stream(stream_id)
|
|
|
|
async def shutdown(self):
|
|
"""Shutdown the manager and cleanup all resources"""
|
|
# Disconnect all streams
|
|
await self.disconnect_all()
|
|
|
|
# Stop model controller
|
|
if self.model_controller:
|
|
await self.model_controller.stop()
|
|
|
|
# Cleanup (model repository cleanup is sync)
|
|
# Note: May need to handle cleanup carefully to avoid segfaults
|
|
|
|
async def _forward_results(self, connection: StreamConnection, callback: Callable):
|
|
"""Forward results from connection to user callback"""
|
|
async for result in connection.tracking_results():
|
|
if asyncio.iscoroutinefunction(callback):
|
|
await callback(result)
|
|
else:
|
|
callback(result)
|
|
|
|
async def _forward_errors(self, connection: StreamConnection, callback: Callable):
|
|
"""Forward errors from connection to user callback"""
|
|
async for error in connection.errors():
|
|
if asyncio.iscoroutinefunction(callback):
|
|
await callback(error)
|
|
else:
|
|
callback(error)
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""Get statistics for all connections"""
|
|
return {
|
|
"manager": {
|
|
"initialized": self.initialized,
|
|
"num_connections": len(self.connections),
|
|
"batch_size": self.batch_size,
|
|
"force_timeout": self.force_timeout,
|
|
},
|
|
"model_controller": self.model_controller.get_stats() if self.model_controller else {},
|
|
"connections": {
|
|
stream_id: conn.get_stats()
|
|
for stream_id, conn in self.connections.items()
|
|
},
|
|
}
|
|
```
|
|
|
|
### 3. User API Examples
|
|
|
|
#### Example 1: Simple Callback Pattern
|
|
|
|
```python
|
|
import asyncio
|
|
from services import StreamConnectionManager
|
|
from services.yolo import YOLOv8Utils
|
|
|
|
async def main():
|
|
# Create manager
|
|
manager = StreamConnectionManager(
|
|
gpu_id=0,
|
|
batch_size=16,
|
|
force_timeout=0.05, # 50ms
|
|
)
|
|
|
|
# Initialize with model
|
|
await manager.initialize(
|
|
model_path="models/yolov8n.trt",
|
|
model_id="yolo",
|
|
preprocess_fn=YOLOv8Utils.preprocess,
|
|
postprocess_fn=YOLOv8Utils.postprocess,
|
|
)
|
|
|
|
# Define callback for tracking results
|
|
def on_tracking_result(result):
|
|
print(f"Stream: {result.stream_id}")
|
|
print(f"Timestamp: {result.timestamp}")
|
|
print(f"Tracked objects: {len(result.tracked_objects)}")
|
|
for obj in result.tracked_objects:
|
|
print(f" - Track ID {obj.track_id}: class={obj.class_id}, conf={obj.confidence:.2f}")
|
|
|
|
def on_error(error):
|
|
print(f"Error: {error}")
|
|
|
|
# Connect to stream
|
|
connection = await manager.connect_stream(
|
|
rtsp_url="rtsp://camera1.example.com/stream",
|
|
stream_id="camera1",
|
|
on_tracking_result=on_tracking_result,
|
|
on_error=on_error,
|
|
)
|
|
|
|
# Let it run for 60 seconds
|
|
await asyncio.sleep(60)
|
|
|
|
# Get statistics
|
|
stats = manager.get_stats()
|
|
print(f"Stats: {stats}")
|
|
|
|
# Cleanup
|
|
await manager.shutdown()
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|
|
```
|
|
|
|
#### Example 2: Async Generator Pattern (Multiple Streams)
|
|
|
|
```python
|
|
import asyncio
|
|
from services import StreamConnectionManager
|
|
from services.yolo import YOLOv8Utils
|
|
|
|
async def process_stream(connection, stream_name):
|
|
"""Process results from a single stream"""
|
|
async for result in connection.tracking_results():
|
|
print(f"[{stream_name}] Frame {result.metadata['frame_number']}: {len(result.tracked_objects)} objects")
|
|
|
|
# Do something with tracked objects
|
|
for obj in result.tracked_objects:
|
|
if obj.class_id == 0: # Person class
|
|
print(f" Person detected! Track ID: {obj.track_id}, Conf: {obj.confidence:.2f}")
|
|
|
|
async def main():
|
|
manager = StreamConnectionManager(
|
|
gpu_id=0,
|
|
batch_size=32, # Larger batch for multiple streams
|
|
force_timeout=0.05,
|
|
)
|
|
|
|
await manager.initialize(
|
|
model_path="models/yolov8n.trt",
|
|
preprocess_fn=YOLOv8Utils.preprocess,
|
|
postprocess_fn=YOLOv8Utils.postprocess,
|
|
)
|
|
|
|
# Connect to multiple streams
|
|
camera_urls = [
|
|
("rtsp://camera1.example.com/stream", "Front Door"),
|
|
("rtsp://camera2.example.com/stream", "Parking Lot"),
|
|
("rtsp://camera3.example.com/stream", "Warehouse"),
|
|
("rtsp://camera4.example.com/stream", "Loading Bay"),
|
|
]
|
|
|
|
tasks = []
|
|
for url, name in camera_urls:
|
|
connection = await manager.connect_stream(
|
|
rtsp_url=url,
|
|
stream_id=name.lower().replace(" ", "_"),
|
|
)
|
|
|
|
# Create task to process this stream
|
|
task = asyncio.create_task(process_stream(connection, name))
|
|
tasks.append(task)
|
|
|
|
# Run all streams concurrently
|
|
try:
|
|
await asyncio.gather(*tasks)
|
|
except KeyboardInterrupt:
|
|
print("Shutting down...")
|
|
|
|
await manager.shutdown()
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|
|
```
|
|
|
|
#### Example 3: Queue-Based Pattern (for integration with other systems)
|
|
|
|
```python
|
|
import asyncio
|
|
from services import StreamConnectionManager
|
|
from services.yolo import YOLOv8Utils
|
|
|
|
async def main():
|
|
manager = StreamConnectionManager(gpu_id=0, batch_size=16)
|
|
|
|
await manager.initialize(
|
|
model_path="models/yolov8n.trt",
|
|
preprocess_fn=YOLOv8Utils.preprocess,
|
|
postprocess_fn=YOLOv8Utils.postprocess,
|
|
)
|
|
|
|
# Connect to stream (no callback)
|
|
connection = await manager.connect_stream(
|
|
rtsp_url="rtsp://camera.example.com/stream",
|
|
stream_id="main_camera",
|
|
)
|
|
|
|
# Use the built-in queue
|
|
result_queue = connection.result_queue
|
|
|
|
# Process results from queue
|
|
while True:
|
|
result = await result_queue.get()
|
|
|
|
# Send to external system (e.g., message queue, database, API)
|
|
await send_to_kafka(result)
|
|
await save_to_database(result)
|
|
|
|
# Or do real-time processing
|
|
if has_person_alert(result.tracked_objects):
|
|
await send_alert("Person detected in restricted area!")
|
|
|
|
async def send_to_kafka(result):
|
|
# Your Kafka producer code
|
|
pass
|
|
|
|
async def save_to_database(result):
|
|
# Your database code
|
|
pass
|
|
|
|
def has_person_alert(tracked_objects):
|
|
# Your alert logic
|
|
return any(obj.class_id == 0 for obj in tracked_objects)
|
|
|
|
async def send_alert(message):
|
|
print(f"ALERT: {message}")
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|
|
```
|
|
|
|
#### Example 4: Async Callback with Error Handling
|
|
|
|
```python
|
|
import asyncio
|
|
from services import StreamConnectionManager
|
|
from services.yolo import YOLOv8Utils
|
|
|
|
async def main():
|
|
manager = StreamConnectionManager(gpu_id=0, batch_size=16)
|
|
|
|
await manager.initialize(
|
|
model_path="models/yolov8n.trt",
|
|
preprocess_fn=YOLOv8Utils.preprocess,
|
|
postprocess_fn=YOLOv8Utils.postprocess,
|
|
)
|
|
|
|
# Async callback (can do I/O operations)
|
|
async def on_tracking_result(result):
|
|
# Can use async operations in callback
|
|
await save_to_database(result)
|
|
|
|
# Check for alerts
|
|
for obj in result.tracked_objects:
|
|
if obj.class_id == 0 and obj.confidence > 0.8:
|
|
await send_notification(f"High confidence person detection: {obj.track_id}")
|
|
|
|
async def on_error(error):
|
|
await log_error_to_monitoring_system(error)
|
|
|
|
# Connect with async callbacks
|
|
connection = await manager.connect_stream(
|
|
rtsp_url="rtsp://camera.example.com/stream",
|
|
on_tracking_result=on_tracking_result,
|
|
on_error=on_error,
|
|
)
|
|
|
|
# Monitor stats periodically
|
|
while True:
|
|
await asyncio.sleep(10)
|
|
stats = manager.get_stats()
|
|
print(f"Buffer stats: {stats['model_controller']}")
|
|
print(f"Connection stats: {stats['connections']}")
|
|
|
|
async def save_to_database(result):
|
|
# Simulate async database operation
|
|
await asyncio.sleep(0.01)
|
|
|
|
async def send_notification(message):
|
|
print(f"NOTIFICATION: {message}")
|
|
|
|
async def log_error_to_monitoring_system(error):
|
|
print(f"ERROR: {error}")
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|
|
```
|
|
|
|
---
|
|
|
|
## Configuration Examples
|
|
|
|
### Performance Tuning
|
|
|
|
```python
|
|
# Low latency (small batches, quick timeout)
|
|
manager = StreamConnectionManager(
|
|
gpu_id=0,
|
|
batch_size=4,
|
|
force_timeout=0.02, # 20ms
|
|
poll_interval=0.005, # 200 FPS
|
|
)
|
|
|
|
# High throughput (large batches, longer timeout)
|
|
manager = StreamConnectionManager(
|
|
gpu_id=0,
|
|
batch_size=32,
|
|
force_timeout=0.1, # 100ms
|
|
poll_interval=0.02, # 50 FPS
|
|
)
|
|
|
|
# Balanced (default)
|
|
manager = StreamConnectionManager(
|
|
gpu_id=0,
|
|
batch_size=16,
|
|
force_timeout=0.05, # 50ms
|
|
poll_interval=0.01, # 100 FPS
|
|
)
|
|
```
|
|
|
|
### Multiple GPUs
|
|
|
|
```python
|
|
# Create manager per GPU
|
|
manager_gpu0 = StreamConnectionManager(gpu_id=0, batch_size=16)
|
|
manager_gpu1 = StreamConnectionManager(gpu_id=1, batch_size=16)
|
|
|
|
# Initialize both
|
|
await manager_gpu0.initialize(model_path="models/yolov8n.trt", ...)
|
|
await manager_gpu1.initialize(model_path="models/yolov8n.trt", ...)
|
|
|
|
# Distribute streams across GPUs
|
|
await manager_gpu0.connect_stream(url1, ...)
|
|
await manager_gpu0.connect_stream(url2, ...)
|
|
await manager_gpu1.connect_stream(url3, ...)
|
|
await manager_gpu1.connect_stream(url4, ...)
|
|
```
|
|
|
|
---
|
|
|
|
## Key Features Summary
|
|
|
|
1. **Ping-Pong Buffers**: Efficient batching with minimal latency
|
|
2. **Force Timeout**: Prevents starvation of small batches
|
|
3. **AsyncIO**: Clean event-driven architecture
|
|
4. **Multiple Patterns**: Callbacks, generators, queues
|
|
5. **Thread-Async Bridge**: Integrates with existing threaded decoders
|
|
6. **Zero-Copy**: All processing stays on GPU
|
|
7. **Auto-Reconnection**: Inherits from StreamDecoder
|
|
8. **Statistics**: Real-time monitoring of buffers and connections
|
|
|
|
---
|
|
|
|
## Performance Characteristics
|
|
|
|
- **Latency**: `force_timeout + inference_time`
|
|
- **Throughput**: Maximized by batching
|
|
- **VRAM**: 60MB per stream + batch buffer overhead
|
|
- **CPU**: Minimal (async event loop + thread polling)
|
|
|
|
---
|
|
|
|
## Next Steps
|
|
|
|
To implement this design:
|
|
|
|
1. Create `services/model_controller.py` with `ModelController` class
|
|
2. Create `services/stream_connection_manager.py` with `StreamConnectionManager` and `StreamConnection` classes
|
|
3. Update `services/__init__.py` to export new classes
|
|
4. Create `test_event_driven.py` to test the system
|
|
5. Add monitoring/logging throughout
|
|
6. Handle edge cases (reconnection, cleanup, errors)
|