batch processing/event driven

This commit is contained in:
Siwat Sirichai 2025-11-09 18:43:56 +07:00
parent e71316ef3d
commit dd57b5a246
7 changed files with 2673 additions and 2 deletions

1108
EVENT_DRIVEN_DESIGN.md Normal file

File diff suppressed because it is too large Load diff

View file

@ -8,6 +8,8 @@ from .model_repository import TensorRTModelRepository, ModelMetadata, ExecutionC
from .tracking_controller import TrackingController, TrackedObject from .tracking_controller import TrackingController, TrackedObject
from .tracking_factory import TrackingFactory from .tracking_factory import TrackingFactory
from .yolo import YOLOv8Utils, COCO_CLASSES from .yolo import YOLOv8Utils, COCO_CLASSES
from .model_controller import ModelController, BatchFrame, BufferState
from .stream_connection_manager import StreamConnectionManager, StreamConnection, TrackingResult
__all__ = [ __all__ = [
'StreamDecoderFactory', 'StreamDecoderFactory',
@ -24,4 +26,10 @@ __all__ = [
'TrackingFactory', 'TrackingFactory',
'YOLOv8Utils', 'YOLOv8Utils',
'COCO_CLASSES', 'COCO_CLASSES',
'ModelController',
'BatchFrame',
'BufferState',
'StreamConnectionManager',
'StreamConnection',
'TrackingResult',
] ]

View file

@ -0,0 +1,499 @@
"""
ModelController - Async batching layer with ping-pong buffers for inference.
This module provides batched inference coordination using ping-pong circular buffers
with force-switch timeout mechanism.
"""
import asyncio
import torch
from typing import Dict, List, Optional, Callable, Any
from dataclasses import dataclass, field
from enum import Enum
import time
import logging
logger = logging.getLogger(__name__)
@dataclass
class BatchFrame:
"""Represents a frame in the batch buffer"""
stream_id: str
frame: torch.Tensor # GPU tensor (3, H, W)
timestamp: float
metadata: Dict = field(default_factory=dict)
class BufferState(Enum):
"""State of a ping-pong buffer"""
IDLE = "idle"
FILLING = "filling"
PROCESSING = "processing"
class ModelController:
"""
Manages batched inference with ping-pong buffers and force-switch timeout.
This controller accumulates frames from multiple streams into batches,
processes them through a model repository, and routes results back to
stream-specific callbacks.
Features:
- Ping-pong circular buffers (BufferA/BufferB)
- Force-switch timeout to prevent batch starvation
- Async event-driven processing
- Thread-safe frame submission
Args:
model_repository: TensorRT model repository for inference
model_id: Model identifier in the repository
batch_size: Maximum frames per batch (default: 16)
force_timeout: Max wait time before forcing buffer switch in seconds (default: 0.05)
preprocess_fn: Optional preprocessing function for frames
postprocess_fn: Optional postprocessing function for model outputs
"""
def __init__(
self,
model_repository,
model_id: str,
batch_size: int = 16,
force_timeout: float = 0.05,
preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None,
):
self.model_repository = model_repository
self.model_id = model_id
self.batch_size = batch_size
self.force_timeout = force_timeout
self.preprocess_fn = preprocess_fn
self.postprocess_fn = postprocess_fn
# Detect model's actual batch size from input shape
self.model_batch_size = self._detect_model_batch_size()
if self.model_batch_size == 1:
logger.warning(
f"Model '{model_id}' has fixed batch_size=1. "
f"Will process frames sequentially. Consider rebuilding model with dynamic batching."
)
else:
logger.info(f"Model '{model_id}' supports batch_size={self.model_batch_size}")
# Ping-pong buffers
self.buffer_a: List[BatchFrame] = []
self.buffer_b: List[BatchFrame] = []
# Buffer states
self.active_buffer = "A" # Which buffer is currently active (filling)
self.buffer_a_state = BufferState.IDLE
self.buffer_b_state = BufferState.IDLE
# Async coordination
self.buffer_lock = asyncio.Lock()
self.last_submit_time = time.time()
# Tasks
self.timeout_task: Optional[asyncio.Task] = None
self.processor_task: Optional[asyncio.Task] = None
self.running = False
# Result callbacks (stream_id -> callback)
self.result_callbacks: Dict[str, Callable] = {}
# Statistics
self.total_frames_processed = 0
self.total_batches_processed = 0
def _detect_model_batch_size(self) -> int:
"""
Detect the model's batch size from its input shape.
Returns:
Maximum batch size supported by the model (1 for fixed batch size models)
"""
try:
metadata = self.model_repository.get_metadata(self.model_id)
# Get first input tensor shape
first_input = list(metadata.inputs.values())[0]
batch_dim = first_input["shape"][0]
# batch_dim can be -1 (dynamic), 1 (fixed), or N (fixed batch size)
if batch_dim == -1:
# Dynamic batch size - use user-specified batch_size
return self.batch_size
else:
# Fixed batch size
return batch_dim
except Exception as e:
logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1")
return 1
async def start(self):
"""Start the controller background tasks"""
if self.running:
logger.warning("ModelController already running")
return
self.running = True
self.timeout_task = asyncio.create_task(self._timeout_monitor())
self.processor_task = asyncio.create_task(self._batch_processor())
logger.info("ModelController started")
async def stop(self):
"""Stop the controller and cleanup"""
if not self.running:
return
logger.info("Stopping ModelController...")
self.running = False
# Cancel tasks
if self.timeout_task:
self.timeout_task.cancel()
try:
await self.timeout_task
except asyncio.CancelledError:
pass
if self.processor_task:
self.processor_task.cancel()
try:
await self.processor_task
except asyncio.CancelledError:
pass
# Process any remaining frames
await self._process_remaining_buffers()
logger.info("ModelController stopped")
def register_callback(self, stream_id: str, callback: Callable):
"""
Register a callback for inference results from a stream.
Args:
stream_id: Unique stream identifier
callback: Callback function to receive results (can be sync or async)
"""
self.result_callbacks[stream_id] = callback
logger.debug(f"Registered callback for stream: {stream_id}")
def unregister_callback(self, stream_id: str):
"""
Unregister a stream callback.
Args:
stream_id: Stream identifier to unregister
"""
self.result_callbacks.pop(stream_id, None)
logger.debug(f"Unregistered callback for stream: {stream_id}")
async def submit_frame(
self,
stream_id: str,
frame: torch.Tensor,
metadata: Optional[Dict] = None
):
"""
Submit a frame for batched inference.
Args:
stream_id: Unique stream identifier
frame: GPU tensor (3, H, W) or (C, H, W)
metadata: Optional metadata to attach to the frame
"""
async with self.buffer_lock:
batch_frame = BatchFrame(
stream_id=stream_id,
frame=frame,
timestamp=time.time(),
metadata=metadata or {}
)
# Add to active buffer
if self.active_buffer == "A":
self.buffer_a.append(batch_frame)
self.buffer_a_state = BufferState.FILLING
buffer_size = len(self.buffer_a)
else:
self.buffer_b.append(batch_frame)
self.buffer_b_state = BufferState.FILLING
buffer_size = len(self.buffer_b)
self.last_submit_time = time.time()
# Check if we should immediately swap (batch full)
if buffer_size >= self.batch_size:
await self._try_swap_buffers()
async def _timeout_monitor(self):
"""Monitor force-switch timeout"""
while self.running:
await asyncio.sleep(0.01) # Check every 10ms
async with self.buffer_lock:
time_since_submit = time.time() - self.last_submit_time
# Check if timeout expired and we have frames waiting
if time_since_submit >= self.force_timeout:
active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b
if len(active_buffer) > 0:
await self._try_swap_buffers()
async def _try_swap_buffers(self):
"""
Attempt to swap ping-pong buffers.
Only swaps if the inactive buffer is not currently processing.
This method should be called with buffer_lock held.
"""
# Check if inactive buffer is available
inactive_state = self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
if inactive_state != BufferState.PROCESSING:
# Swap active buffer
old_active = self.active_buffer
self.active_buffer = "B" if old_active == "A" else "A"
# Mark old active buffer as ready for processing
if old_active == "A":
self.buffer_a_state = BufferState.PROCESSING
buffer_size = len(self.buffer_a)
else:
self.buffer_b_state = BufferState.PROCESSING
buffer_size = len(self.buffer_b)
logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})")
async def _batch_processor(self):
"""Background task that processes batches when available"""
while self.running:
await asyncio.sleep(0.001) # Check every 1ms
# Check if buffer A needs processing
if self.buffer_a_state == BufferState.PROCESSING:
await self._process_buffer("A")
# Check if buffer B needs processing
if self.buffer_b_state == BufferState.PROCESSING:
await self._process_buffer("B")
async def _process_buffer(self, buffer_name: str):
"""
Process a buffer through inference.
Args:
buffer_name: "A" or "B"
"""
# Extract buffer to process
async with self.buffer_lock:
if buffer_name == "A":
batch = self.buffer_a.copy()
self.buffer_a.clear()
else:
batch = self.buffer_b.copy()
self.buffer_b.clear()
if len(batch) == 0:
# Mark as idle and return
async with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
self.buffer_b_state = BufferState.IDLE
return
# Process batch (outside lock to allow concurrent submissions)
try:
start_time = time.time()
results = await self._run_batch_inference(batch)
inference_time = time.time() - start_time
# Update statistics
self.total_frames_processed += len(batch)
self.total_batches_processed += 1
logger.debug(
f"Processed batch of {len(batch)} frames in {inference_time*1000:.2f}ms "
f"({inference_time*1000/len(batch):.2f}ms per frame)"
)
# Emit results to callbacks
for batch_frame, result in zip(batch, results):
callback = self.result_callbacks.get(batch_frame.stream_id)
if callback:
# Schedule callback asynchronously
if asyncio.iscoroutinefunction(callback):
asyncio.create_task(callback(result))
else:
# Run sync callback in executor to avoid blocking
loop = asyncio.get_event_loop()
loop.call_soon(lambda cb=callback, r=result: cb(r))
except Exception as e:
logger.error(f"Error processing batch: {e}", exc_info=True)
# TODO: Emit error events to streams
finally:
# Mark buffer as idle
async with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
self.buffer_b_state = BufferState.IDLE
async def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""
Run inference on a batch of frames.
Args:
batch: List of BatchFrame objects
Returns:
List of detection results (one per frame)
"""
loop = asyncio.get_event_loop()
# Check if model supports batching
if self.model_batch_size == 1:
# Process frames one at a time for batch_size=1 models
return await self._run_sequential_inference(batch, loop)
else:
# Use true batching for models that support it
return await self._run_batched_inference(batch, loop)
async def _run_sequential_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
"""Run inference sequentially for batch_size=1 models"""
results = []
for batch_frame in batch:
# Preprocess frame
if self.preprocess_fn:
processed = self.preprocess_fn(batch_frame.frame)
else:
# Ensure we have batch dimension
processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame
# Run inference for this frame
outputs = await loop.run_in_executor(
None,
lambda p=processed: self.model_repository.infer(
self.model_id,
{"images": p},
synchronize=True
)
)
# Postprocess
if self.postprocess_fn:
try:
detections = self.postprocess_fn(outputs)
except Exception as e:
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
# Return empty detections on error
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
else:
detections = outputs
result = {
"stream_id": batch_frame.stream_id,
"timestamp": batch_frame.timestamp,
"detections": detections,
"metadata": batch_frame.metadata,
}
results.append(result)
return results
async def _run_batched_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
"""Run true batched inference for models that support it"""
# Preprocess frames (on GPU)
preprocessed = []
for batch_frame in batch:
if self.preprocess_fn:
processed = self.preprocess_fn(batch_frame.frame)
# Preprocess may return (1, C, H, W), squeeze to (C, H, W)
if processed.dim() == 4 and processed.shape[0] == 1:
processed = processed.squeeze(0)
else:
processed = batch_frame.frame
preprocessed.append(processed)
# Stack into batch tensor: (N, C, H, W)
batch_tensor = torch.stack(preprocessed, dim=0)
# Limit batch size to model's max batch size
if batch_tensor.shape[0] > self.model_batch_size:
logger.warning(
f"Batch size {batch_tensor.shape[0]} exceeds model max {self.model_batch_size}, "
f"will split into sub-batches"
)
# TODO: Handle splitting into sub-batches
batch_tensor = batch_tensor[:self.model_batch_size]
batch = batch[:self.model_batch_size]
# Run inference
outputs = await loop.run_in_executor(
None,
lambda: self.model_repository.infer(
self.model_id,
{"images": batch_tensor},
synchronize=True
)
)
# Postprocess results (split batch back to individual results)
results = []
for i, batch_frame in enumerate(batch):
# Extract single frame output from batch
frame_output = {}
for k, v in outputs.items():
# v has shape (N, ...), extract index i and keep batch dimension
frame_output[k] = v[i:i+1] # Shape: (1, ...)
if self.postprocess_fn:
try:
detections = self.postprocess_fn(frame_output)
except Exception as e:
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
# Return empty detections on error
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
else:
detections = frame_output
result = {
"stream_id": batch_frame.stream_id,
"timestamp": batch_frame.timestamp,
"detections": detections,
"metadata": batch_frame.metadata,
}
results.append(result)
return results
async def _process_remaining_buffers(self):
"""Process any remaining frames in buffers during shutdown"""
if len(self.buffer_a) > 0:
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
await self._process_buffer("A")
if len(self.buffer_b) > 0:
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
await self._process_buffer("B")
def get_stats(self) -> Dict[str, Any]:
"""Get current buffer statistics"""
return {
"active_buffer": self.active_buffer,
"buffer_a_size": len(self.buffer_a),
"buffer_b_size": len(self.buffer_b),
"buffer_a_state": self.buffer_a_state.value,
"buffer_b_state": self.buffer_b_state.value,
"registered_streams": len(self.result_callbacks),
"total_frames_processed": self.total_frames_processed,
"total_batches_processed": self.total_batches_processed,
"avg_batch_size": (
self.total_frames_processed / self.total_batches_processed
if self.total_batches_processed > 0 else 0
),
}

View file

@ -0,0 +1,566 @@
"""
StreamConnectionManager - Async orchestration for stream processing with batched inference.
This module provides high-level connection management for multiple RTSP streams,
coordinating decoders, batched inference, and tracking with an event-driven API.
"""
import asyncio
import time
from typing import Dict, Optional, Callable, AsyncIterator, Tuple, Any, List
from dataclasses import dataclass
from enum import Enum
import logging
import torch
from .model_controller import ModelController
from .stream_decoder import StreamDecoderFactory
from .tracking_factory import TrackingFactory
from .model_repository import TensorRTModelRepository
logger = logging.getLogger(__name__)
class ConnectionStatus(Enum):
"""Status of a stream connection"""
CONNECTING = "connecting"
CONNECTED = "connected"
DISCONNECTED = "disconnected"
ERROR = "error"
@dataclass
class TrackingResult:
"""Result emitted to user callbacks"""
stream_id: str
timestamp: float
tracked_objects: List # List of TrackedObject from TrackingController
detections: List # Raw detections
frame_shape: Tuple[int, int, int]
metadata: Dict
class StreamConnection:
"""
Represents a single stream connection with event emission.
This class wraps a StreamDecoder, polls frames asynchronously, submits them
to the ModelController for batched inference, runs tracking, and emits results
via queues or callbacks.
Args:
stream_id: Unique identifier for this stream
decoder: StreamDecoder instance
model_controller: ModelController for batched inference
tracking_controller: TrackingController for object tracking
poll_interval: Frame polling interval in seconds (default: 0.01)
"""
def __init__(
self,
stream_id: str,
decoder,
model_controller: ModelController,
tracking_controller,
poll_interval: float = 0.01,
):
self.stream_id = stream_id
self.decoder = decoder
self.model_controller = model_controller
self.tracking_controller = tracking_controller
self.poll_interval = poll_interval
self.status = ConnectionStatus.CONNECTING
self.frame_count = 0
self.last_frame_time = 0.0
# Event emission
self.result_queue: asyncio.Queue[TrackingResult] = asyncio.Queue()
self.error_queue: asyncio.Queue[Exception] = asyncio.Queue()
# Tasks
self.poller_task: Optional[asyncio.Task] = None
self.running = False
async def start(self):
"""Start the connection (decoder and frame polling)"""
# Start decoder (runs in background thread)
self.decoder.start()
# Wait for initial connection (try for up to 10 seconds)
max_wait = 10.0
wait_interval = 0.5
elapsed = 0.0
while elapsed < max_wait:
await asyncio.sleep(wait_interval)
elapsed += wait_interval
if self.decoder.is_connected():
self.status = ConnectionStatus.CONNECTED
logger.info(f"Stream {self.stream_id} connected after {elapsed:.1f}s")
break
else:
# Timeout - but don't fail hard, let it try to connect in background
logger.warning(f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying...")
self.status = ConnectionStatus.CONNECTING
# Start frame polling task
self.running = True
self.poller_task = asyncio.create_task(self._frame_poller())
async def stop(self):
"""Stop the connection and cleanup"""
logger.info(f"Stopping stream {self.stream_id}...")
self.running = False
if self.poller_task:
self.poller_task.cancel()
try:
await self.poller_task
except asyncio.CancelledError:
pass
# Stop decoder
self.decoder.stop()
# Unregister from model controller
self.model_controller.unregister_callback(self.stream_id)
self.status = ConnectionStatus.DISCONNECTED
logger.info(f"Stream {self.stream_id} stopped")
async def _frame_poller(self):
"""Poll frames from threaded decoder and submit to model controller"""
last_frame_ptr = None
while self.running:
try:
# Poll frame from decoder (runs in thread)
frame = self.decoder.get_latest_frame(rgb=True)
# Check if we got a new frame (avoid reprocessing same frame)
if frame is not None and frame.data_ptr() != last_frame_ptr:
last_frame_ptr = frame.data_ptr()
self.last_frame_time = time.time()
self.frame_count += 1
# Submit to model controller for batched inference
await self.model_controller.submit_frame(
stream_id=self.stream_id,
frame=frame,
metadata={
"frame_number": self.frame_count,
"shape": tuple(frame.shape),
}
)
# Check decoder status
if not self.decoder.is_connected():
if self.status == ConnectionStatus.CONNECTED:
logger.warning(f"Stream {self.stream_id} disconnected")
self.status = ConnectionStatus.DISCONNECTED
# Decoder will auto-reconnect, just update status
await asyncio.sleep(1.0)
if self.decoder.is_connected():
logger.info(f"Stream {self.stream_id} reconnected")
self.status = ConnectionStatus.CONNECTED
except Exception as e:
logger.error(f"Error in frame poller for {self.stream_id}: {e}", exc_info=True)
await self.error_queue.put(e)
self.status = ConnectionStatus.ERROR
# Sleep until next poll
await asyncio.sleep(self.poll_interval)
async def _handle_inference_result(self, result: Dict[str, Any]):
"""
Callback invoked by ModelController when inference is done.
Runs tracking and emits final result.
Args:
result: Inference result dictionary
"""
try:
# Extract detections
detections = result["detections"]
# Run tracking (this is sync, so run in executor)
loop = asyncio.get_event_loop()
tracked_objects = await loop.run_in_executor(
None,
lambda: self._run_tracking_sync(detections)
)
# Create tracking result
tracking_result = TrackingResult(
stream_id=self.stream_id,
timestamp=result["timestamp"],
tracked_objects=tracked_objects,
detections=detections,
frame_shape=result["metadata"].get("shape"),
metadata=result["metadata"],
)
# Emit to result queue
await self.result_queue.put(tracking_result)
except Exception as e:
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
await self.error_queue.put(e)
def _run_tracking_sync(self, detections):
"""
Run tracking synchronously (called from executor).
Args:
detections: Detection tensor (N, 6) [x1, y1, x2, y2, conf, class_id]
Returns:
List of TrackedObject instances
"""
# Use the TrackingController's internal tracking with detections
# We need to manually update tracks since we already have detections
import torch
with self.tracking_controller._lock:
self.tracking_controller._frame_count += 1
# If no detections, just cleanup and return current tracks
if len(detections) == 0:
self.tracking_controller._cleanup_stale_tracks()
return list(self.tracking_controller._tracks.values())
# Run IoU tracking to associate detections with existing tracks
associations = self.tracking_controller._iou_tracking(detections)
# Update or create tracks
for (det_idx, track_id), detection in zip(associations, detections):
bbox = detection[:4].cpu().tolist()
confidence = float(detection[4])
class_id = int(detection[5]) if detection.shape[0] > 5 else 0
if track_id == -1:
# Create new track
new_track = self.tracking_controller._create_track(
bbox, confidence, class_id, self.tracking_controller._frame_count
)
self.tracking_controller._tracks[new_track.track_id] = new_track
else:
# Update existing track
self.tracking_controller._tracks[track_id].update(
bbox, confidence, self.tracking_controller._frame_count
)
# Cleanup stale tracks
self.tracking_controller._cleanup_stale_tracks()
return list(self.tracking_controller._tracks.values())
async def tracking_results(self) -> AsyncIterator[TrackingResult]:
"""
Async generator for tracking results.
Usage:
async for result in connection.tracking_results():
print(result.tracked_objects)
Yields:
TrackingResult objects as they become available
"""
while self.running or not self.result_queue.empty():
try:
result = await asyncio.wait_for(self.result_queue.get(), timeout=1.0)
yield result
except asyncio.TimeoutError:
continue
async def errors(self) -> AsyncIterator[Exception]:
"""
Async generator for errors.
Yields:
Exception objects as they occur
"""
while self.running or not self.error_queue.empty():
try:
error = await asyncio.wait_for(self.error_queue.get(), timeout=1.0)
yield error
except asyncio.TimeoutError:
continue
def get_stats(self) -> Dict[str, Any]:
"""Get connection statistics"""
return {
"stream_id": self.stream_id,
"status": self.status.value,
"frame_count": self.frame_count,
"last_frame_time": self.last_frame_time,
"decoder_connected": self.decoder.is_connected(),
"decoder_buffer_size": self.decoder.get_buffer_size(),
"result_queue_size": self.result_queue.qsize(),
"error_queue_size": self.error_queue.qsize(),
}
class StreamConnectionManager:
"""
High-level manager for stream connections with batched inference.
This manager coordinates multiple RTSP streams, batched model inference,
and object tracking through an async event-driven API.
Args:
gpu_id: GPU device ID (default: 0)
batch_size: Maximum batch size for inference (default: 16)
force_timeout: Force buffer switch timeout in seconds (default: 0.05)
poll_interval: Frame polling interval in seconds (default: 0.01)
Example:
manager = StreamConnectionManager(gpu_id=0, batch_size=16)
await manager.initialize(model_path="yolov8n.trt", ...)
connection = await manager.connect_stream(rtsp_url, on_tracking_result=callback)
await asyncio.sleep(60)
await manager.shutdown()
"""
def __init__(
self,
gpu_id: int = 0,
batch_size: int = 16,
force_timeout: float = 0.05,
poll_interval: float = 0.01,
):
self.gpu_id = gpu_id
self.batch_size = batch_size
self.force_timeout = force_timeout
self.poll_interval = poll_interval
# Factories
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
self.tracking_factory = TrackingFactory(gpu_id=gpu_id)
self.model_repository = TensorRTModelRepository(gpu_id=gpu_id)
# Controllers
self.model_controller: Optional[ModelController] = None
self.tracking_controller = None
# Connections
self.connections: Dict[str, StreamConnection] = {}
# State
self.initialized = False
async def initialize(
self,
model_path: str,
model_id: str = "detector",
preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None,
num_contexts: int = 4,
):
"""
Initialize the manager with a model.
Args:
model_path: Path to TensorRT model file
model_id: Model identifier (default: "detector")
preprocess_fn: Preprocessing function (e.g., YOLOv8Utils.preprocess)
postprocess_fn: Postprocessing function (e.g., YOLOv8Utils.postprocess)
num_contexts: Number of TensorRT execution contexts (default: 4)
"""
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
# Load model
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
lambda: self.model_repository.load_model(model_id, model_path, num_contexts=num_contexts)
)
logger.info(f"Loaded model {model_id} from {model_path}")
# Create model controller
self.model_controller = ModelController(
model_repository=self.model_repository,
model_id=model_id,
batch_size=self.batch_size,
force_timeout=self.force_timeout,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
await self.model_controller.start()
# Create tracking controller
self.tracking_controller = self.tracking_factory.create_controller(
model_repository=self.model_repository,
model_id=model_id,
tracker_type="iou",
)
logger.info("TrackingController created")
self.initialized = True
logger.info("StreamConnectionManager initialized successfully")
async def connect_stream(
self,
rtsp_url: str,
stream_id: Optional[str] = None,
on_tracking_result: Optional[Callable] = None,
on_error: Optional[Callable] = None,
buffer_size: int = 30,
) -> StreamConnection:
"""
Connect to a stream and start processing.
Args:
rtsp_url: RTSP stream URL
stream_id: Optional stream identifier (auto-generated if not provided)
on_tracking_result: Optional callback for tracking results (sync or async)
on_error: Optional callback for errors (sync or async)
buffer_size: Decoder buffer size (default: 30)
Returns:
StreamConnection object for this stream
Raises:
RuntimeError: If manager is not initialized
ConnectionError: If stream connection fails
"""
if not self.initialized:
raise RuntimeError("Manager not initialized. Call initialize() first.")
# Generate stream ID if not provided
if stream_id is None:
stream_id = f"stream_{len(self.connections)}"
logger.info(f"Connecting to stream {stream_id}: {rtsp_url}")
# Create decoder
decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=buffer_size)
# Create connection
connection = StreamConnection(
stream_id=stream_id,
decoder=decoder,
model_controller=self.model_controller,
tracking_controller=self.tracking_controller,
poll_interval=self.poll_interval,
)
# Register callback with model controller
self.model_controller.register_callback(
stream_id,
connection._handle_inference_result
)
# Start connection
await connection.start()
# Store connection
self.connections[stream_id] = connection
# Set up user callbacks if provided
if on_tracking_result:
asyncio.create_task(self._forward_results(connection, on_tracking_result))
if on_error:
asyncio.create_task(self._forward_errors(connection, on_error))
logger.info(f"Stream {stream_id} connected successfully")
return connection
async def disconnect_stream(self, stream_id: str):
"""
Disconnect and cleanup a stream.
Args:
stream_id: Stream identifier to disconnect
"""
connection = self.connections.get(stream_id)
if connection:
await connection.stop()
del self.connections[stream_id]
logger.info(f"Stream {stream_id} disconnected")
async def disconnect_all(self):
"""Disconnect all streams"""
logger.info("Disconnecting all streams...")
stream_ids = list(self.connections.keys())
for stream_id in stream_ids:
await self.disconnect_stream(stream_id)
async def shutdown(self):
"""Shutdown the manager and cleanup all resources"""
logger.info("Shutting down StreamConnectionManager...")
# Disconnect all streams
await self.disconnect_all()
# Stop model controller
if self.model_controller:
await self.model_controller.stop()
# Note: Model repository cleanup is sync and may cause segfaults
# Leaving cleanup to garbage collection for now
self.initialized = False
logger.info("StreamConnectionManager shutdown complete")
async def _forward_results(self, connection: StreamConnection, callback: Callable):
"""
Forward results from connection to user callback.
Args:
connection: StreamConnection to listen to
callback: User callback (sync or async)
"""
try:
async for result in connection.tracking_results():
if asyncio.iscoroutinefunction(callback):
await callback(result)
else:
callback(result)
except Exception as e:
logger.error(f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True)
async def _forward_errors(self, connection: StreamConnection, callback: Callable):
"""
Forward errors from connection to user callback.
Args:
connection: StreamConnection to listen to
callback: User callback (sync or async)
"""
try:
async for error in connection.errors():
if asyncio.iscoroutinefunction(callback):
await callback(error)
else:
callback(error)
except Exception as e:
logger.error(f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True)
def get_stats(self) -> Dict[str, Any]:
"""
Get statistics for all connections.
Returns:
Dictionary with manager and connection statistics
"""
return {
"manager": {
"initialized": self.initialized,
"gpu_id": self.gpu_id,
"num_connections": len(self.connections),
"batch_size": self.batch_size,
"force_timeout": self.force_timeout,
"poll_interval": self.poll_interval,
},
"model_controller": self.model_controller.get_stats() if self.model_controller else {},
"connections": {
stream_id: conn.get_stats()
for stream_id, conn in self.connections.items()
},
}

373
test_event_driven.py Normal file
View file

@ -0,0 +1,373 @@
#!/usr/bin/env python3
"""
Test script for event-driven stream processing with batched inference.
This demonstrates the new AsyncIO-based API for connecting to RTSP streams,
processing frames through batched inference, and receiving tracking results
via callbacks and async generators.
"""
import asyncio
import os
import time
import logging
from dotenv import load_dotenv
from services import StreamConnectionManager, YOLOv8Utils, COCO_CLASSES
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Example 1: Simple callback pattern
async def example_callback_pattern():
"""Demonstrates the simple callback pattern for a single stream"""
logger.info("=== Example 1: Callback Pattern ===")
# Load environment variables
load_dotenv()
camera_url = os.getenv('CAMERA_URL_1')
if not camera_url:
logger.error("CAMERA_URL_1 not found in .env file")
return
# Create manager
manager = StreamConnectionManager(
gpu_id=0,
batch_size=16,
force_timeout=0.05, # 50ms
poll_interval=0.01, # 100 FPS
)
# Initialize with YOLOv8 model
model_path = "models/yolov8n.trt" # Adjust path as needed
if not os.path.exists(model_path):
logger.error(f"Model file not found: {model_path}")
return
await manager.initialize(
model_path=model_path,
model_id="yolo",
preprocess_fn=YOLOv8Utils.preprocess,
postprocess_fn=YOLOv8Utils.postprocess,
)
# Define callback for tracking results
def on_tracking_result(result):
logger.info(f"[{result.stream_id}] Frame {result.metadata.get('frame_number', 0)}")
logger.info(f" Timestamp: {result.timestamp:.3f}")
logger.info(f" Tracked objects: {len(result.tracked_objects)}")
for obj in result.tracked_objects[:5]: # Show first 5
class_name = COCO_CLASSES.get(obj.class_id, f"Class {obj.class_id}")
logger.info(
f" Track ID {obj.track_id}: {class_name}, "
f"conf={obj.confidence:.2f}, bbox={obj.bbox}"
)
def on_error(error):
logger.error(f"Stream error: {error}")
# Connect to stream
connection = await manager.connect_stream(
rtsp_url=camera_url,
stream_id="camera1",
on_tracking_result=on_tracking_result,
on_error=on_error,
)
# Let it run for 30 seconds
logger.info("Processing stream for 30 seconds...")
await asyncio.sleep(30)
# Get statistics
stats = manager.get_stats()
logger.info("=== Statistics ===")
logger.info(f"Manager: {stats['manager']}")
logger.info(f"Model Controller: {stats['model_controller']}")
logger.info(f"Connection: {stats['connections']['camera1']}")
# Cleanup
await manager.shutdown()
logger.info("Example 1 complete\n")
# Example 2: Async generator pattern with multiple streams
async def example_async_generator_pattern():
"""Demonstrates async generator pattern for multiple streams"""
logger.info("=== Example 2: Async Generator Pattern (Multiple Streams) ===")
# Load environment variables
load_dotenv()
camera_urls = []
for i in range(1, 5): # Try to load 4 cameras
url = os.getenv(f'CAMERA_URL_{i}')
if url:
camera_urls.append((url, f"camera{i}"))
if not camera_urls:
logger.error("No camera URLs found in .env file")
return
logger.info(f"Found {len(camera_urls)} camera(s)")
# Create manager with larger batch for multiple streams
manager = StreamConnectionManager(
gpu_id=0,
batch_size=32, # Larger batch for multiple streams
force_timeout=0.05,
)
# Initialize
model_path = "models/yolov8n.trt"
if not os.path.exists(model_path):
logger.error(f"Model file not found: {model_path}")
return
await manager.initialize(
model_path=model_path,
preprocess_fn=YOLOv8Utils.preprocess,
postprocess_fn=YOLOv8Utils.postprocess,
)
# Connect to all streams
connections = []
for url, stream_id in camera_urls:
try:
connection = await manager.connect_stream(
rtsp_url=url,
stream_id=stream_id,
)
connections.append((connection, stream_id))
logger.info(f"Connected to {stream_id}")
except Exception as e:
logger.error(f"Failed to connect to {stream_id}: {e}")
# Process each stream with async generator
async def process_stream(connection, stream_name):
"""Process results from a single stream"""
frame_count = 0
person_detections = 0
async for result in connection.tracking_results():
frame_count += 1
# Count person detections (class_id 0 in COCO)
for obj in result.tracked_objects:
if obj.class_id == 0:
person_detections += 1
# Log every 10th frame
if frame_count % 10 == 0:
logger.info(
f"[{stream_name}] Processed {frame_count} frames, "
f"{person_detections} person detections"
)
# Stop after 100 frames
if frame_count >= 100:
break
# Run all streams concurrently
tasks = [
asyncio.create_task(process_stream(conn, name))
for conn, name in connections
]
# Wait for all tasks to complete
await asyncio.gather(*tasks)
# Get final statistics
stats = manager.get_stats()
logger.info("\n=== Final Statistics ===")
logger.info(f"Total connections: {stats['manager']['num_connections']}")
logger.info(f"Frames processed: {stats['model_controller']['total_frames_processed']}")
logger.info(f"Batches processed: {stats['model_controller']['total_batches_processed']}")
logger.info(f"Avg batch size: {stats['model_controller']['avg_batch_size']:.2f}")
# Cleanup
await manager.shutdown()
logger.info("Example 2 complete\n")
# Example 3: Queue-based pattern
async def example_queue_pattern():
"""Demonstrates direct queue access for custom processing"""
logger.info("=== Example 3: Queue-Based Pattern ===")
# Load environment
load_dotenv()
camera_url = os.getenv('CAMERA_URL_1')
if not camera_url:
logger.error("CAMERA_URL_1 not found in .env file")
return
# Create manager
manager = StreamConnectionManager(gpu_id=0, batch_size=16)
# Initialize
model_path = "models/yolov8n.trt"
if not os.path.exists(model_path):
logger.error(f"Model file not found: {model_path}")
return
await manager.initialize(
model_path=model_path,
preprocess_fn=YOLOv8Utils.preprocess,
postprocess_fn=YOLOv8Utils.postprocess,
)
# Connect to stream (no callback)
connection = await manager.connect_stream(
rtsp_url=camera_url,
stream_id="main_camera",
)
# Use the built-in queue directly
result_queue = connection.result_queue
# Process results from queue
processed_count = 0
while processed_count < 50: # Process 50 frames
try:
result = await asyncio.wait_for(result_queue.get(), timeout=5.0)
processed_count += 1
# Custom processing
has_person = any(obj.class_id == 0 for obj in result.tracked_objects)
has_car = any(obj.class_id == 2 for obj in result.tracked_objects)
if has_person or has_car:
logger.info(
f"Frame {processed_count}: "
f"Person={'Yes' if has_person else 'No'}, "
f"Car={'Yes' if has_car else 'No'}"
)
except asyncio.TimeoutError:
logger.warning("Timeout waiting for result")
break
# Cleanup
await manager.shutdown()
logger.info("Example 3 complete\n")
# Example 4: Performance monitoring
async def example_performance_monitoring():
"""Demonstrates real-time performance monitoring"""
logger.info("=== Example 4: Performance Monitoring ===")
# Load environment
load_dotenv()
camera_url = os.getenv('CAMERA_URL_1')
if not camera_url:
logger.error("CAMERA_URL_1 not found in .env file")
return
# Create manager
manager = StreamConnectionManager(
gpu_id=0,
batch_size=16,
force_timeout=0.05,
)
# Initialize
model_path = "models/yolov8n.trt"
if not os.path.exists(model_path):
logger.error(f"Model file not found: {model_path}")
return
await manager.initialize(
model_path=model_path,
preprocess_fn=YOLOv8Utils.preprocess,
postprocess_fn=YOLOv8Utils.postprocess,
)
# Track performance metrics
frame_times = []
last_frame_time = None
def on_tracking_result(result):
nonlocal last_frame_time
current_time = time.time()
if last_frame_time is not None:
frame_interval = current_time - last_frame_time
frame_times.append(frame_interval)
last_frame_time = current_time
# Connect
connection = await manager.connect_stream(
rtsp_url=camera_url,
on_tracking_result=on_tracking_result,
)
# Monitor stats periodically
for i in range(6): # Monitor for 60 seconds
await asyncio.sleep(10)
stats = manager.get_stats()
model_stats = stats['model_controller']
conn_stats = stats['connections'].get('stream_0', {})
logger.info(f"\n=== Stats Update {i+1} ===")
logger.info(f"Buffer A: {model_stats['buffer_a_size']} ({model_stats['buffer_a_state']})")
logger.info(f"Buffer B: {model_stats['buffer_b_size']} ({model_stats['buffer_b_state']})")
logger.info(f"Active buffer: {model_stats['active_buffer']}")
logger.info(f"Total frames: {model_stats['total_frames_processed']}")
logger.info(f"Total batches: {model_stats['total_batches_processed']}")
logger.info(f"Avg batch size: {model_stats['avg_batch_size']:.2f}")
logger.info(f"Decoder frames: {conn_stats.get('frame_count', 0)}")
if frame_times:
avg_fps = 1.0 / (sum(frame_times) / len(frame_times))
logger.info(f"Processing FPS: {avg_fps:.2f}")
# Cleanup
await manager.shutdown()
logger.info("Example 4 complete\n")
async def main():
"""Run all examples"""
logger.info("Starting event-driven stream processing tests\n")
# Choose which example to run
choice = os.getenv('EXAMPLE', '1')
if choice == '1':
await example_callback_pattern()
elif choice == '2':
await example_async_generator_pattern()
elif choice == '3':
await example_queue_pattern()
elif choice == '4':
await example_performance_monitoring()
elif choice == 'all':
await example_callback_pattern()
await asyncio.sleep(2)
await example_async_generator_pattern()
await asyncio.sleep(2)
await example_queue_pattern()
await asyncio.sleep(2)
await example_performance_monitoring()
else:
logger.error(f"Invalid choice: {choice}")
logger.info("Set EXAMPLE env var to 1, 2, 3, 4, or 'all'")
logger.info("All tests complete!")
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("\nInterrupted by user")
except Exception as e:
logger.error(f"Error: {e}", exc_info=True)

117
test_event_driven_quick.py Executable file
View file

@ -0,0 +1,117 @@
#!/usr/bin/env python3
"""
Quick test for event-driven stream processing - runs for 20 seconds.
"""
import asyncio
import os
import logging
from dotenv import load_dotenv
from services import StreamConnectionManager, YOLOv8Utils, COCO_CLASSES
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
async def main():
"""Quick test with callback pattern"""
logger.info("=== Quick Event-Driven Test (20 seconds) ===")
# Load environment variables
load_dotenv()
camera_url = os.getenv('CAMERA_URL_1')
if not camera_url:
logger.error("CAMERA_URL_1 not found in .env file")
return
# Create manager
manager = StreamConnectionManager(
gpu_id=0,
batch_size=16,
force_timeout=0.05, # 50ms
poll_interval=0.01, # 100 FPS
)
# Initialize with YOLOv8 model
model_path = "models/yolov8n.trt"
logger.info(f"Initializing with model: {model_path}")
await manager.initialize(
model_path=model_path,
model_id="yolo",
preprocess_fn=YOLOv8Utils.preprocess,
postprocess_fn=YOLOv8Utils.postprocess,
)
result_count = 0
# Define callback for tracking results
def on_tracking_result(result):
nonlocal result_count
result_count += 1
if result_count % 5 == 0: # Log every 5th result
logger.info(f"[{result.stream_id}] Frame {result.metadata.get('frame_number', 0)}")
logger.info(f" Tracked objects: {len(result.tracked_objects)}")
for obj in result.tracked_objects[:3]: # Show first 3
class_name = COCO_CLASSES.get(obj.class_id, f"Class {obj.class_id}")
logger.info(
f" Track ID {obj.track_id}: {class_name}, "
f"conf={obj.confidence:.2f}"
)
def on_error(error):
logger.error(f"Stream error: {error}")
# Connect to stream
logger.info(f"Connecting to stream...")
connection = await manager.connect_stream(
rtsp_url=camera_url,
stream_id="test_camera",
on_tracking_result=on_tracking_result,
on_error=on_error,
)
# Monitor for 20 seconds with stats updates
for i in range(4): # 4 x 5 seconds = 20 seconds
await asyncio.sleep(5)
stats = manager.get_stats()
model_stats = stats['model_controller']
logger.info(f"\n=== Stats Update {i+1}/4 ===")
logger.info(f"Results received: {result_count}")
logger.info(f"Buffer A: {model_stats['buffer_a_size']} ({model_stats['buffer_a_state']})")
logger.info(f"Buffer B: {model_stats['buffer_b_size']} ({model_stats['buffer_b_state']})")
logger.info(f"Active buffer: {model_stats['active_buffer']}")
logger.info(f"Total frames processed: {model_stats['total_frames_processed']}")
logger.info(f"Total batches: {model_stats['total_batches_processed']}")
logger.info(f"Avg batch size: {model_stats['avg_batch_size']:.2f}")
# Final statistics
stats = manager.get_stats()
logger.info("\n=== Final Statistics ===")
logger.info(f"Total results received: {result_count}")
logger.info(f"Manager: {stats['manager']}")
logger.info(f"Model Controller: {stats['model_controller']}")
logger.info(f"Connection: {stats['connections']['test_camera']}")
# Cleanup
logger.info("\nShutting down...")
await manager.shutdown()
logger.info("Test complete!")
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("\nInterrupted by user")
except Exception as e:
logger.error(f"Error: {e}", exc_info=True)

View file

@ -509,7 +509,7 @@ def main_multi_window():
if __name__ == "__main__": if __name__ == "__main__":
# Run single camera visualization # Run single camera visualization
main() # main()
# Uncomment to run multi-window visualization # Uncomment to run multi-window visualization
# main_multi_window() main_multi_window()