batch processing/event driven
This commit is contained in:
parent
e71316ef3d
commit
dd57b5a246
7 changed files with 2673 additions and 2 deletions
1108
EVENT_DRIVEN_DESIGN.md
Normal file
1108
EVENT_DRIVEN_DESIGN.md
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -8,6 +8,8 @@ from .model_repository import TensorRTModelRepository, ModelMetadata, ExecutionC
|
|||
from .tracking_controller import TrackingController, TrackedObject
|
||||
from .tracking_factory import TrackingFactory
|
||||
from .yolo import YOLOv8Utils, COCO_CLASSES
|
||||
from .model_controller import ModelController, BatchFrame, BufferState
|
||||
from .stream_connection_manager import StreamConnectionManager, StreamConnection, TrackingResult
|
||||
|
||||
__all__ = [
|
||||
'StreamDecoderFactory',
|
||||
|
|
@ -24,4 +26,10 @@ __all__ = [
|
|||
'TrackingFactory',
|
||||
'YOLOv8Utils',
|
||||
'COCO_CLASSES',
|
||||
'ModelController',
|
||||
'BatchFrame',
|
||||
'BufferState',
|
||||
'StreamConnectionManager',
|
||||
'StreamConnection',
|
||||
'TrackingResult',
|
||||
]
|
||||
|
|
|
|||
499
services/model_controller.py
Normal file
499
services/model_controller.py
Normal file
|
|
@ -0,0 +1,499 @@
|
|||
"""
|
||||
ModelController - Async batching layer with ping-pong buffers for inference.
|
||||
|
||||
This module provides batched inference coordination using ping-pong circular buffers
|
||||
with force-switch timeout mechanism.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import torch
|
||||
from typing import Dict, List, Optional, Callable, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import time
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchFrame:
|
||||
"""Represents a frame in the batch buffer"""
|
||||
stream_id: str
|
||||
frame: torch.Tensor # GPU tensor (3, H, W)
|
||||
timestamp: float
|
||||
metadata: Dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class BufferState(Enum):
|
||||
"""State of a ping-pong buffer"""
|
||||
IDLE = "idle"
|
||||
FILLING = "filling"
|
||||
PROCESSING = "processing"
|
||||
|
||||
|
||||
class ModelController:
|
||||
"""
|
||||
Manages batched inference with ping-pong buffers and force-switch timeout.
|
||||
|
||||
This controller accumulates frames from multiple streams into batches,
|
||||
processes them through a model repository, and routes results back to
|
||||
stream-specific callbacks.
|
||||
|
||||
Features:
|
||||
- Ping-pong circular buffers (BufferA/BufferB)
|
||||
- Force-switch timeout to prevent batch starvation
|
||||
- Async event-driven processing
|
||||
- Thread-safe frame submission
|
||||
|
||||
Args:
|
||||
model_repository: TensorRT model repository for inference
|
||||
model_id: Model identifier in the repository
|
||||
batch_size: Maximum frames per batch (default: 16)
|
||||
force_timeout: Max wait time before forcing buffer switch in seconds (default: 0.05)
|
||||
preprocess_fn: Optional preprocessing function for frames
|
||||
postprocess_fn: Optional postprocessing function for model outputs
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_repository,
|
||||
model_id: str,
|
||||
batch_size: int = 16,
|
||||
force_timeout: float = 0.05,
|
||||
preprocess_fn: Optional[Callable] = None,
|
||||
postprocess_fn: Optional[Callable] = None,
|
||||
):
|
||||
self.model_repository = model_repository
|
||||
self.model_id = model_id
|
||||
self.batch_size = batch_size
|
||||
self.force_timeout = force_timeout
|
||||
self.preprocess_fn = preprocess_fn
|
||||
self.postprocess_fn = postprocess_fn
|
||||
|
||||
# Detect model's actual batch size from input shape
|
||||
self.model_batch_size = self._detect_model_batch_size()
|
||||
if self.model_batch_size == 1:
|
||||
logger.warning(
|
||||
f"Model '{model_id}' has fixed batch_size=1. "
|
||||
f"Will process frames sequentially. Consider rebuilding model with dynamic batching."
|
||||
)
|
||||
else:
|
||||
logger.info(f"Model '{model_id}' supports batch_size={self.model_batch_size}")
|
||||
|
||||
# Ping-pong buffers
|
||||
self.buffer_a: List[BatchFrame] = []
|
||||
self.buffer_b: List[BatchFrame] = []
|
||||
|
||||
# Buffer states
|
||||
self.active_buffer = "A" # Which buffer is currently active (filling)
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
|
||||
# Async coordination
|
||||
self.buffer_lock = asyncio.Lock()
|
||||
self.last_submit_time = time.time()
|
||||
|
||||
# Tasks
|
||||
self.timeout_task: Optional[asyncio.Task] = None
|
||||
self.processor_task: Optional[asyncio.Task] = None
|
||||
self.running = False
|
||||
|
||||
# Result callbacks (stream_id -> callback)
|
||||
self.result_callbacks: Dict[str, Callable] = {}
|
||||
|
||||
# Statistics
|
||||
self.total_frames_processed = 0
|
||||
self.total_batches_processed = 0
|
||||
|
||||
def _detect_model_batch_size(self) -> int:
|
||||
"""
|
||||
Detect the model's batch size from its input shape.
|
||||
|
||||
Returns:
|
||||
Maximum batch size supported by the model (1 for fixed batch size models)
|
||||
"""
|
||||
try:
|
||||
metadata = self.model_repository.get_metadata(self.model_id)
|
||||
# Get first input tensor shape
|
||||
first_input = list(metadata.inputs.values())[0]
|
||||
batch_dim = first_input["shape"][0]
|
||||
|
||||
# batch_dim can be -1 (dynamic), 1 (fixed), or N (fixed batch size)
|
||||
if batch_dim == -1:
|
||||
# Dynamic batch size - use user-specified batch_size
|
||||
return self.batch_size
|
||||
else:
|
||||
# Fixed batch size
|
||||
return batch_dim
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1")
|
||||
return 1
|
||||
|
||||
async def start(self):
|
||||
"""Start the controller background tasks"""
|
||||
if self.running:
|
||||
logger.warning("ModelController already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.timeout_task = asyncio.create_task(self._timeout_monitor())
|
||||
self.processor_task = asyncio.create_task(self._batch_processor())
|
||||
logger.info("ModelController started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the controller and cleanup"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
logger.info("Stopping ModelController...")
|
||||
self.running = False
|
||||
|
||||
# Cancel tasks
|
||||
if self.timeout_task:
|
||||
self.timeout_task.cancel()
|
||||
try:
|
||||
await self.timeout_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self.processor_task:
|
||||
self.processor_task.cancel()
|
||||
try:
|
||||
await self.processor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Process any remaining frames
|
||||
await self._process_remaining_buffers()
|
||||
logger.info("ModelController stopped")
|
||||
|
||||
def register_callback(self, stream_id: str, callback: Callable):
|
||||
"""
|
||||
Register a callback for inference results from a stream.
|
||||
|
||||
Args:
|
||||
stream_id: Unique stream identifier
|
||||
callback: Callback function to receive results (can be sync or async)
|
||||
"""
|
||||
self.result_callbacks[stream_id] = callback
|
||||
logger.debug(f"Registered callback for stream: {stream_id}")
|
||||
|
||||
def unregister_callback(self, stream_id: str):
|
||||
"""
|
||||
Unregister a stream callback.
|
||||
|
||||
Args:
|
||||
stream_id: Stream identifier to unregister
|
||||
"""
|
||||
self.result_callbacks.pop(stream_id, None)
|
||||
logger.debug(f"Unregistered callback for stream: {stream_id}")
|
||||
|
||||
async def submit_frame(
|
||||
self,
|
||||
stream_id: str,
|
||||
frame: torch.Tensor,
|
||||
metadata: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Submit a frame for batched inference.
|
||||
|
||||
Args:
|
||||
stream_id: Unique stream identifier
|
||||
frame: GPU tensor (3, H, W) or (C, H, W)
|
||||
metadata: Optional metadata to attach to the frame
|
||||
"""
|
||||
async with self.buffer_lock:
|
||||
batch_frame = BatchFrame(
|
||||
stream_id=stream_id,
|
||||
frame=frame,
|
||||
timestamp=time.time(),
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
# Add to active buffer
|
||||
if self.active_buffer == "A":
|
||||
self.buffer_a.append(batch_frame)
|
||||
self.buffer_a_state = BufferState.FILLING
|
||||
buffer_size = len(self.buffer_a)
|
||||
else:
|
||||
self.buffer_b.append(batch_frame)
|
||||
self.buffer_b_state = BufferState.FILLING
|
||||
buffer_size = len(self.buffer_b)
|
||||
|
||||
self.last_submit_time = time.time()
|
||||
|
||||
# Check if we should immediately swap (batch full)
|
||||
if buffer_size >= self.batch_size:
|
||||
await self._try_swap_buffers()
|
||||
|
||||
async def _timeout_monitor(self):
|
||||
"""Monitor force-switch timeout"""
|
||||
while self.running:
|
||||
await asyncio.sleep(0.01) # Check every 10ms
|
||||
|
||||
async with self.buffer_lock:
|
||||
time_since_submit = time.time() - self.last_submit_time
|
||||
|
||||
# Check if timeout expired and we have frames waiting
|
||||
if time_since_submit >= self.force_timeout:
|
||||
active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
||||
if len(active_buffer) > 0:
|
||||
await self._try_swap_buffers()
|
||||
|
||||
async def _try_swap_buffers(self):
|
||||
"""
|
||||
Attempt to swap ping-pong buffers.
|
||||
Only swaps if the inactive buffer is not currently processing.
|
||||
|
||||
This method should be called with buffer_lock held.
|
||||
"""
|
||||
# Check if inactive buffer is available
|
||||
inactive_state = self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
|
||||
|
||||
if inactive_state != BufferState.PROCESSING:
|
||||
# Swap active buffer
|
||||
old_active = self.active_buffer
|
||||
self.active_buffer = "B" if old_active == "A" else "A"
|
||||
|
||||
# Mark old active buffer as ready for processing
|
||||
if old_active == "A":
|
||||
self.buffer_a_state = BufferState.PROCESSING
|
||||
buffer_size = len(self.buffer_a)
|
||||
else:
|
||||
self.buffer_b_state = BufferState.PROCESSING
|
||||
buffer_size = len(self.buffer_b)
|
||||
|
||||
logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})")
|
||||
|
||||
async def _batch_processor(self):
|
||||
"""Background task that processes batches when available"""
|
||||
while self.running:
|
||||
await asyncio.sleep(0.001) # Check every 1ms
|
||||
|
||||
# Check if buffer A needs processing
|
||||
if self.buffer_a_state == BufferState.PROCESSING:
|
||||
await self._process_buffer("A")
|
||||
|
||||
# Check if buffer B needs processing
|
||||
if self.buffer_b_state == BufferState.PROCESSING:
|
||||
await self._process_buffer("B")
|
||||
|
||||
async def _process_buffer(self, buffer_name: str):
|
||||
"""
|
||||
Process a buffer through inference.
|
||||
|
||||
Args:
|
||||
buffer_name: "A" or "B"
|
||||
"""
|
||||
# Extract buffer to process
|
||||
async with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
batch = self.buffer_a.copy()
|
||||
self.buffer_a.clear()
|
||||
else:
|
||||
batch = self.buffer_b.copy()
|
||||
self.buffer_b.clear()
|
||||
|
||||
if len(batch) == 0:
|
||||
# Mark as idle and return
|
||||
async with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
else:
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
return
|
||||
|
||||
# Process batch (outside lock to allow concurrent submissions)
|
||||
try:
|
||||
start_time = time.time()
|
||||
results = await self._run_batch_inference(batch)
|
||||
inference_time = time.time() - start_time
|
||||
|
||||
# Update statistics
|
||||
self.total_frames_processed += len(batch)
|
||||
self.total_batches_processed += 1
|
||||
|
||||
logger.debug(
|
||||
f"Processed batch of {len(batch)} frames in {inference_time*1000:.2f}ms "
|
||||
f"({inference_time*1000/len(batch):.2f}ms per frame)"
|
||||
)
|
||||
|
||||
# Emit results to callbacks
|
||||
for batch_frame, result in zip(batch, results):
|
||||
callback = self.result_callbacks.get(batch_frame.stream_id)
|
||||
if callback:
|
||||
# Schedule callback asynchronously
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
asyncio.create_task(callback(result))
|
||||
else:
|
||||
# Run sync callback in executor to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.call_soon(lambda cb=callback, r=result: cb(r))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing batch: {e}", exc_info=True)
|
||||
# TODO: Emit error events to streams
|
||||
|
||||
finally:
|
||||
# Mark buffer as idle
|
||||
async with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
else:
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
|
||||
async def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run inference on a batch of frames.
|
||||
|
||||
Args:
|
||||
batch: List of BatchFrame objects
|
||||
|
||||
Returns:
|
||||
List of detection results (one per frame)
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Check if model supports batching
|
||||
if self.model_batch_size == 1:
|
||||
# Process frames one at a time for batch_size=1 models
|
||||
return await self._run_sequential_inference(batch, loop)
|
||||
else:
|
||||
# Use true batching for models that support it
|
||||
return await self._run_batched_inference(batch, loop)
|
||||
|
||||
async def _run_sequential_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
|
||||
"""Run inference sequentially for batch_size=1 models"""
|
||||
results = []
|
||||
|
||||
for batch_frame in batch:
|
||||
# Preprocess frame
|
||||
if self.preprocess_fn:
|
||||
processed = self.preprocess_fn(batch_frame.frame)
|
||||
else:
|
||||
# Ensure we have batch dimension
|
||||
processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame
|
||||
|
||||
# Run inference for this frame
|
||||
outputs = await loop.run_in_executor(
|
||||
None,
|
||||
lambda p=processed: self.model_repository.infer(
|
||||
self.model_id,
|
||||
{"images": p},
|
||||
synchronize=True
|
||||
)
|
||||
)
|
||||
|
||||
# Postprocess
|
||||
if self.postprocess_fn:
|
||||
try:
|
||||
detections = self.postprocess_fn(outputs)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
|
||||
# Return empty detections on error
|
||||
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
|
||||
else:
|
||||
detections = outputs
|
||||
|
||||
result = {
|
||||
"stream_id": batch_frame.stream_id,
|
||||
"timestamp": batch_frame.timestamp,
|
||||
"detections": detections,
|
||||
"metadata": batch_frame.metadata,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
async def _run_batched_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
|
||||
"""Run true batched inference for models that support it"""
|
||||
# Preprocess frames (on GPU)
|
||||
preprocessed = []
|
||||
for batch_frame in batch:
|
||||
if self.preprocess_fn:
|
||||
processed = self.preprocess_fn(batch_frame.frame)
|
||||
# Preprocess may return (1, C, H, W), squeeze to (C, H, W)
|
||||
if processed.dim() == 4 and processed.shape[0] == 1:
|
||||
processed = processed.squeeze(0)
|
||||
else:
|
||||
processed = batch_frame.frame
|
||||
preprocessed.append(processed)
|
||||
|
||||
# Stack into batch tensor: (N, C, H, W)
|
||||
batch_tensor = torch.stack(preprocessed, dim=0)
|
||||
|
||||
# Limit batch size to model's max batch size
|
||||
if batch_tensor.shape[0] > self.model_batch_size:
|
||||
logger.warning(
|
||||
f"Batch size {batch_tensor.shape[0]} exceeds model max {self.model_batch_size}, "
|
||||
f"will split into sub-batches"
|
||||
)
|
||||
# TODO: Handle splitting into sub-batches
|
||||
batch_tensor = batch_tensor[:self.model_batch_size]
|
||||
batch = batch[:self.model_batch_size]
|
||||
|
||||
# Run inference
|
||||
outputs = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model_repository.infer(
|
||||
self.model_id,
|
||||
{"images": batch_tensor},
|
||||
synchronize=True
|
||||
)
|
||||
)
|
||||
|
||||
# Postprocess results (split batch back to individual results)
|
||||
results = []
|
||||
for i, batch_frame in enumerate(batch):
|
||||
# Extract single frame output from batch
|
||||
frame_output = {}
|
||||
for k, v in outputs.items():
|
||||
# v has shape (N, ...), extract index i and keep batch dimension
|
||||
frame_output[k] = v[i:i+1] # Shape: (1, ...)
|
||||
|
||||
if self.postprocess_fn:
|
||||
try:
|
||||
detections = self.postprocess_fn(frame_output)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
|
||||
# Return empty detections on error
|
||||
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
|
||||
else:
|
||||
detections = frame_output
|
||||
|
||||
result = {
|
||||
"stream_id": batch_frame.stream_id,
|
||||
"timestamp": batch_frame.timestamp,
|
||||
"detections": detections,
|
||||
"metadata": batch_frame.metadata,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
async def _process_remaining_buffers(self):
|
||||
"""Process any remaining frames in buffers during shutdown"""
|
||||
if len(self.buffer_a) > 0:
|
||||
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
|
||||
await self._process_buffer("A")
|
||||
if len(self.buffer_b) > 0:
|
||||
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
|
||||
await self._process_buffer("B")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get current buffer statistics"""
|
||||
return {
|
||||
"active_buffer": self.active_buffer,
|
||||
"buffer_a_size": len(self.buffer_a),
|
||||
"buffer_b_size": len(self.buffer_b),
|
||||
"buffer_a_state": self.buffer_a_state.value,
|
||||
"buffer_b_state": self.buffer_b_state.value,
|
||||
"registered_streams": len(self.result_callbacks),
|
||||
"total_frames_processed": self.total_frames_processed,
|
||||
"total_batches_processed": self.total_batches_processed,
|
||||
"avg_batch_size": (
|
||||
self.total_frames_processed / self.total_batches_processed
|
||||
if self.total_batches_processed > 0 else 0
|
||||
),
|
||||
}
|
||||
566
services/stream_connection_manager.py
Normal file
566
services/stream_connection_manager.py
Normal file
|
|
@ -0,0 +1,566 @@
|
|||
"""
|
||||
StreamConnectionManager - Async orchestration for stream processing with batched inference.
|
||||
|
||||
This module provides high-level connection management for multiple RTSP streams,
|
||||
coordinating decoders, batched inference, and tracking with an event-driven API.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional, Callable, AsyncIterator, Tuple, Any, List
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from .model_controller import ModelController
|
||||
from .stream_decoder import StreamDecoderFactory
|
||||
from .tracking_factory import TrackingFactory
|
||||
from .model_repository import TensorRTModelRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectionStatus(Enum):
|
||||
"""Status of a stream connection"""
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
DISCONNECTED = "disconnected"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrackingResult:
|
||||
"""Result emitted to user callbacks"""
|
||||
stream_id: str
|
||||
timestamp: float
|
||||
tracked_objects: List # List of TrackedObject from TrackingController
|
||||
detections: List # Raw detections
|
||||
frame_shape: Tuple[int, int, int]
|
||||
metadata: Dict
|
||||
|
||||
|
||||
class StreamConnection:
|
||||
"""
|
||||
Represents a single stream connection with event emission.
|
||||
|
||||
This class wraps a StreamDecoder, polls frames asynchronously, submits them
|
||||
to the ModelController for batched inference, runs tracking, and emits results
|
||||
via queues or callbacks.
|
||||
|
||||
Args:
|
||||
stream_id: Unique identifier for this stream
|
||||
decoder: StreamDecoder instance
|
||||
model_controller: ModelController for batched inference
|
||||
tracking_controller: TrackingController for object tracking
|
||||
poll_interval: Frame polling interval in seconds (default: 0.01)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream_id: str,
|
||||
decoder,
|
||||
model_controller: ModelController,
|
||||
tracking_controller,
|
||||
poll_interval: float = 0.01,
|
||||
):
|
||||
self.stream_id = stream_id
|
||||
self.decoder = decoder
|
||||
self.model_controller = model_controller
|
||||
self.tracking_controller = tracking_controller
|
||||
self.poll_interval = poll_interval
|
||||
|
||||
self.status = ConnectionStatus.CONNECTING
|
||||
self.frame_count = 0
|
||||
self.last_frame_time = 0.0
|
||||
|
||||
# Event emission
|
||||
self.result_queue: asyncio.Queue[TrackingResult] = asyncio.Queue()
|
||||
self.error_queue: asyncio.Queue[Exception] = asyncio.Queue()
|
||||
|
||||
# Tasks
|
||||
self.poller_task: Optional[asyncio.Task] = None
|
||||
self.running = False
|
||||
|
||||
async def start(self):
|
||||
"""Start the connection (decoder and frame polling)"""
|
||||
# Start decoder (runs in background thread)
|
||||
self.decoder.start()
|
||||
|
||||
# Wait for initial connection (try for up to 10 seconds)
|
||||
max_wait = 10.0
|
||||
wait_interval = 0.5
|
||||
elapsed = 0.0
|
||||
|
||||
while elapsed < max_wait:
|
||||
await asyncio.sleep(wait_interval)
|
||||
elapsed += wait_interval
|
||||
|
||||
if self.decoder.is_connected():
|
||||
self.status = ConnectionStatus.CONNECTED
|
||||
logger.info(f"Stream {self.stream_id} connected after {elapsed:.1f}s")
|
||||
break
|
||||
else:
|
||||
# Timeout - but don't fail hard, let it try to connect in background
|
||||
logger.warning(f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying...")
|
||||
self.status = ConnectionStatus.CONNECTING
|
||||
|
||||
# Start frame polling task
|
||||
self.running = True
|
||||
self.poller_task = asyncio.create_task(self._frame_poller())
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the connection and cleanup"""
|
||||
logger.info(f"Stopping stream {self.stream_id}...")
|
||||
self.running = False
|
||||
|
||||
if self.poller_task:
|
||||
self.poller_task.cancel()
|
||||
try:
|
||||
await self.poller_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Stop decoder
|
||||
self.decoder.stop()
|
||||
|
||||
# Unregister from model controller
|
||||
self.model_controller.unregister_callback(self.stream_id)
|
||||
|
||||
self.status = ConnectionStatus.DISCONNECTED
|
||||
logger.info(f"Stream {self.stream_id} stopped")
|
||||
|
||||
async def _frame_poller(self):
|
||||
"""Poll frames from threaded decoder and submit to model controller"""
|
||||
last_frame_ptr = None
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Poll frame from decoder (runs in thread)
|
||||
frame = self.decoder.get_latest_frame(rgb=True)
|
||||
|
||||
# Check if we got a new frame (avoid reprocessing same frame)
|
||||
if frame is not None and frame.data_ptr() != last_frame_ptr:
|
||||
last_frame_ptr = frame.data_ptr()
|
||||
self.last_frame_time = time.time()
|
||||
self.frame_count += 1
|
||||
|
||||
# Submit to model controller for batched inference
|
||||
await self.model_controller.submit_frame(
|
||||
stream_id=self.stream_id,
|
||||
frame=frame,
|
||||
metadata={
|
||||
"frame_number": self.frame_count,
|
||||
"shape": tuple(frame.shape),
|
||||
}
|
||||
)
|
||||
|
||||
# Check decoder status
|
||||
if not self.decoder.is_connected():
|
||||
if self.status == ConnectionStatus.CONNECTED:
|
||||
logger.warning(f"Stream {self.stream_id} disconnected")
|
||||
self.status = ConnectionStatus.DISCONNECTED
|
||||
# Decoder will auto-reconnect, just update status
|
||||
await asyncio.sleep(1.0)
|
||||
if self.decoder.is_connected():
|
||||
logger.info(f"Stream {self.stream_id} reconnected")
|
||||
self.status = ConnectionStatus.CONNECTED
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in frame poller for {self.stream_id}: {e}", exc_info=True)
|
||||
await self.error_queue.put(e)
|
||||
self.status = ConnectionStatus.ERROR
|
||||
|
||||
# Sleep until next poll
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
|
||||
async def _handle_inference_result(self, result: Dict[str, Any]):
|
||||
"""
|
||||
Callback invoked by ModelController when inference is done.
|
||||
Runs tracking and emits final result.
|
||||
|
||||
Args:
|
||||
result: Inference result dictionary
|
||||
"""
|
||||
try:
|
||||
# Extract detections
|
||||
detections = result["detections"]
|
||||
|
||||
# Run tracking (this is sync, so run in executor)
|
||||
loop = asyncio.get_event_loop()
|
||||
tracked_objects = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._run_tracking_sync(detections)
|
||||
)
|
||||
|
||||
# Create tracking result
|
||||
tracking_result = TrackingResult(
|
||||
stream_id=self.stream_id,
|
||||
timestamp=result["timestamp"],
|
||||
tracked_objects=tracked_objects,
|
||||
detections=detections,
|
||||
frame_shape=result["metadata"].get("shape"),
|
||||
metadata=result["metadata"],
|
||||
)
|
||||
|
||||
# Emit to result queue
|
||||
await self.result_queue.put(tracking_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
|
||||
await self.error_queue.put(e)
|
||||
|
||||
def _run_tracking_sync(self, detections):
|
||||
"""
|
||||
Run tracking synchronously (called from executor).
|
||||
|
||||
Args:
|
||||
detections: Detection tensor (N, 6) [x1, y1, x2, y2, conf, class_id]
|
||||
|
||||
Returns:
|
||||
List of TrackedObject instances
|
||||
"""
|
||||
# Use the TrackingController's internal tracking with detections
|
||||
# We need to manually update tracks since we already have detections
|
||||
import torch
|
||||
|
||||
with self.tracking_controller._lock:
|
||||
self.tracking_controller._frame_count += 1
|
||||
|
||||
# If no detections, just cleanup and return current tracks
|
||||
if len(detections) == 0:
|
||||
self.tracking_controller._cleanup_stale_tracks()
|
||||
return list(self.tracking_controller._tracks.values())
|
||||
|
||||
# Run IoU tracking to associate detections with existing tracks
|
||||
associations = self.tracking_controller._iou_tracking(detections)
|
||||
|
||||
# Update or create tracks
|
||||
for (det_idx, track_id), detection in zip(associations, detections):
|
||||
bbox = detection[:4].cpu().tolist()
|
||||
confidence = float(detection[4])
|
||||
class_id = int(detection[5]) if detection.shape[0] > 5 else 0
|
||||
|
||||
if track_id == -1:
|
||||
# Create new track
|
||||
new_track = self.tracking_controller._create_track(
|
||||
bbox, confidence, class_id, self.tracking_controller._frame_count
|
||||
)
|
||||
self.tracking_controller._tracks[new_track.track_id] = new_track
|
||||
else:
|
||||
# Update existing track
|
||||
self.tracking_controller._tracks[track_id].update(
|
||||
bbox, confidence, self.tracking_controller._frame_count
|
||||
)
|
||||
|
||||
# Cleanup stale tracks
|
||||
self.tracking_controller._cleanup_stale_tracks()
|
||||
|
||||
return list(self.tracking_controller._tracks.values())
|
||||
|
||||
async def tracking_results(self) -> AsyncIterator[TrackingResult]:
|
||||
"""
|
||||
Async generator for tracking results.
|
||||
|
||||
Usage:
|
||||
async for result in connection.tracking_results():
|
||||
print(result.tracked_objects)
|
||||
|
||||
Yields:
|
||||
TrackingResult objects as they become available
|
||||
"""
|
||||
while self.running or not self.result_queue.empty():
|
||||
try:
|
||||
result = await asyncio.wait_for(self.result_queue.get(), timeout=1.0)
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
async def errors(self) -> AsyncIterator[Exception]:
|
||||
"""
|
||||
Async generator for errors.
|
||||
|
||||
Yields:
|
||||
Exception objects as they occur
|
||||
"""
|
||||
while self.running or not self.error_queue.empty():
|
||||
try:
|
||||
error = await asyncio.wait_for(self.error_queue.get(), timeout=1.0)
|
||||
yield error
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get connection statistics"""
|
||||
return {
|
||||
"stream_id": self.stream_id,
|
||||
"status": self.status.value,
|
||||
"frame_count": self.frame_count,
|
||||
"last_frame_time": self.last_frame_time,
|
||||
"decoder_connected": self.decoder.is_connected(),
|
||||
"decoder_buffer_size": self.decoder.get_buffer_size(),
|
||||
"result_queue_size": self.result_queue.qsize(),
|
||||
"error_queue_size": self.error_queue.qsize(),
|
||||
}
|
||||
|
||||
|
||||
class StreamConnectionManager:
|
||||
"""
|
||||
High-level manager for stream connections with batched inference.
|
||||
|
||||
This manager coordinates multiple RTSP streams, batched model inference,
|
||||
and object tracking through an async event-driven API.
|
||||
|
||||
Args:
|
||||
gpu_id: GPU device ID (default: 0)
|
||||
batch_size: Maximum batch size for inference (default: 16)
|
||||
force_timeout: Force buffer switch timeout in seconds (default: 0.05)
|
||||
poll_interval: Frame polling interval in seconds (default: 0.01)
|
||||
|
||||
Example:
|
||||
manager = StreamConnectionManager(gpu_id=0, batch_size=16)
|
||||
await manager.initialize(model_path="yolov8n.trt", ...)
|
||||
connection = await manager.connect_stream(rtsp_url, on_tracking_result=callback)
|
||||
await asyncio.sleep(60)
|
||||
await manager.shutdown()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gpu_id: int = 0,
|
||||
batch_size: int = 16,
|
||||
force_timeout: float = 0.05,
|
||||
poll_interval: float = 0.01,
|
||||
):
|
||||
self.gpu_id = gpu_id
|
||||
self.batch_size = batch_size
|
||||
self.force_timeout = force_timeout
|
||||
self.poll_interval = poll_interval
|
||||
|
||||
# Factories
|
||||
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
|
||||
self.tracking_factory = TrackingFactory(gpu_id=gpu_id)
|
||||
self.model_repository = TensorRTModelRepository(gpu_id=gpu_id)
|
||||
|
||||
# Controllers
|
||||
self.model_controller: Optional[ModelController] = None
|
||||
self.tracking_controller = None
|
||||
|
||||
# Connections
|
||||
self.connections: Dict[str, StreamConnection] = {}
|
||||
|
||||
# State
|
||||
self.initialized = False
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
model_path: str,
|
||||
model_id: str = "detector",
|
||||
preprocess_fn: Optional[Callable] = None,
|
||||
postprocess_fn: Optional[Callable] = None,
|
||||
num_contexts: int = 4,
|
||||
):
|
||||
"""
|
||||
Initialize the manager with a model.
|
||||
|
||||
Args:
|
||||
model_path: Path to TensorRT model file
|
||||
model_id: Model identifier (default: "detector")
|
||||
preprocess_fn: Preprocessing function (e.g., YOLOv8Utils.preprocess)
|
||||
postprocess_fn: Postprocessing function (e.g., YOLOv8Utils.postprocess)
|
||||
num_contexts: Number of TensorRT execution contexts (default: 4)
|
||||
"""
|
||||
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
|
||||
|
||||
# Load model
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model_repository.load_model(model_id, model_path, num_contexts=num_contexts)
|
||||
)
|
||||
logger.info(f"Loaded model {model_id} from {model_path}")
|
||||
|
||||
# Create model controller
|
||||
self.model_controller = ModelController(
|
||||
model_repository=self.model_repository,
|
||||
model_id=model_id,
|
||||
batch_size=self.batch_size,
|
||||
force_timeout=self.force_timeout,
|
||||
preprocess_fn=preprocess_fn,
|
||||
postprocess_fn=postprocess_fn,
|
||||
)
|
||||
await self.model_controller.start()
|
||||
|
||||
# Create tracking controller
|
||||
self.tracking_controller = self.tracking_factory.create_controller(
|
||||
model_repository=self.model_repository,
|
||||
model_id=model_id,
|
||||
tracker_type="iou",
|
||||
)
|
||||
logger.info("TrackingController created")
|
||||
|
||||
self.initialized = True
|
||||
logger.info("StreamConnectionManager initialized successfully")
|
||||
|
||||
async def connect_stream(
|
||||
self,
|
||||
rtsp_url: str,
|
||||
stream_id: Optional[str] = None,
|
||||
on_tracking_result: Optional[Callable] = None,
|
||||
on_error: Optional[Callable] = None,
|
||||
buffer_size: int = 30,
|
||||
) -> StreamConnection:
|
||||
"""
|
||||
Connect to a stream and start processing.
|
||||
|
||||
Args:
|
||||
rtsp_url: RTSP stream URL
|
||||
stream_id: Optional stream identifier (auto-generated if not provided)
|
||||
on_tracking_result: Optional callback for tracking results (sync or async)
|
||||
on_error: Optional callback for errors (sync or async)
|
||||
buffer_size: Decoder buffer size (default: 30)
|
||||
|
||||
Returns:
|
||||
StreamConnection object for this stream
|
||||
|
||||
Raises:
|
||||
RuntimeError: If manager is not initialized
|
||||
ConnectionError: If stream connection fails
|
||||
"""
|
||||
if not self.initialized:
|
||||
raise RuntimeError("Manager not initialized. Call initialize() first.")
|
||||
|
||||
# Generate stream ID if not provided
|
||||
if stream_id is None:
|
||||
stream_id = f"stream_{len(self.connections)}"
|
||||
|
||||
logger.info(f"Connecting to stream {stream_id}: {rtsp_url}")
|
||||
|
||||
# Create decoder
|
||||
decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=buffer_size)
|
||||
|
||||
# Create connection
|
||||
connection = StreamConnection(
|
||||
stream_id=stream_id,
|
||||
decoder=decoder,
|
||||
model_controller=self.model_controller,
|
||||
tracking_controller=self.tracking_controller,
|
||||
poll_interval=self.poll_interval,
|
||||
)
|
||||
|
||||
# Register callback with model controller
|
||||
self.model_controller.register_callback(
|
||||
stream_id,
|
||||
connection._handle_inference_result
|
||||
)
|
||||
|
||||
# Start connection
|
||||
await connection.start()
|
||||
|
||||
# Store connection
|
||||
self.connections[stream_id] = connection
|
||||
|
||||
# Set up user callbacks if provided
|
||||
if on_tracking_result:
|
||||
asyncio.create_task(self._forward_results(connection, on_tracking_result))
|
||||
|
||||
if on_error:
|
||||
asyncio.create_task(self._forward_errors(connection, on_error))
|
||||
|
||||
logger.info(f"Stream {stream_id} connected successfully")
|
||||
return connection
|
||||
|
||||
async def disconnect_stream(self, stream_id: str):
|
||||
"""
|
||||
Disconnect and cleanup a stream.
|
||||
|
||||
Args:
|
||||
stream_id: Stream identifier to disconnect
|
||||
"""
|
||||
connection = self.connections.get(stream_id)
|
||||
if connection:
|
||||
await connection.stop()
|
||||
del self.connections[stream_id]
|
||||
logger.info(f"Stream {stream_id} disconnected")
|
||||
|
||||
async def disconnect_all(self):
|
||||
"""Disconnect all streams"""
|
||||
logger.info("Disconnecting all streams...")
|
||||
stream_ids = list(self.connections.keys())
|
||||
for stream_id in stream_ids:
|
||||
await self.disconnect_stream(stream_id)
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown the manager and cleanup all resources"""
|
||||
logger.info("Shutting down StreamConnectionManager...")
|
||||
|
||||
# Disconnect all streams
|
||||
await self.disconnect_all()
|
||||
|
||||
# Stop model controller
|
||||
if self.model_controller:
|
||||
await self.model_controller.stop()
|
||||
|
||||
# Note: Model repository cleanup is sync and may cause segfaults
|
||||
# Leaving cleanup to garbage collection for now
|
||||
|
||||
self.initialized = False
|
||||
logger.info("StreamConnectionManager shutdown complete")
|
||||
|
||||
async def _forward_results(self, connection: StreamConnection, callback: Callable):
|
||||
"""
|
||||
Forward results from connection to user callback.
|
||||
|
||||
Args:
|
||||
connection: StreamConnection to listen to
|
||||
callback: User callback (sync or async)
|
||||
"""
|
||||
try:
|
||||
async for result in connection.tracking_results():
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(result)
|
||||
else:
|
||||
callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True)
|
||||
|
||||
async def _forward_errors(self, connection: StreamConnection, callback: Callable):
|
||||
"""
|
||||
Forward errors from connection to user callback.
|
||||
|
||||
Args:
|
||||
connection: StreamConnection to listen to
|
||||
callback: User callback (sync or async)
|
||||
"""
|
||||
try:
|
||||
async for error in connection.errors():
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(error)
|
||||
else:
|
||||
callback(error)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics for all connections.
|
||||
|
||||
Returns:
|
||||
Dictionary with manager and connection statistics
|
||||
"""
|
||||
return {
|
||||
"manager": {
|
||||
"initialized": self.initialized,
|
||||
"gpu_id": self.gpu_id,
|
||||
"num_connections": len(self.connections),
|
||||
"batch_size": self.batch_size,
|
||||
"force_timeout": self.force_timeout,
|
||||
"poll_interval": self.poll_interval,
|
||||
},
|
||||
"model_controller": self.model_controller.get_stats() if self.model_controller else {},
|
||||
"connections": {
|
||||
stream_id: conn.get_stats()
|
||||
for stream_id, conn in self.connections.items()
|
||||
},
|
||||
}
|
||||
373
test_event_driven.py
Normal file
373
test_event_driven.py
Normal file
|
|
@ -0,0 +1,373 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for event-driven stream processing with batched inference.
|
||||
|
||||
This demonstrates the new AsyncIO-based API for connecting to RTSP streams,
|
||||
processing frames through batched inference, and receiving tracking results
|
||||
via callbacks and async generators.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from services import StreamConnectionManager, YOLOv8Utils, COCO_CLASSES
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Example 1: Simple callback pattern
|
||||
async def example_callback_pattern():
|
||||
"""Demonstrates the simple callback pattern for a single stream"""
|
||||
logger.info("=== Example 1: Callback Pattern ===")
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
camera_url = os.getenv('CAMERA_URL_1')
|
||||
if not camera_url:
|
||||
logger.error("CAMERA_URL_1 not found in .env file")
|
||||
return
|
||||
|
||||
# Create manager
|
||||
manager = StreamConnectionManager(
|
||||
gpu_id=0,
|
||||
batch_size=16,
|
||||
force_timeout=0.05, # 50ms
|
||||
poll_interval=0.01, # 100 FPS
|
||||
)
|
||||
|
||||
# Initialize with YOLOv8 model
|
||||
model_path = "models/yolov8n.trt" # Adjust path as needed
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f"Model file not found: {model_path}")
|
||||
return
|
||||
|
||||
await manager.initialize(
|
||||
model_path=model_path,
|
||||
model_id="yolo",
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess,
|
||||
)
|
||||
|
||||
# Define callback for tracking results
|
||||
def on_tracking_result(result):
|
||||
logger.info(f"[{result.stream_id}] Frame {result.metadata.get('frame_number', 0)}")
|
||||
logger.info(f" Timestamp: {result.timestamp:.3f}")
|
||||
logger.info(f" Tracked objects: {len(result.tracked_objects)}")
|
||||
|
||||
for obj in result.tracked_objects[:5]: # Show first 5
|
||||
class_name = COCO_CLASSES.get(obj.class_id, f"Class {obj.class_id}")
|
||||
logger.info(
|
||||
f" Track ID {obj.track_id}: {class_name}, "
|
||||
f"conf={obj.confidence:.2f}, bbox={obj.bbox}"
|
||||
)
|
||||
|
||||
def on_error(error):
|
||||
logger.error(f"Stream error: {error}")
|
||||
|
||||
# Connect to stream
|
||||
connection = await manager.connect_stream(
|
||||
rtsp_url=camera_url,
|
||||
stream_id="camera1",
|
||||
on_tracking_result=on_tracking_result,
|
||||
on_error=on_error,
|
||||
)
|
||||
|
||||
# Let it run for 30 seconds
|
||||
logger.info("Processing stream for 30 seconds...")
|
||||
await asyncio.sleep(30)
|
||||
|
||||
# Get statistics
|
||||
stats = manager.get_stats()
|
||||
logger.info("=== Statistics ===")
|
||||
logger.info(f"Manager: {stats['manager']}")
|
||||
logger.info(f"Model Controller: {stats['model_controller']}")
|
||||
logger.info(f"Connection: {stats['connections']['camera1']}")
|
||||
|
||||
# Cleanup
|
||||
await manager.shutdown()
|
||||
logger.info("Example 1 complete\n")
|
||||
|
||||
|
||||
# Example 2: Async generator pattern with multiple streams
|
||||
async def example_async_generator_pattern():
|
||||
"""Demonstrates async generator pattern for multiple streams"""
|
||||
logger.info("=== Example 2: Async Generator Pattern (Multiple Streams) ===")
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
camera_urls = []
|
||||
for i in range(1, 5): # Try to load 4 cameras
|
||||
url = os.getenv(f'CAMERA_URL_{i}')
|
||||
if url:
|
||||
camera_urls.append((url, f"camera{i}"))
|
||||
|
||||
if not camera_urls:
|
||||
logger.error("No camera URLs found in .env file")
|
||||
return
|
||||
|
||||
logger.info(f"Found {len(camera_urls)} camera(s)")
|
||||
|
||||
# Create manager with larger batch for multiple streams
|
||||
manager = StreamConnectionManager(
|
||||
gpu_id=0,
|
||||
batch_size=32, # Larger batch for multiple streams
|
||||
force_timeout=0.05,
|
||||
)
|
||||
|
||||
# Initialize
|
||||
model_path = "models/yolov8n.trt"
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f"Model file not found: {model_path}")
|
||||
return
|
||||
|
||||
await manager.initialize(
|
||||
model_path=model_path,
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess,
|
||||
)
|
||||
|
||||
# Connect to all streams
|
||||
connections = []
|
||||
for url, stream_id in camera_urls:
|
||||
try:
|
||||
connection = await manager.connect_stream(
|
||||
rtsp_url=url,
|
||||
stream_id=stream_id,
|
||||
)
|
||||
connections.append((connection, stream_id))
|
||||
logger.info(f"Connected to {stream_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to {stream_id}: {e}")
|
||||
|
||||
# Process each stream with async generator
|
||||
async def process_stream(connection, stream_name):
|
||||
"""Process results from a single stream"""
|
||||
frame_count = 0
|
||||
person_detections = 0
|
||||
|
||||
async for result in connection.tracking_results():
|
||||
frame_count += 1
|
||||
|
||||
# Count person detections (class_id 0 in COCO)
|
||||
for obj in result.tracked_objects:
|
||||
if obj.class_id == 0:
|
||||
person_detections += 1
|
||||
|
||||
# Log every 10th frame
|
||||
if frame_count % 10 == 0:
|
||||
logger.info(
|
||||
f"[{stream_name}] Processed {frame_count} frames, "
|
||||
f"{person_detections} person detections"
|
||||
)
|
||||
|
||||
# Stop after 100 frames
|
||||
if frame_count >= 100:
|
||||
break
|
||||
|
||||
# Run all streams concurrently
|
||||
tasks = [
|
||||
asyncio.create_task(process_stream(conn, name))
|
||||
for conn, name in connections
|
||||
]
|
||||
|
||||
# Wait for all tasks to complete
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Get final statistics
|
||||
stats = manager.get_stats()
|
||||
logger.info("\n=== Final Statistics ===")
|
||||
logger.info(f"Total connections: {stats['manager']['num_connections']}")
|
||||
logger.info(f"Frames processed: {stats['model_controller']['total_frames_processed']}")
|
||||
logger.info(f"Batches processed: {stats['model_controller']['total_batches_processed']}")
|
||||
logger.info(f"Avg batch size: {stats['model_controller']['avg_batch_size']:.2f}")
|
||||
|
||||
# Cleanup
|
||||
await manager.shutdown()
|
||||
logger.info("Example 2 complete\n")
|
||||
|
||||
|
||||
# Example 3: Queue-based pattern
|
||||
async def example_queue_pattern():
|
||||
"""Demonstrates direct queue access for custom processing"""
|
||||
logger.info("=== Example 3: Queue-Based Pattern ===")
|
||||
|
||||
# Load environment
|
||||
load_dotenv()
|
||||
camera_url = os.getenv('CAMERA_URL_1')
|
||||
if not camera_url:
|
||||
logger.error("CAMERA_URL_1 not found in .env file")
|
||||
return
|
||||
|
||||
# Create manager
|
||||
manager = StreamConnectionManager(gpu_id=0, batch_size=16)
|
||||
|
||||
# Initialize
|
||||
model_path = "models/yolov8n.trt"
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f"Model file not found: {model_path}")
|
||||
return
|
||||
|
||||
await manager.initialize(
|
||||
model_path=model_path,
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess,
|
||||
)
|
||||
|
||||
# Connect to stream (no callback)
|
||||
connection = await manager.connect_stream(
|
||||
rtsp_url=camera_url,
|
||||
stream_id="main_camera",
|
||||
)
|
||||
|
||||
# Use the built-in queue directly
|
||||
result_queue = connection.result_queue
|
||||
|
||||
# Process results from queue
|
||||
processed_count = 0
|
||||
while processed_count < 50: # Process 50 frames
|
||||
try:
|
||||
result = await asyncio.wait_for(result_queue.get(), timeout=5.0)
|
||||
processed_count += 1
|
||||
|
||||
# Custom processing
|
||||
has_person = any(obj.class_id == 0 for obj in result.tracked_objects)
|
||||
has_car = any(obj.class_id == 2 for obj in result.tracked_objects)
|
||||
|
||||
if has_person or has_car:
|
||||
logger.info(
|
||||
f"Frame {processed_count}: "
|
||||
f"Person={'Yes' if has_person else 'No'}, "
|
||||
f"Car={'Yes' if has_car else 'No'}"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timeout waiting for result")
|
||||
break
|
||||
|
||||
# Cleanup
|
||||
await manager.shutdown()
|
||||
logger.info("Example 3 complete\n")
|
||||
|
||||
|
||||
# Example 4: Performance monitoring
|
||||
async def example_performance_monitoring():
|
||||
"""Demonstrates real-time performance monitoring"""
|
||||
logger.info("=== Example 4: Performance Monitoring ===")
|
||||
|
||||
# Load environment
|
||||
load_dotenv()
|
||||
camera_url = os.getenv('CAMERA_URL_1')
|
||||
if not camera_url:
|
||||
logger.error("CAMERA_URL_1 not found in .env file")
|
||||
return
|
||||
|
||||
# Create manager
|
||||
manager = StreamConnectionManager(
|
||||
gpu_id=0,
|
||||
batch_size=16,
|
||||
force_timeout=0.05,
|
||||
)
|
||||
|
||||
# Initialize
|
||||
model_path = "models/yolov8n.trt"
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f"Model file not found: {model_path}")
|
||||
return
|
||||
|
||||
await manager.initialize(
|
||||
model_path=model_path,
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess,
|
||||
)
|
||||
|
||||
# Track performance metrics
|
||||
frame_times = []
|
||||
last_frame_time = None
|
||||
|
||||
def on_tracking_result(result):
|
||||
nonlocal last_frame_time
|
||||
current_time = time.time()
|
||||
|
||||
if last_frame_time is not None:
|
||||
frame_interval = current_time - last_frame_time
|
||||
frame_times.append(frame_interval)
|
||||
|
||||
last_frame_time = current_time
|
||||
|
||||
# Connect
|
||||
connection = await manager.connect_stream(
|
||||
rtsp_url=camera_url,
|
||||
on_tracking_result=on_tracking_result,
|
||||
)
|
||||
|
||||
# Monitor stats periodically
|
||||
for i in range(6): # Monitor for 60 seconds
|
||||
await asyncio.sleep(10)
|
||||
|
||||
stats = manager.get_stats()
|
||||
model_stats = stats['model_controller']
|
||||
conn_stats = stats['connections'].get('stream_0', {})
|
||||
|
||||
logger.info(f"\n=== Stats Update {i+1} ===")
|
||||
logger.info(f"Buffer A: {model_stats['buffer_a_size']} ({model_stats['buffer_a_state']})")
|
||||
logger.info(f"Buffer B: {model_stats['buffer_b_size']} ({model_stats['buffer_b_state']})")
|
||||
logger.info(f"Active buffer: {model_stats['active_buffer']}")
|
||||
logger.info(f"Total frames: {model_stats['total_frames_processed']}")
|
||||
logger.info(f"Total batches: {model_stats['total_batches_processed']}")
|
||||
logger.info(f"Avg batch size: {model_stats['avg_batch_size']:.2f}")
|
||||
logger.info(f"Decoder frames: {conn_stats.get('frame_count', 0)}")
|
||||
|
||||
if frame_times:
|
||||
avg_fps = 1.0 / (sum(frame_times) / len(frame_times))
|
||||
logger.info(f"Processing FPS: {avg_fps:.2f}")
|
||||
|
||||
# Cleanup
|
||||
await manager.shutdown()
|
||||
logger.info("Example 4 complete\n")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all examples"""
|
||||
logger.info("Starting event-driven stream processing tests\n")
|
||||
|
||||
# Choose which example to run
|
||||
choice = os.getenv('EXAMPLE', '1')
|
||||
|
||||
if choice == '1':
|
||||
await example_callback_pattern()
|
||||
elif choice == '2':
|
||||
await example_async_generator_pattern()
|
||||
elif choice == '3':
|
||||
await example_queue_pattern()
|
||||
elif choice == '4':
|
||||
await example_performance_monitoring()
|
||||
elif choice == 'all':
|
||||
await example_callback_pattern()
|
||||
await asyncio.sleep(2)
|
||||
await example_async_generator_pattern()
|
||||
await asyncio.sleep(2)
|
||||
await example_queue_pattern()
|
||||
await asyncio.sleep(2)
|
||||
await example_performance_monitoring()
|
||||
else:
|
||||
logger.error(f"Invalid choice: {choice}")
|
||||
logger.info("Set EXAMPLE env var to 1, 2, 3, 4, or 'all'")
|
||||
|
||||
logger.info("All tests complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\nInterrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}", exc_info=True)
|
||||
117
test_event_driven_quick.py
Executable file
117
test_event_driven_quick.py
Executable file
|
|
@ -0,0 +1,117 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick test for event-driven stream processing - runs for 20 seconds.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from services import StreamConnectionManager, YOLOv8Utils, COCO_CLASSES
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Quick test with callback pattern"""
|
||||
logger.info("=== Quick Event-Driven Test (20 seconds) ===")
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
camera_url = os.getenv('CAMERA_URL_1')
|
||||
if not camera_url:
|
||||
logger.error("CAMERA_URL_1 not found in .env file")
|
||||
return
|
||||
|
||||
# Create manager
|
||||
manager = StreamConnectionManager(
|
||||
gpu_id=0,
|
||||
batch_size=16,
|
||||
force_timeout=0.05, # 50ms
|
||||
poll_interval=0.01, # 100 FPS
|
||||
)
|
||||
|
||||
# Initialize with YOLOv8 model
|
||||
model_path = "models/yolov8n.trt"
|
||||
logger.info(f"Initializing with model: {model_path}")
|
||||
|
||||
await manager.initialize(
|
||||
model_path=model_path,
|
||||
model_id="yolo",
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess,
|
||||
)
|
||||
|
||||
result_count = 0
|
||||
|
||||
# Define callback for tracking results
|
||||
def on_tracking_result(result):
|
||||
nonlocal result_count
|
||||
result_count += 1
|
||||
|
||||
if result_count % 5 == 0: # Log every 5th result
|
||||
logger.info(f"[{result.stream_id}] Frame {result.metadata.get('frame_number', 0)}")
|
||||
logger.info(f" Tracked objects: {len(result.tracked_objects)}")
|
||||
|
||||
for obj in result.tracked_objects[:3]: # Show first 3
|
||||
class_name = COCO_CLASSES.get(obj.class_id, f"Class {obj.class_id}")
|
||||
logger.info(
|
||||
f" Track ID {obj.track_id}: {class_name}, "
|
||||
f"conf={obj.confidence:.2f}"
|
||||
)
|
||||
|
||||
def on_error(error):
|
||||
logger.error(f"Stream error: {error}")
|
||||
|
||||
# Connect to stream
|
||||
logger.info(f"Connecting to stream...")
|
||||
connection = await manager.connect_stream(
|
||||
rtsp_url=camera_url,
|
||||
stream_id="test_camera",
|
||||
on_tracking_result=on_tracking_result,
|
||||
on_error=on_error,
|
||||
)
|
||||
|
||||
# Monitor for 20 seconds with stats updates
|
||||
for i in range(4): # 4 x 5 seconds = 20 seconds
|
||||
await asyncio.sleep(5)
|
||||
|
||||
stats = manager.get_stats()
|
||||
model_stats = stats['model_controller']
|
||||
|
||||
logger.info(f"\n=== Stats Update {i+1}/4 ===")
|
||||
logger.info(f"Results received: {result_count}")
|
||||
logger.info(f"Buffer A: {model_stats['buffer_a_size']} ({model_stats['buffer_a_state']})")
|
||||
logger.info(f"Buffer B: {model_stats['buffer_b_size']} ({model_stats['buffer_b_state']})")
|
||||
logger.info(f"Active buffer: {model_stats['active_buffer']}")
|
||||
logger.info(f"Total frames processed: {model_stats['total_frames_processed']}")
|
||||
logger.info(f"Total batches: {model_stats['total_batches_processed']}")
|
||||
logger.info(f"Avg batch size: {model_stats['avg_batch_size']:.2f}")
|
||||
|
||||
# Final statistics
|
||||
stats = manager.get_stats()
|
||||
logger.info("\n=== Final Statistics ===")
|
||||
logger.info(f"Total results received: {result_count}")
|
||||
logger.info(f"Manager: {stats['manager']}")
|
||||
logger.info(f"Model Controller: {stats['model_controller']}")
|
||||
logger.info(f"Connection: {stats['connections']['test_camera']}")
|
||||
|
||||
# Cleanup
|
||||
logger.info("\nShutting down...")
|
||||
await manager.shutdown()
|
||||
logger.info("Test complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\nInterrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}", exc_info=True)
|
||||
|
|
@ -509,7 +509,7 @@ def main_multi_window():
|
|||
|
||||
if __name__ == "__main__":
|
||||
# Run single camera visualization
|
||||
main()
|
||||
# main()
|
||||
|
||||
# Uncomment to run multi-window visualization
|
||||
# main_multi_window()
|
||||
main_multi_window()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue