batch processing/event driven
This commit is contained in:
parent
e71316ef3d
commit
dd57b5a246
7 changed files with 2673 additions and 2 deletions
1108
EVENT_DRIVEN_DESIGN.md
Normal file
1108
EVENT_DRIVEN_DESIGN.md
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -8,6 +8,8 @@ from .model_repository import TensorRTModelRepository, ModelMetadata, ExecutionC
|
||||||
from .tracking_controller import TrackingController, TrackedObject
|
from .tracking_controller import TrackingController, TrackedObject
|
||||||
from .tracking_factory import TrackingFactory
|
from .tracking_factory import TrackingFactory
|
||||||
from .yolo import YOLOv8Utils, COCO_CLASSES
|
from .yolo import YOLOv8Utils, COCO_CLASSES
|
||||||
|
from .model_controller import ModelController, BatchFrame, BufferState
|
||||||
|
from .stream_connection_manager import StreamConnectionManager, StreamConnection, TrackingResult
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'StreamDecoderFactory',
|
'StreamDecoderFactory',
|
||||||
|
|
@ -24,4 +26,10 @@ __all__ = [
|
||||||
'TrackingFactory',
|
'TrackingFactory',
|
||||||
'YOLOv8Utils',
|
'YOLOv8Utils',
|
||||||
'COCO_CLASSES',
|
'COCO_CLASSES',
|
||||||
|
'ModelController',
|
||||||
|
'BatchFrame',
|
||||||
|
'BufferState',
|
||||||
|
'StreamConnectionManager',
|
||||||
|
'StreamConnection',
|
||||||
|
'TrackingResult',
|
||||||
]
|
]
|
||||||
|
|
|
||||||
499
services/model_controller.py
Normal file
499
services/model_controller.py
Normal file
|
|
@ -0,0 +1,499 @@
|
||||||
|
"""
|
||||||
|
ModelController - Async batching layer with ping-pong buffers for inference.
|
||||||
|
|
||||||
|
This module provides batched inference coordination using ping-pong circular buffers
|
||||||
|
with force-switch timeout mechanism.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import torch
|
||||||
|
from typing import Dict, List, Optional, Callable, Any
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchFrame:
|
||||||
|
"""Represents a frame in the batch buffer"""
|
||||||
|
stream_id: str
|
||||||
|
frame: torch.Tensor # GPU tensor (3, H, W)
|
||||||
|
timestamp: float
|
||||||
|
metadata: Dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class BufferState(Enum):
|
||||||
|
"""State of a ping-pong buffer"""
|
||||||
|
IDLE = "idle"
|
||||||
|
FILLING = "filling"
|
||||||
|
PROCESSING = "processing"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelController:
|
||||||
|
"""
|
||||||
|
Manages batched inference with ping-pong buffers and force-switch timeout.
|
||||||
|
|
||||||
|
This controller accumulates frames from multiple streams into batches,
|
||||||
|
processes them through a model repository, and routes results back to
|
||||||
|
stream-specific callbacks.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Ping-pong circular buffers (BufferA/BufferB)
|
||||||
|
- Force-switch timeout to prevent batch starvation
|
||||||
|
- Async event-driven processing
|
||||||
|
- Thread-safe frame submission
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_repository: TensorRT model repository for inference
|
||||||
|
model_id: Model identifier in the repository
|
||||||
|
batch_size: Maximum frames per batch (default: 16)
|
||||||
|
force_timeout: Max wait time before forcing buffer switch in seconds (default: 0.05)
|
||||||
|
preprocess_fn: Optional preprocessing function for frames
|
||||||
|
postprocess_fn: Optional postprocessing function for model outputs
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_repository,
|
||||||
|
model_id: str,
|
||||||
|
batch_size: int = 16,
|
||||||
|
force_timeout: float = 0.05,
|
||||||
|
preprocess_fn: Optional[Callable] = None,
|
||||||
|
postprocess_fn: Optional[Callable] = None,
|
||||||
|
):
|
||||||
|
self.model_repository = model_repository
|
||||||
|
self.model_id = model_id
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.force_timeout = force_timeout
|
||||||
|
self.preprocess_fn = preprocess_fn
|
||||||
|
self.postprocess_fn = postprocess_fn
|
||||||
|
|
||||||
|
# Detect model's actual batch size from input shape
|
||||||
|
self.model_batch_size = self._detect_model_batch_size()
|
||||||
|
if self.model_batch_size == 1:
|
||||||
|
logger.warning(
|
||||||
|
f"Model '{model_id}' has fixed batch_size=1. "
|
||||||
|
f"Will process frames sequentially. Consider rebuilding model with dynamic batching."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"Model '{model_id}' supports batch_size={self.model_batch_size}")
|
||||||
|
|
||||||
|
# Ping-pong buffers
|
||||||
|
self.buffer_a: List[BatchFrame] = []
|
||||||
|
self.buffer_b: List[BatchFrame] = []
|
||||||
|
|
||||||
|
# Buffer states
|
||||||
|
self.active_buffer = "A" # Which buffer is currently active (filling)
|
||||||
|
self.buffer_a_state = BufferState.IDLE
|
||||||
|
self.buffer_b_state = BufferState.IDLE
|
||||||
|
|
||||||
|
# Async coordination
|
||||||
|
self.buffer_lock = asyncio.Lock()
|
||||||
|
self.last_submit_time = time.time()
|
||||||
|
|
||||||
|
# Tasks
|
||||||
|
self.timeout_task: Optional[asyncio.Task] = None
|
||||||
|
self.processor_task: Optional[asyncio.Task] = None
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
# Result callbacks (stream_id -> callback)
|
||||||
|
self.result_callbacks: Dict[str, Callable] = {}
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
self.total_frames_processed = 0
|
||||||
|
self.total_batches_processed = 0
|
||||||
|
|
||||||
|
def _detect_model_batch_size(self) -> int:
|
||||||
|
"""
|
||||||
|
Detect the model's batch size from its input shape.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Maximum batch size supported by the model (1 for fixed batch size models)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
metadata = self.model_repository.get_metadata(self.model_id)
|
||||||
|
# Get first input tensor shape
|
||||||
|
first_input = list(metadata.inputs.values())[0]
|
||||||
|
batch_dim = first_input["shape"][0]
|
||||||
|
|
||||||
|
# batch_dim can be -1 (dynamic), 1 (fixed), or N (fixed batch size)
|
||||||
|
if batch_dim == -1:
|
||||||
|
# Dynamic batch size - use user-specified batch_size
|
||||||
|
return self.batch_size
|
||||||
|
else:
|
||||||
|
# Fixed batch size
|
||||||
|
return batch_dim
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Start the controller background tasks"""
|
||||||
|
if self.running:
|
||||||
|
logger.warning("ModelController already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.running = True
|
||||||
|
self.timeout_task = asyncio.create_task(self._timeout_monitor())
|
||||||
|
self.processor_task = asyncio.create_task(self._batch_processor())
|
||||||
|
logger.info("ModelController started")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""Stop the controller and cleanup"""
|
||||||
|
if not self.running:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Stopping ModelController...")
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
# Cancel tasks
|
||||||
|
if self.timeout_task:
|
||||||
|
self.timeout_task.cancel()
|
||||||
|
try:
|
||||||
|
await self.timeout_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self.processor_task:
|
||||||
|
self.processor_task.cancel()
|
||||||
|
try:
|
||||||
|
await self.processor_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Process any remaining frames
|
||||||
|
await self._process_remaining_buffers()
|
||||||
|
logger.info("ModelController stopped")
|
||||||
|
|
||||||
|
def register_callback(self, stream_id: str, callback: Callable):
|
||||||
|
"""
|
||||||
|
Register a callback for inference results from a stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: Unique stream identifier
|
||||||
|
callback: Callback function to receive results (can be sync or async)
|
||||||
|
"""
|
||||||
|
self.result_callbacks[stream_id] = callback
|
||||||
|
logger.debug(f"Registered callback for stream: {stream_id}")
|
||||||
|
|
||||||
|
def unregister_callback(self, stream_id: str):
|
||||||
|
"""
|
||||||
|
Unregister a stream callback.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: Stream identifier to unregister
|
||||||
|
"""
|
||||||
|
self.result_callbacks.pop(stream_id, None)
|
||||||
|
logger.debug(f"Unregistered callback for stream: {stream_id}")
|
||||||
|
|
||||||
|
async def submit_frame(
|
||||||
|
self,
|
||||||
|
stream_id: str,
|
||||||
|
frame: torch.Tensor,
|
||||||
|
metadata: Optional[Dict] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Submit a frame for batched inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: Unique stream identifier
|
||||||
|
frame: GPU tensor (3, H, W) or (C, H, W)
|
||||||
|
metadata: Optional metadata to attach to the frame
|
||||||
|
"""
|
||||||
|
async with self.buffer_lock:
|
||||||
|
batch_frame = BatchFrame(
|
||||||
|
stream_id=stream_id,
|
||||||
|
frame=frame,
|
||||||
|
timestamp=time.time(),
|
||||||
|
metadata=metadata or {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to active buffer
|
||||||
|
if self.active_buffer == "A":
|
||||||
|
self.buffer_a.append(batch_frame)
|
||||||
|
self.buffer_a_state = BufferState.FILLING
|
||||||
|
buffer_size = len(self.buffer_a)
|
||||||
|
else:
|
||||||
|
self.buffer_b.append(batch_frame)
|
||||||
|
self.buffer_b_state = BufferState.FILLING
|
||||||
|
buffer_size = len(self.buffer_b)
|
||||||
|
|
||||||
|
self.last_submit_time = time.time()
|
||||||
|
|
||||||
|
# Check if we should immediately swap (batch full)
|
||||||
|
if buffer_size >= self.batch_size:
|
||||||
|
await self._try_swap_buffers()
|
||||||
|
|
||||||
|
async def _timeout_monitor(self):
|
||||||
|
"""Monitor force-switch timeout"""
|
||||||
|
while self.running:
|
||||||
|
await asyncio.sleep(0.01) # Check every 10ms
|
||||||
|
|
||||||
|
async with self.buffer_lock:
|
||||||
|
time_since_submit = time.time() - self.last_submit_time
|
||||||
|
|
||||||
|
# Check if timeout expired and we have frames waiting
|
||||||
|
if time_since_submit >= self.force_timeout:
|
||||||
|
active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
||||||
|
if len(active_buffer) > 0:
|
||||||
|
await self._try_swap_buffers()
|
||||||
|
|
||||||
|
async def _try_swap_buffers(self):
|
||||||
|
"""
|
||||||
|
Attempt to swap ping-pong buffers.
|
||||||
|
Only swaps if the inactive buffer is not currently processing.
|
||||||
|
|
||||||
|
This method should be called with buffer_lock held.
|
||||||
|
"""
|
||||||
|
# Check if inactive buffer is available
|
||||||
|
inactive_state = self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
|
||||||
|
|
||||||
|
if inactive_state != BufferState.PROCESSING:
|
||||||
|
# Swap active buffer
|
||||||
|
old_active = self.active_buffer
|
||||||
|
self.active_buffer = "B" if old_active == "A" else "A"
|
||||||
|
|
||||||
|
# Mark old active buffer as ready for processing
|
||||||
|
if old_active == "A":
|
||||||
|
self.buffer_a_state = BufferState.PROCESSING
|
||||||
|
buffer_size = len(self.buffer_a)
|
||||||
|
else:
|
||||||
|
self.buffer_b_state = BufferState.PROCESSING
|
||||||
|
buffer_size = len(self.buffer_b)
|
||||||
|
|
||||||
|
logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})")
|
||||||
|
|
||||||
|
async def _batch_processor(self):
|
||||||
|
"""Background task that processes batches when available"""
|
||||||
|
while self.running:
|
||||||
|
await asyncio.sleep(0.001) # Check every 1ms
|
||||||
|
|
||||||
|
# Check if buffer A needs processing
|
||||||
|
if self.buffer_a_state == BufferState.PROCESSING:
|
||||||
|
await self._process_buffer("A")
|
||||||
|
|
||||||
|
# Check if buffer B needs processing
|
||||||
|
if self.buffer_b_state == BufferState.PROCESSING:
|
||||||
|
await self._process_buffer("B")
|
||||||
|
|
||||||
|
async def _process_buffer(self, buffer_name: str):
|
||||||
|
"""
|
||||||
|
Process a buffer through inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
buffer_name: "A" or "B"
|
||||||
|
"""
|
||||||
|
# Extract buffer to process
|
||||||
|
async with self.buffer_lock:
|
||||||
|
if buffer_name == "A":
|
||||||
|
batch = self.buffer_a.copy()
|
||||||
|
self.buffer_a.clear()
|
||||||
|
else:
|
||||||
|
batch = self.buffer_b.copy()
|
||||||
|
self.buffer_b.clear()
|
||||||
|
|
||||||
|
if len(batch) == 0:
|
||||||
|
# Mark as idle and return
|
||||||
|
async with self.buffer_lock:
|
||||||
|
if buffer_name == "A":
|
||||||
|
self.buffer_a_state = BufferState.IDLE
|
||||||
|
else:
|
||||||
|
self.buffer_b_state = BufferState.IDLE
|
||||||
|
return
|
||||||
|
|
||||||
|
# Process batch (outside lock to allow concurrent submissions)
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
results = await self._run_batch_inference(batch)
|
||||||
|
inference_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
self.total_frames_processed += len(batch)
|
||||||
|
self.total_batches_processed += 1
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Processed batch of {len(batch)} frames in {inference_time*1000:.2f}ms "
|
||||||
|
f"({inference_time*1000/len(batch):.2f}ms per frame)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit results to callbacks
|
||||||
|
for batch_frame, result in zip(batch, results):
|
||||||
|
callback = self.result_callbacks.get(batch_frame.stream_id)
|
||||||
|
if callback:
|
||||||
|
# Schedule callback asynchronously
|
||||||
|
if asyncio.iscoroutinefunction(callback):
|
||||||
|
asyncio.create_task(callback(result))
|
||||||
|
else:
|
||||||
|
# Run sync callback in executor to avoid blocking
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.call_soon(lambda cb=callback, r=result: cb(r))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing batch: {e}", exc_info=True)
|
||||||
|
# TODO: Emit error events to streams
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Mark buffer as idle
|
||||||
|
async with self.buffer_lock:
|
||||||
|
if buffer_name == "A":
|
||||||
|
self.buffer_a_state = BufferState.IDLE
|
||||||
|
else:
|
||||||
|
self.buffer_b_state = BufferState.IDLE
|
||||||
|
|
||||||
|
async def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Run inference on a batch of frames.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: List of BatchFrame objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of detection results (one per frame)
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
# Check if model supports batching
|
||||||
|
if self.model_batch_size == 1:
|
||||||
|
# Process frames one at a time for batch_size=1 models
|
||||||
|
return await self._run_sequential_inference(batch, loop)
|
||||||
|
else:
|
||||||
|
# Use true batching for models that support it
|
||||||
|
return await self._run_batched_inference(batch, loop)
|
||||||
|
|
||||||
|
async def _run_sequential_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
|
||||||
|
"""Run inference sequentially for batch_size=1 models"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for batch_frame in batch:
|
||||||
|
# Preprocess frame
|
||||||
|
if self.preprocess_fn:
|
||||||
|
processed = self.preprocess_fn(batch_frame.frame)
|
||||||
|
else:
|
||||||
|
# Ensure we have batch dimension
|
||||||
|
processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame
|
||||||
|
|
||||||
|
# Run inference for this frame
|
||||||
|
outputs = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda p=processed: self.model_repository.infer(
|
||||||
|
self.model_id,
|
||||||
|
{"images": p},
|
||||||
|
synchronize=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Postprocess
|
||||||
|
if self.postprocess_fn:
|
||||||
|
try:
|
||||||
|
detections = self.postprocess_fn(outputs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
|
||||||
|
# Return empty detections on error
|
||||||
|
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
|
||||||
|
else:
|
||||||
|
detections = outputs
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"stream_id": batch_frame.stream_id,
|
||||||
|
"timestamp": batch_frame.timestamp,
|
||||||
|
"detections": detections,
|
||||||
|
"metadata": batch_frame.metadata,
|
||||||
|
}
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def _run_batched_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
|
||||||
|
"""Run true batched inference for models that support it"""
|
||||||
|
# Preprocess frames (on GPU)
|
||||||
|
preprocessed = []
|
||||||
|
for batch_frame in batch:
|
||||||
|
if self.preprocess_fn:
|
||||||
|
processed = self.preprocess_fn(batch_frame.frame)
|
||||||
|
# Preprocess may return (1, C, H, W), squeeze to (C, H, W)
|
||||||
|
if processed.dim() == 4 and processed.shape[0] == 1:
|
||||||
|
processed = processed.squeeze(0)
|
||||||
|
else:
|
||||||
|
processed = batch_frame.frame
|
||||||
|
preprocessed.append(processed)
|
||||||
|
|
||||||
|
# Stack into batch tensor: (N, C, H, W)
|
||||||
|
batch_tensor = torch.stack(preprocessed, dim=0)
|
||||||
|
|
||||||
|
# Limit batch size to model's max batch size
|
||||||
|
if batch_tensor.shape[0] > self.model_batch_size:
|
||||||
|
logger.warning(
|
||||||
|
f"Batch size {batch_tensor.shape[0]} exceeds model max {self.model_batch_size}, "
|
||||||
|
f"will split into sub-batches"
|
||||||
|
)
|
||||||
|
# TODO: Handle splitting into sub-batches
|
||||||
|
batch_tensor = batch_tensor[:self.model_batch_size]
|
||||||
|
batch = batch[:self.model_batch_size]
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
outputs = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self.model_repository.infer(
|
||||||
|
self.model_id,
|
||||||
|
{"images": batch_tensor},
|
||||||
|
synchronize=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Postprocess results (split batch back to individual results)
|
||||||
|
results = []
|
||||||
|
for i, batch_frame in enumerate(batch):
|
||||||
|
# Extract single frame output from batch
|
||||||
|
frame_output = {}
|
||||||
|
for k, v in outputs.items():
|
||||||
|
# v has shape (N, ...), extract index i and keep batch dimension
|
||||||
|
frame_output[k] = v[i:i+1] # Shape: (1, ...)
|
||||||
|
|
||||||
|
if self.postprocess_fn:
|
||||||
|
try:
|
||||||
|
detections = self.postprocess_fn(frame_output)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
|
||||||
|
# Return empty detections on error
|
||||||
|
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
|
||||||
|
else:
|
||||||
|
detections = frame_output
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"stream_id": batch_frame.stream_id,
|
||||||
|
"timestamp": batch_frame.timestamp,
|
||||||
|
"detections": detections,
|
||||||
|
"metadata": batch_frame.metadata,
|
||||||
|
}
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def _process_remaining_buffers(self):
|
||||||
|
"""Process any remaining frames in buffers during shutdown"""
|
||||||
|
if len(self.buffer_a) > 0:
|
||||||
|
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
|
||||||
|
await self._process_buffer("A")
|
||||||
|
if len(self.buffer_b) > 0:
|
||||||
|
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
|
||||||
|
await self._process_buffer("B")
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get current buffer statistics"""
|
||||||
|
return {
|
||||||
|
"active_buffer": self.active_buffer,
|
||||||
|
"buffer_a_size": len(self.buffer_a),
|
||||||
|
"buffer_b_size": len(self.buffer_b),
|
||||||
|
"buffer_a_state": self.buffer_a_state.value,
|
||||||
|
"buffer_b_state": self.buffer_b_state.value,
|
||||||
|
"registered_streams": len(self.result_callbacks),
|
||||||
|
"total_frames_processed": self.total_frames_processed,
|
||||||
|
"total_batches_processed": self.total_batches_processed,
|
||||||
|
"avg_batch_size": (
|
||||||
|
self.total_frames_processed / self.total_batches_processed
|
||||||
|
if self.total_batches_processed > 0 else 0
|
||||||
|
),
|
||||||
|
}
|
||||||
566
services/stream_connection_manager.py
Normal file
566
services/stream_connection_manager.py
Normal file
|
|
@ -0,0 +1,566 @@
|
||||||
|
"""
|
||||||
|
StreamConnectionManager - Async orchestration for stream processing with batched inference.
|
||||||
|
|
||||||
|
This module provides high-level connection management for multiple RTSP streams,
|
||||||
|
coordinating decoders, batched inference, and tracking with an event-driven API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Dict, Optional, Callable, AsyncIterator, Tuple, Any, List
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .model_controller import ModelController
|
||||||
|
from .stream_decoder import StreamDecoderFactory
|
||||||
|
from .tracking_factory import TrackingFactory
|
||||||
|
from .model_repository import TensorRTModelRepository
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionStatus(Enum):
|
||||||
|
"""Status of a stream connection"""
|
||||||
|
CONNECTING = "connecting"
|
||||||
|
CONNECTED = "connected"
|
||||||
|
DISCONNECTED = "disconnected"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrackingResult:
|
||||||
|
"""Result emitted to user callbacks"""
|
||||||
|
stream_id: str
|
||||||
|
timestamp: float
|
||||||
|
tracked_objects: List # List of TrackedObject from TrackingController
|
||||||
|
detections: List # Raw detections
|
||||||
|
frame_shape: Tuple[int, int, int]
|
||||||
|
metadata: Dict
|
||||||
|
|
||||||
|
|
||||||
|
class StreamConnection:
|
||||||
|
"""
|
||||||
|
Represents a single stream connection with event emission.
|
||||||
|
|
||||||
|
This class wraps a StreamDecoder, polls frames asynchronously, submits them
|
||||||
|
to the ModelController for batched inference, runs tracking, and emits results
|
||||||
|
via queues or callbacks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: Unique identifier for this stream
|
||||||
|
decoder: StreamDecoder instance
|
||||||
|
model_controller: ModelController for batched inference
|
||||||
|
tracking_controller: TrackingController for object tracking
|
||||||
|
poll_interval: Frame polling interval in seconds (default: 0.01)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stream_id: str,
|
||||||
|
decoder,
|
||||||
|
model_controller: ModelController,
|
||||||
|
tracking_controller,
|
||||||
|
poll_interval: float = 0.01,
|
||||||
|
):
|
||||||
|
self.stream_id = stream_id
|
||||||
|
self.decoder = decoder
|
||||||
|
self.model_controller = model_controller
|
||||||
|
self.tracking_controller = tracking_controller
|
||||||
|
self.poll_interval = poll_interval
|
||||||
|
|
||||||
|
self.status = ConnectionStatus.CONNECTING
|
||||||
|
self.frame_count = 0
|
||||||
|
self.last_frame_time = 0.0
|
||||||
|
|
||||||
|
# Event emission
|
||||||
|
self.result_queue: asyncio.Queue[TrackingResult] = asyncio.Queue()
|
||||||
|
self.error_queue: asyncio.Queue[Exception] = asyncio.Queue()
|
||||||
|
|
||||||
|
# Tasks
|
||||||
|
self.poller_task: Optional[asyncio.Task] = None
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Start the connection (decoder and frame polling)"""
|
||||||
|
# Start decoder (runs in background thread)
|
||||||
|
self.decoder.start()
|
||||||
|
|
||||||
|
# Wait for initial connection (try for up to 10 seconds)
|
||||||
|
max_wait = 10.0
|
||||||
|
wait_interval = 0.5
|
||||||
|
elapsed = 0.0
|
||||||
|
|
||||||
|
while elapsed < max_wait:
|
||||||
|
await asyncio.sleep(wait_interval)
|
||||||
|
elapsed += wait_interval
|
||||||
|
|
||||||
|
if self.decoder.is_connected():
|
||||||
|
self.status = ConnectionStatus.CONNECTED
|
||||||
|
logger.info(f"Stream {self.stream_id} connected after {elapsed:.1f}s")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Timeout - but don't fail hard, let it try to connect in background
|
||||||
|
logger.warning(f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying...")
|
||||||
|
self.status = ConnectionStatus.CONNECTING
|
||||||
|
|
||||||
|
# Start frame polling task
|
||||||
|
self.running = True
|
||||||
|
self.poller_task = asyncio.create_task(self._frame_poller())
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""Stop the connection and cleanup"""
|
||||||
|
logger.info(f"Stopping stream {self.stream_id}...")
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
if self.poller_task:
|
||||||
|
self.poller_task.cancel()
|
||||||
|
try:
|
||||||
|
await self.poller_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Stop decoder
|
||||||
|
self.decoder.stop()
|
||||||
|
|
||||||
|
# Unregister from model controller
|
||||||
|
self.model_controller.unregister_callback(self.stream_id)
|
||||||
|
|
||||||
|
self.status = ConnectionStatus.DISCONNECTED
|
||||||
|
logger.info(f"Stream {self.stream_id} stopped")
|
||||||
|
|
||||||
|
async def _frame_poller(self):
|
||||||
|
"""Poll frames from threaded decoder and submit to model controller"""
|
||||||
|
last_frame_ptr = None
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
# Poll frame from decoder (runs in thread)
|
||||||
|
frame = self.decoder.get_latest_frame(rgb=True)
|
||||||
|
|
||||||
|
# Check if we got a new frame (avoid reprocessing same frame)
|
||||||
|
if frame is not None and frame.data_ptr() != last_frame_ptr:
|
||||||
|
last_frame_ptr = frame.data_ptr()
|
||||||
|
self.last_frame_time = time.time()
|
||||||
|
self.frame_count += 1
|
||||||
|
|
||||||
|
# Submit to model controller for batched inference
|
||||||
|
await self.model_controller.submit_frame(
|
||||||
|
stream_id=self.stream_id,
|
||||||
|
frame=frame,
|
||||||
|
metadata={
|
||||||
|
"frame_number": self.frame_count,
|
||||||
|
"shape": tuple(frame.shape),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check decoder status
|
||||||
|
if not self.decoder.is_connected():
|
||||||
|
if self.status == ConnectionStatus.CONNECTED:
|
||||||
|
logger.warning(f"Stream {self.stream_id} disconnected")
|
||||||
|
self.status = ConnectionStatus.DISCONNECTED
|
||||||
|
# Decoder will auto-reconnect, just update status
|
||||||
|
await asyncio.sleep(1.0)
|
||||||
|
if self.decoder.is_connected():
|
||||||
|
logger.info(f"Stream {self.stream_id} reconnected")
|
||||||
|
self.status = ConnectionStatus.CONNECTED
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in frame poller for {self.stream_id}: {e}", exc_info=True)
|
||||||
|
await self.error_queue.put(e)
|
||||||
|
self.status = ConnectionStatus.ERROR
|
||||||
|
|
||||||
|
# Sleep until next poll
|
||||||
|
await asyncio.sleep(self.poll_interval)
|
||||||
|
|
||||||
|
async def _handle_inference_result(self, result: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Callback invoked by ModelController when inference is done.
|
||||||
|
Runs tracking and emits final result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Inference result dictionary
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Extract detections
|
||||||
|
detections = result["detections"]
|
||||||
|
|
||||||
|
# Run tracking (this is sync, so run in executor)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
tracked_objects = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self._run_tracking_sync(detections)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create tracking result
|
||||||
|
tracking_result = TrackingResult(
|
||||||
|
stream_id=self.stream_id,
|
||||||
|
timestamp=result["timestamp"],
|
||||||
|
tracked_objects=tracked_objects,
|
||||||
|
detections=detections,
|
||||||
|
frame_shape=result["metadata"].get("shape"),
|
||||||
|
metadata=result["metadata"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit to result queue
|
||||||
|
await self.result_queue.put(tracking_result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
|
||||||
|
await self.error_queue.put(e)
|
||||||
|
|
||||||
|
def _run_tracking_sync(self, detections):
|
||||||
|
"""
|
||||||
|
Run tracking synchronously (called from executor).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detections: Detection tensor (N, 6) [x1, y1, x2, y2, conf, class_id]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of TrackedObject instances
|
||||||
|
"""
|
||||||
|
# Use the TrackingController's internal tracking with detections
|
||||||
|
# We need to manually update tracks since we already have detections
|
||||||
|
import torch
|
||||||
|
|
||||||
|
with self.tracking_controller._lock:
|
||||||
|
self.tracking_controller._frame_count += 1
|
||||||
|
|
||||||
|
# If no detections, just cleanup and return current tracks
|
||||||
|
if len(detections) == 0:
|
||||||
|
self.tracking_controller._cleanup_stale_tracks()
|
||||||
|
return list(self.tracking_controller._tracks.values())
|
||||||
|
|
||||||
|
# Run IoU tracking to associate detections with existing tracks
|
||||||
|
associations = self.tracking_controller._iou_tracking(detections)
|
||||||
|
|
||||||
|
# Update or create tracks
|
||||||
|
for (det_idx, track_id), detection in zip(associations, detections):
|
||||||
|
bbox = detection[:4].cpu().tolist()
|
||||||
|
confidence = float(detection[4])
|
||||||
|
class_id = int(detection[5]) if detection.shape[0] > 5 else 0
|
||||||
|
|
||||||
|
if track_id == -1:
|
||||||
|
# Create new track
|
||||||
|
new_track = self.tracking_controller._create_track(
|
||||||
|
bbox, confidence, class_id, self.tracking_controller._frame_count
|
||||||
|
)
|
||||||
|
self.tracking_controller._tracks[new_track.track_id] = new_track
|
||||||
|
else:
|
||||||
|
# Update existing track
|
||||||
|
self.tracking_controller._tracks[track_id].update(
|
||||||
|
bbox, confidence, self.tracking_controller._frame_count
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cleanup stale tracks
|
||||||
|
self.tracking_controller._cleanup_stale_tracks()
|
||||||
|
|
||||||
|
return list(self.tracking_controller._tracks.values())
|
||||||
|
|
||||||
|
async def tracking_results(self) -> AsyncIterator[TrackingResult]:
|
||||||
|
"""
|
||||||
|
Async generator for tracking results.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
async for result in connection.tracking_results():
|
||||||
|
print(result.tracked_objects)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
TrackingResult objects as they become available
|
||||||
|
"""
|
||||||
|
while self.running or not self.result_queue.empty():
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(self.result_queue.get(), timeout=1.0)
|
||||||
|
yield result
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
async def errors(self) -> AsyncIterator[Exception]:
|
||||||
|
"""
|
||||||
|
Async generator for errors.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Exception objects as they occur
|
||||||
|
"""
|
||||||
|
while self.running or not self.error_queue.empty():
|
||||||
|
try:
|
||||||
|
error = await asyncio.wait_for(self.error_queue.get(), timeout=1.0)
|
||||||
|
yield error
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get connection statistics"""
|
||||||
|
return {
|
||||||
|
"stream_id": self.stream_id,
|
||||||
|
"status": self.status.value,
|
||||||
|
"frame_count": self.frame_count,
|
||||||
|
"last_frame_time": self.last_frame_time,
|
||||||
|
"decoder_connected": self.decoder.is_connected(),
|
||||||
|
"decoder_buffer_size": self.decoder.get_buffer_size(),
|
||||||
|
"result_queue_size": self.result_queue.qsize(),
|
||||||
|
"error_queue_size": self.error_queue.qsize(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class StreamConnectionManager:
|
||||||
|
"""
|
||||||
|
High-level manager for stream connections with batched inference.
|
||||||
|
|
||||||
|
This manager coordinates multiple RTSP streams, batched model inference,
|
||||||
|
and object tracking through an async event-driven API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gpu_id: GPU device ID (default: 0)
|
||||||
|
batch_size: Maximum batch size for inference (default: 16)
|
||||||
|
force_timeout: Force buffer switch timeout in seconds (default: 0.05)
|
||||||
|
poll_interval: Frame polling interval in seconds (default: 0.01)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
manager = StreamConnectionManager(gpu_id=0, batch_size=16)
|
||||||
|
await manager.initialize(model_path="yolov8n.trt", ...)
|
||||||
|
connection = await manager.connect_stream(rtsp_url, on_tracking_result=callback)
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
await manager.shutdown()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
gpu_id: int = 0,
|
||||||
|
batch_size: int = 16,
|
||||||
|
force_timeout: float = 0.05,
|
||||||
|
poll_interval: float = 0.01,
|
||||||
|
):
|
||||||
|
self.gpu_id = gpu_id
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.force_timeout = force_timeout
|
||||||
|
self.poll_interval = poll_interval
|
||||||
|
|
||||||
|
# Factories
|
||||||
|
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
|
||||||
|
self.tracking_factory = TrackingFactory(gpu_id=gpu_id)
|
||||||
|
self.model_repository = TensorRTModelRepository(gpu_id=gpu_id)
|
||||||
|
|
||||||
|
# Controllers
|
||||||
|
self.model_controller: Optional[ModelController] = None
|
||||||
|
self.tracking_controller = None
|
||||||
|
|
||||||
|
# Connections
|
||||||
|
self.connections: Dict[str, StreamConnection] = {}
|
||||||
|
|
||||||
|
# State
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
async def initialize(
|
||||||
|
self,
|
||||||
|
model_path: str,
|
||||||
|
model_id: str = "detector",
|
||||||
|
preprocess_fn: Optional[Callable] = None,
|
||||||
|
postprocess_fn: Optional[Callable] = None,
|
||||||
|
num_contexts: int = 4,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the manager with a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to TensorRT model file
|
||||||
|
model_id: Model identifier (default: "detector")
|
||||||
|
preprocess_fn: Preprocessing function (e.g., YOLOv8Utils.preprocess)
|
||||||
|
postprocess_fn: Postprocessing function (e.g., YOLOv8Utils.postprocess)
|
||||||
|
num_contexts: Number of TensorRT execution contexts (default: 4)
|
||||||
|
"""
|
||||||
|
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self.model_repository.load_model(model_id, model_path, num_contexts=num_contexts)
|
||||||
|
)
|
||||||
|
logger.info(f"Loaded model {model_id} from {model_path}")
|
||||||
|
|
||||||
|
# Create model controller
|
||||||
|
self.model_controller = ModelController(
|
||||||
|
model_repository=self.model_repository,
|
||||||
|
model_id=model_id,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
force_timeout=self.force_timeout,
|
||||||
|
preprocess_fn=preprocess_fn,
|
||||||
|
postprocess_fn=postprocess_fn,
|
||||||
|
)
|
||||||
|
await self.model_controller.start()
|
||||||
|
|
||||||
|
# Create tracking controller
|
||||||
|
self.tracking_controller = self.tracking_factory.create_controller(
|
||||||
|
model_repository=self.model_repository,
|
||||||
|
model_id=model_id,
|
||||||
|
tracker_type="iou",
|
||||||
|
)
|
||||||
|
logger.info("TrackingController created")
|
||||||
|
|
||||||
|
self.initialized = True
|
||||||
|
logger.info("StreamConnectionManager initialized successfully")
|
||||||
|
|
||||||
|
async def connect_stream(
|
||||||
|
self,
|
||||||
|
rtsp_url: str,
|
||||||
|
stream_id: Optional[str] = None,
|
||||||
|
on_tracking_result: Optional[Callable] = None,
|
||||||
|
on_error: Optional[Callable] = None,
|
||||||
|
buffer_size: int = 30,
|
||||||
|
) -> StreamConnection:
|
||||||
|
"""
|
||||||
|
Connect to a stream and start processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rtsp_url: RTSP stream URL
|
||||||
|
stream_id: Optional stream identifier (auto-generated if not provided)
|
||||||
|
on_tracking_result: Optional callback for tracking results (sync or async)
|
||||||
|
on_error: Optional callback for errors (sync or async)
|
||||||
|
buffer_size: Decoder buffer size (default: 30)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StreamConnection object for this stream
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If manager is not initialized
|
||||||
|
ConnectionError: If stream connection fails
|
||||||
|
"""
|
||||||
|
if not self.initialized:
|
||||||
|
raise RuntimeError("Manager not initialized. Call initialize() first.")
|
||||||
|
|
||||||
|
# Generate stream ID if not provided
|
||||||
|
if stream_id is None:
|
||||||
|
stream_id = f"stream_{len(self.connections)}"
|
||||||
|
|
||||||
|
logger.info(f"Connecting to stream {stream_id}: {rtsp_url}")
|
||||||
|
|
||||||
|
# Create decoder
|
||||||
|
decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=buffer_size)
|
||||||
|
|
||||||
|
# Create connection
|
||||||
|
connection = StreamConnection(
|
||||||
|
stream_id=stream_id,
|
||||||
|
decoder=decoder,
|
||||||
|
model_controller=self.model_controller,
|
||||||
|
tracking_controller=self.tracking_controller,
|
||||||
|
poll_interval=self.poll_interval,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register callback with model controller
|
||||||
|
self.model_controller.register_callback(
|
||||||
|
stream_id,
|
||||||
|
connection._handle_inference_result
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start connection
|
||||||
|
await connection.start()
|
||||||
|
|
||||||
|
# Store connection
|
||||||
|
self.connections[stream_id] = connection
|
||||||
|
|
||||||
|
# Set up user callbacks if provided
|
||||||
|
if on_tracking_result:
|
||||||
|
asyncio.create_task(self._forward_results(connection, on_tracking_result))
|
||||||
|
|
||||||
|
if on_error:
|
||||||
|
asyncio.create_task(self._forward_errors(connection, on_error))
|
||||||
|
|
||||||
|
logger.info(f"Stream {stream_id} connected successfully")
|
||||||
|
return connection
|
||||||
|
|
||||||
|
async def disconnect_stream(self, stream_id: str):
|
||||||
|
"""
|
||||||
|
Disconnect and cleanup a stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: Stream identifier to disconnect
|
||||||
|
"""
|
||||||
|
connection = self.connections.get(stream_id)
|
||||||
|
if connection:
|
||||||
|
await connection.stop()
|
||||||
|
del self.connections[stream_id]
|
||||||
|
logger.info(f"Stream {stream_id} disconnected")
|
||||||
|
|
||||||
|
async def disconnect_all(self):
|
||||||
|
"""Disconnect all streams"""
|
||||||
|
logger.info("Disconnecting all streams...")
|
||||||
|
stream_ids = list(self.connections.keys())
|
||||||
|
for stream_id in stream_ids:
|
||||||
|
await self.disconnect_stream(stream_id)
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""Shutdown the manager and cleanup all resources"""
|
||||||
|
logger.info("Shutting down StreamConnectionManager...")
|
||||||
|
|
||||||
|
# Disconnect all streams
|
||||||
|
await self.disconnect_all()
|
||||||
|
|
||||||
|
# Stop model controller
|
||||||
|
if self.model_controller:
|
||||||
|
await self.model_controller.stop()
|
||||||
|
|
||||||
|
# Note: Model repository cleanup is sync and may cause segfaults
|
||||||
|
# Leaving cleanup to garbage collection for now
|
||||||
|
|
||||||
|
self.initialized = False
|
||||||
|
logger.info("StreamConnectionManager shutdown complete")
|
||||||
|
|
||||||
|
async def _forward_results(self, connection: StreamConnection, callback: Callable):
|
||||||
|
"""
|
||||||
|
Forward results from connection to user callback.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: StreamConnection to listen to
|
||||||
|
callback: User callback (sync or async)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async for result in connection.tracking_results():
|
||||||
|
if asyncio.iscoroutinefunction(callback):
|
||||||
|
await callback(result)
|
||||||
|
else:
|
||||||
|
callback(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def _forward_errors(self, connection: StreamConnection, callback: Callable):
|
||||||
|
"""
|
||||||
|
Forward errors from connection to user callback.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: StreamConnection to listen to
|
||||||
|
callback: User callback (sync or async)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async for error in connection.errors():
|
||||||
|
if asyncio.iscoroutinefunction(callback):
|
||||||
|
await callback(error)
|
||||||
|
else:
|
||||||
|
callback(error)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get statistics for all connections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with manager and connection statistics
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"manager": {
|
||||||
|
"initialized": self.initialized,
|
||||||
|
"gpu_id": self.gpu_id,
|
||||||
|
"num_connections": len(self.connections),
|
||||||
|
"batch_size": self.batch_size,
|
||||||
|
"force_timeout": self.force_timeout,
|
||||||
|
"poll_interval": self.poll_interval,
|
||||||
|
},
|
||||||
|
"model_controller": self.model_controller.get_stats() if self.model_controller else {},
|
||||||
|
"connections": {
|
||||||
|
stream_id: conn.get_stats()
|
||||||
|
for stream_id, conn in self.connections.items()
|
||||||
|
},
|
||||||
|
}
|
||||||
373
test_event_driven.py
Normal file
373
test_event_driven.py
Normal file
|
|
@ -0,0 +1,373 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script for event-driven stream processing with batched inference.
|
||||||
|
|
||||||
|
This demonstrates the new AsyncIO-based API for connecting to RTSP streams,
|
||||||
|
processing frames through batched inference, and receiving tracking results
|
||||||
|
via callbacks and async generators.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from services import StreamConnectionManager, YOLOv8Utils, COCO_CLASSES
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Example 1: Simple callback pattern
|
||||||
|
async def example_callback_pattern():
|
||||||
|
"""Demonstrates the simple callback pattern for a single stream"""
|
||||||
|
logger.info("=== Example 1: Callback Pattern ===")
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
camera_url = os.getenv('CAMERA_URL_1')
|
||||||
|
if not camera_url:
|
||||||
|
logger.error("CAMERA_URL_1 not found in .env file")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create manager
|
||||||
|
manager = StreamConnectionManager(
|
||||||
|
gpu_id=0,
|
||||||
|
batch_size=16,
|
||||||
|
force_timeout=0.05, # 50ms
|
||||||
|
poll_interval=0.01, # 100 FPS
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize with YOLOv8 model
|
||||||
|
model_path = "models/yolov8n.trt" # Adjust path as needed
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
logger.error(f"Model file not found: {model_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
await manager.initialize(
|
||||||
|
model_path=model_path,
|
||||||
|
model_id="yolo",
|
||||||
|
preprocess_fn=YOLOv8Utils.preprocess,
|
||||||
|
postprocess_fn=YOLOv8Utils.postprocess,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define callback for tracking results
|
||||||
|
def on_tracking_result(result):
|
||||||
|
logger.info(f"[{result.stream_id}] Frame {result.metadata.get('frame_number', 0)}")
|
||||||
|
logger.info(f" Timestamp: {result.timestamp:.3f}")
|
||||||
|
logger.info(f" Tracked objects: {len(result.tracked_objects)}")
|
||||||
|
|
||||||
|
for obj in result.tracked_objects[:5]: # Show first 5
|
||||||
|
class_name = COCO_CLASSES.get(obj.class_id, f"Class {obj.class_id}")
|
||||||
|
logger.info(
|
||||||
|
f" Track ID {obj.track_id}: {class_name}, "
|
||||||
|
f"conf={obj.confidence:.2f}, bbox={obj.bbox}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_error(error):
|
||||||
|
logger.error(f"Stream error: {error}")
|
||||||
|
|
||||||
|
# Connect to stream
|
||||||
|
connection = await manager.connect_stream(
|
||||||
|
rtsp_url=camera_url,
|
||||||
|
stream_id="camera1",
|
||||||
|
on_tracking_result=on_tracking_result,
|
||||||
|
on_error=on_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Let it run for 30 seconds
|
||||||
|
logger.info("Processing stream for 30 seconds...")
|
||||||
|
await asyncio.sleep(30)
|
||||||
|
|
||||||
|
# Get statistics
|
||||||
|
stats = manager.get_stats()
|
||||||
|
logger.info("=== Statistics ===")
|
||||||
|
logger.info(f"Manager: {stats['manager']}")
|
||||||
|
logger.info(f"Model Controller: {stats['model_controller']}")
|
||||||
|
logger.info(f"Connection: {stats['connections']['camera1']}")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await manager.shutdown()
|
||||||
|
logger.info("Example 1 complete\n")
|
||||||
|
|
||||||
|
|
||||||
|
# Example 2: Async generator pattern with multiple streams
|
||||||
|
async def example_async_generator_pattern():
|
||||||
|
"""Demonstrates async generator pattern for multiple streams"""
|
||||||
|
logger.info("=== Example 2: Async Generator Pattern (Multiple Streams) ===")
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
camera_urls = []
|
||||||
|
for i in range(1, 5): # Try to load 4 cameras
|
||||||
|
url = os.getenv(f'CAMERA_URL_{i}')
|
||||||
|
if url:
|
||||||
|
camera_urls.append((url, f"camera{i}"))
|
||||||
|
|
||||||
|
if not camera_urls:
|
||||||
|
logger.error("No camera URLs found in .env file")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Found {len(camera_urls)} camera(s)")
|
||||||
|
|
||||||
|
# Create manager with larger batch for multiple streams
|
||||||
|
manager = StreamConnectionManager(
|
||||||
|
gpu_id=0,
|
||||||
|
batch_size=32, # Larger batch for multiple streams
|
||||||
|
force_timeout=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize
|
||||||
|
model_path = "models/yolov8n.trt"
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
logger.error(f"Model file not found: {model_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
await manager.initialize(
|
||||||
|
model_path=model_path,
|
||||||
|
preprocess_fn=YOLOv8Utils.preprocess,
|
||||||
|
postprocess_fn=YOLOv8Utils.postprocess,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connect to all streams
|
||||||
|
connections = []
|
||||||
|
for url, stream_id in camera_urls:
|
||||||
|
try:
|
||||||
|
connection = await manager.connect_stream(
|
||||||
|
rtsp_url=url,
|
||||||
|
stream_id=stream_id,
|
||||||
|
)
|
||||||
|
connections.append((connection, stream_id))
|
||||||
|
logger.info(f"Connected to {stream_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to {stream_id}: {e}")
|
||||||
|
|
||||||
|
# Process each stream with async generator
|
||||||
|
async def process_stream(connection, stream_name):
|
||||||
|
"""Process results from a single stream"""
|
||||||
|
frame_count = 0
|
||||||
|
person_detections = 0
|
||||||
|
|
||||||
|
async for result in connection.tracking_results():
|
||||||
|
frame_count += 1
|
||||||
|
|
||||||
|
# Count person detections (class_id 0 in COCO)
|
||||||
|
for obj in result.tracked_objects:
|
||||||
|
if obj.class_id == 0:
|
||||||
|
person_detections += 1
|
||||||
|
|
||||||
|
# Log every 10th frame
|
||||||
|
if frame_count % 10 == 0:
|
||||||
|
logger.info(
|
||||||
|
f"[{stream_name}] Processed {frame_count} frames, "
|
||||||
|
f"{person_detections} person detections"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop after 100 frames
|
||||||
|
if frame_count >= 100:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Run all streams concurrently
|
||||||
|
tasks = [
|
||||||
|
asyncio.create_task(process_stream(conn, name))
|
||||||
|
for conn, name in connections
|
||||||
|
]
|
||||||
|
|
||||||
|
# Wait for all tasks to complete
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Get final statistics
|
||||||
|
stats = manager.get_stats()
|
||||||
|
logger.info("\n=== Final Statistics ===")
|
||||||
|
logger.info(f"Total connections: {stats['manager']['num_connections']}")
|
||||||
|
logger.info(f"Frames processed: {stats['model_controller']['total_frames_processed']}")
|
||||||
|
logger.info(f"Batches processed: {stats['model_controller']['total_batches_processed']}")
|
||||||
|
logger.info(f"Avg batch size: {stats['model_controller']['avg_batch_size']:.2f}")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await manager.shutdown()
|
||||||
|
logger.info("Example 2 complete\n")
|
||||||
|
|
||||||
|
|
||||||
|
# Example 3: Queue-based pattern
|
||||||
|
async def example_queue_pattern():
|
||||||
|
"""Demonstrates direct queue access for custom processing"""
|
||||||
|
logger.info("=== Example 3: Queue-Based Pattern ===")
|
||||||
|
|
||||||
|
# Load environment
|
||||||
|
load_dotenv()
|
||||||
|
camera_url = os.getenv('CAMERA_URL_1')
|
||||||
|
if not camera_url:
|
||||||
|
logger.error("CAMERA_URL_1 not found in .env file")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create manager
|
||||||
|
manager = StreamConnectionManager(gpu_id=0, batch_size=16)
|
||||||
|
|
||||||
|
# Initialize
|
||||||
|
model_path = "models/yolov8n.trt"
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
logger.error(f"Model file not found: {model_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
await manager.initialize(
|
||||||
|
model_path=model_path,
|
||||||
|
preprocess_fn=YOLOv8Utils.preprocess,
|
||||||
|
postprocess_fn=YOLOv8Utils.postprocess,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connect to stream (no callback)
|
||||||
|
connection = await manager.connect_stream(
|
||||||
|
rtsp_url=camera_url,
|
||||||
|
stream_id="main_camera",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the built-in queue directly
|
||||||
|
result_queue = connection.result_queue
|
||||||
|
|
||||||
|
# Process results from queue
|
||||||
|
processed_count = 0
|
||||||
|
while processed_count < 50: # Process 50 frames
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(result_queue.get(), timeout=5.0)
|
||||||
|
processed_count += 1
|
||||||
|
|
||||||
|
# Custom processing
|
||||||
|
has_person = any(obj.class_id == 0 for obj in result.tracked_objects)
|
||||||
|
has_car = any(obj.class_id == 2 for obj in result.tracked_objects)
|
||||||
|
|
||||||
|
if has_person or has_car:
|
||||||
|
logger.info(
|
||||||
|
f"Frame {processed_count}: "
|
||||||
|
f"Person={'Yes' if has_person else 'No'}, "
|
||||||
|
f"Car={'Yes' if has_car else 'No'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("Timeout waiting for result")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await manager.shutdown()
|
||||||
|
logger.info("Example 3 complete\n")
|
||||||
|
|
||||||
|
|
||||||
|
# Example 4: Performance monitoring
|
||||||
|
async def example_performance_monitoring():
|
||||||
|
"""Demonstrates real-time performance monitoring"""
|
||||||
|
logger.info("=== Example 4: Performance Monitoring ===")
|
||||||
|
|
||||||
|
# Load environment
|
||||||
|
load_dotenv()
|
||||||
|
camera_url = os.getenv('CAMERA_URL_1')
|
||||||
|
if not camera_url:
|
||||||
|
logger.error("CAMERA_URL_1 not found in .env file")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create manager
|
||||||
|
manager = StreamConnectionManager(
|
||||||
|
gpu_id=0,
|
||||||
|
batch_size=16,
|
||||||
|
force_timeout=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize
|
||||||
|
model_path = "models/yolov8n.trt"
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
logger.error(f"Model file not found: {model_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
await manager.initialize(
|
||||||
|
model_path=model_path,
|
||||||
|
preprocess_fn=YOLOv8Utils.preprocess,
|
||||||
|
postprocess_fn=YOLOv8Utils.postprocess,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track performance metrics
|
||||||
|
frame_times = []
|
||||||
|
last_frame_time = None
|
||||||
|
|
||||||
|
def on_tracking_result(result):
|
||||||
|
nonlocal last_frame_time
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
if last_frame_time is not None:
|
||||||
|
frame_interval = current_time - last_frame_time
|
||||||
|
frame_times.append(frame_interval)
|
||||||
|
|
||||||
|
last_frame_time = current_time
|
||||||
|
|
||||||
|
# Connect
|
||||||
|
connection = await manager.connect_stream(
|
||||||
|
rtsp_url=camera_url,
|
||||||
|
on_tracking_result=on_tracking_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Monitor stats periodically
|
||||||
|
for i in range(6): # Monitor for 60 seconds
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
stats = manager.get_stats()
|
||||||
|
model_stats = stats['model_controller']
|
||||||
|
conn_stats = stats['connections'].get('stream_0', {})
|
||||||
|
|
||||||
|
logger.info(f"\n=== Stats Update {i+1} ===")
|
||||||
|
logger.info(f"Buffer A: {model_stats['buffer_a_size']} ({model_stats['buffer_a_state']})")
|
||||||
|
logger.info(f"Buffer B: {model_stats['buffer_b_size']} ({model_stats['buffer_b_state']})")
|
||||||
|
logger.info(f"Active buffer: {model_stats['active_buffer']}")
|
||||||
|
logger.info(f"Total frames: {model_stats['total_frames_processed']}")
|
||||||
|
logger.info(f"Total batches: {model_stats['total_batches_processed']}")
|
||||||
|
logger.info(f"Avg batch size: {model_stats['avg_batch_size']:.2f}")
|
||||||
|
logger.info(f"Decoder frames: {conn_stats.get('frame_count', 0)}")
|
||||||
|
|
||||||
|
if frame_times:
|
||||||
|
avg_fps = 1.0 / (sum(frame_times) / len(frame_times))
|
||||||
|
logger.info(f"Processing FPS: {avg_fps:.2f}")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await manager.shutdown()
|
||||||
|
logger.info("Example 4 complete\n")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Run all examples"""
|
||||||
|
logger.info("Starting event-driven stream processing tests\n")
|
||||||
|
|
||||||
|
# Choose which example to run
|
||||||
|
choice = os.getenv('EXAMPLE', '1')
|
||||||
|
|
||||||
|
if choice == '1':
|
||||||
|
await example_callback_pattern()
|
||||||
|
elif choice == '2':
|
||||||
|
await example_async_generator_pattern()
|
||||||
|
elif choice == '3':
|
||||||
|
await example_queue_pattern()
|
||||||
|
elif choice == '4':
|
||||||
|
await example_performance_monitoring()
|
||||||
|
elif choice == 'all':
|
||||||
|
await example_callback_pattern()
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
await example_async_generator_pattern()
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
await example_queue_pattern()
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
await example_performance_monitoring()
|
||||||
|
else:
|
||||||
|
logger.error(f"Invalid choice: {choice}")
|
||||||
|
logger.info("Set EXAMPLE env var to 1, 2, 3, 4, or 'all'")
|
||||||
|
|
||||||
|
logger.info("All tests complete!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("\nInterrupted by user")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error: {e}", exc_info=True)
|
||||||
117
test_event_driven_quick.py
Executable file
117
test_event_driven_quick.py
Executable file
|
|
@ -0,0 +1,117 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Quick test for event-driven stream processing - runs for 20 seconds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from services import StreamConnectionManager, YOLOv8Utils, COCO_CLASSES
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Quick test with callback pattern"""
|
||||||
|
logger.info("=== Quick Event-Driven Test (20 seconds) ===")
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
camera_url = os.getenv('CAMERA_URL_1')
|
||||||
|
if not camera_url:
|
||||||
|
logger.error("CAMERA_URL_1 not found in .env file")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create manager
|
||||||
|
manager = StreamConnectionManager(
|
||||||
|
gpu_id=0,
|
||||||
|
batch_size=16,
|
||||||
|
force_timeout=0.05, # 50ms
|
||||||
|
poll_interval=0.01, # 100 FPS
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize with YOLOv8 model
|
||||||
|
model_path = "models/yolov8n.trt"
|
||||||
|
logger.info(f"Initializing with model: {model_path}")
|
||||||
|
|
||||||
|
await manager.initialize(
|
||||||
|
model_path=model_path,
|
||||||
|
model_id="yolo",
|
||||||
|
preprocess_fn=YOLOv8Utils.preprocess,
|
||||||
|
postprocess_fn=YOLOv8Utils.postprocess,
|
||||||
|
)
|
||||||
|
|
||||||
|
result_count = 0
|
||||||
|
|
||||||
|
# Define callback for tracking results
|
||||||
|
def on_tracking_result(result):
|
||||||
|
nonlocal result_count
|
||||||
|
result_count += 1
|
||||||
|
|
||||||
|
if result_count % 5 == 0: # Log every 5th result
|
||||||
|
logger.info(f"[{result.stream_id}] Frame {result.metadata.get('frame_number', 0)}")
|
||||||
|
logger.info(f" Tracked objects: {len(result.tracked_objects)}")
|
||||||
|
|
||||||
|
for obj in result.tracked_objects[:3]: # Show first 3
|
||||||
|
class_name = COCO_CLASSES.get(obj.class_id, f"Class {obj.class_id}")
|
||||||
|
logger.info(
|
||||||
|
f" Track ID {obj.track_id}: {class_name}, "
|
||||||
|
f"conf={obj.confidence:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_error(error):
|
||||||
|
logger.error(f"Stream error: {error}")
|
||||||
|
|
||||||
|
# Connect to stream
|
||||||
|
logger.info(f"Connecting to stream...")
|
||||||
|
connection = await manager.connect_stream(
|
||||||
|
rtsp_url=camera_url,
|
||||||
|
stream_id="test_camera",
|
||||||
|
on_tracking_result=on_tracking_result,
|
||||||
|
on_error=on_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Monitor for 20 seconds with stats updates
|
||||||
|
for i in range(4): # 4 x 5 seconds = 20 seconds
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
stats = manager.get_stats()
|
||||||
|
model_stats = stats['model_controller']
|
||||||
|
|
||||||
|
logger.info(f"\n=== Stats Update {i+1}/4 ===")
|
||||||
|
logger.info(f"Results received: {result_count}")
|
||||||
|
logger.info(f"Buffer A: {model_stats['buffer_a_size']} ({model_stats['buffer_a_state']})")
|
||||||
|
logger.info(f"Buffer B: {model_stats['buffer_b_size']} ({model_stats['buffer_b_state']})")
|
||||||
|
logger.info(f"Active buffer: {model_stats['active_buffer']}")
|
||||||
|
logger.info(f"Total frames processed: {model_stats['total_frames_processed']}")
|
||||||
|
logger.info(f"Total batches: {model_stats['total_batches_processed']}")
|
||||||
|
logger.info(f"Avg batch size: {model_stats['avg_batch_size']:.2f}")
|
||||||
|
|
||||||
|
# Final statistics
|
||||||
|
stats = manager.get_stats()
|
||||||
|
logger.info("\n=== Final Statistics ===")
|
||||||
|
logger.info(f"Total results received: {result_count}")
|
||||||
|
logger.info(f"Manager: {stats['manager']}")
|
||||||
|
logger.info(f"Model Controller: {stats['model_controller']}")
|
||||||
|
logger.info(f"Connection: {stats['connections']['test_camera']}")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
logger.info("\nShutting down...")
|
||||||
|
await manager.shutdown()
|
||||||
|
logger.info("Test complete!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("\nInterrupted by user")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error: {e}", exc_info=True)
|
||||||
|
|
@ -509,7 +509,7 @@ def main_multi_window():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Run single camera visualization
|
# Run single camera visualization
|
||||||
main()
|
# main()
|
||||||
|
|
||||||
# Uncomment to run multi-window visualization
|
# Uncomment to run multi-window visualization
|
||||||
# main_multi_window()
|
main_multi_window()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue