event driven system
This commit is contained in:
parent
0c5f56c8a6
commit
3a47920186
10 changed files with 782 additions and 253 deletions
|
|
@ -1,16 +1,17 @@
|
|||
"""
|
||||
StreamConnectionManager - Async orchestration for stream processing with batched inference.
|
||||
StreamConnectionManager - Event-driven orchestration for stream processing with batched inference.
|
||||
|
||||
This module provides high-level connection management for multiple RTSP streams,
|
||||
coordinating decoders, batched inference, and tracking with an event-driven API.
|
||||
coordinating decoders, batched inference, and tracking with callbacks and threading.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, Optional, Callable, AsyncIterator, Tuple, Any, List
|
||||
from typing import Dict, Optional, Callable, Tuple, Any, List
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import logging
|
||||
import queue
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -44,7 +45,7 @@ class StreamConnection:
|
|||
"""
|
||||
Represents a single stream connection with event emission.
|
||||
|
||||
This class wraps a StreamDecoder, polls frames asynchronously, submits them
|
||||
This class wraps a StreamDecoder, polls frames in a thread, submits them
|
||||
to the ModelController for batched inference, runs tracking, and emits results
|
||||
via queues or callbacks.
|
||||
|
||||
|
|
@ -75,15 +76,19 @@ class StreamConnection:
|
|||
self.last_frame_time = 0.0
|
||||
|
||||
# Event emission
|
||||
self.result_queue: asyncio.Queue[TrackingResult] = asyncio.Queue()
|
||||
self.error_queue: asyncio.Queue[Exception] = asyncio.Queue()
|
||||
self.result_queue: queue.Queue[TrackingResult] = queue.Queue()
|
||||
self.error_queue: queue.Queue[Exception] = queue.Queue()
|
||||
|
||||
# Tasks
|
||||
self.poller_task: Optional[asyncio.Task] = None
|
||||
# Event-driven state
|
||||
self.running = False
|
||||
|
||||
async def start(self):
|
||||
"""Start the connection (decoder and frame polling)"""
|
||||
def start(self):
|
||||
"""Start the connection (decoder with frame callback)"""
|
||||
self.running = True
|
||||
|
||||
# Register callback for frame events from decoder
|
||||
self.decoder.register_frame_callback(self._on_frame_decoded)
|
||||
|
||||
# Start decoder (runs in background thread)
|
||||
self.decoder.start()
|
||||
|
||||
|
|
@ -93,7 +98,7 @@ class StreamConnection:
|
|||
elapsed = 0.0
|
||||
|
||||
while elapsed < max_wait:
|
||||
await asyncio.sleep(wait_interval)
|
||||
time.sleep(wait_interval)
|
||||
elapsed += wait_interval
|
||||
|
||||
if self.decoder.is_connected():
|
||||
|
|
@ -105,21 +110,13 @@ class StreamConnection:
|
|||
logger.warning(f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying...")
|
||||
self.status = ConnectionStatus.CONNECTING
|
||||
|
||||
# Start frame polling task
|
||||
self.running = True
|
||||
self.poller_task = asyncio.create_task(self._frame_poller())
|
||||
|
||||
async def stop(self):
|
||||
def stop(self):
|
||||
"""Stop the connection and cleanup"""
|
||||
logger.info(f"Stopping stream {self.stream_id}...")
|
||||
self.running = False
|
||||
|
||||
if self.poller_task:
|
||||
self.poller_task.cancel()
|
||||
try:
|
||||
await self.poller_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Unregister frame callback
|
||||
self.decoder.unregister_frame_callback(self._on_frame_decoded)
|
||||
|
||||
# Stop decoder
|
||||
self.decoder.stop()
|
||||
|
|
@ -130,55 +127,45 @@ class StreamConnection:
|
|||
self.status = ConnectionStatus.DISCONNECTED
|
||||
logger.info(f"Stream {self.stream_id} stopped")
|
||||
|
||||
async def _frame_poller(self):
|
||||
"""Poll frames from threaded decoder and submit to model controller"""
|
||||
last_decoder_frame_count = -1
|
||||
def _on_frame_decoded(self, frame: torch.Tensor):
|
||||
"""
|
||||
Event handler called by decoder when a new frame is decoded.
|
||||
This is the event-driven replacement for polling.
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Get current decoder frame count (no data transfer, just counter)
|
||||
decoder_frame_count = self.decoder.get_frame_count()
|
||||
Args:
|
||||
frame: RGB frame tensor on GPU (3, H, W)
|
||||
"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
# Check if decoder has a new frame (avoid reprocessing same frame)
|
||||
if decoder_frame_count > last_decoder_frame_count:
|
||||
# Poll frame from decoder (zero-copy - stays in VRAM)
|
||||
frame = self.decoder.get_latest_frame(rgb=True)
|
||||
try:
|
||||
self.last_frame_time = time.time()
|
||||
self.frame_count += 1
|
||||
|
||||
if frame is not None:
|
||||
last_decoder_frame_count = decoder_frame_count
|
||||
self.last_frame_time = time.time()
|
||||
self.frame_count += 1
|
||||
# Submit to model controller for batched inference
|
||||
self.model_controller.submit_frame(
|
||||
stream_id=self.stream_id,
|
||||
frame=frame,
|
||||
metadata={
|
||||
"frame_number": self.frame_count,
|
||||
"shape": tuple(frame.shape),
|
||||
}
|
||||
)
|
||||
|
||||
# Submit to model controller for batched inference
|
||||
await self.model_controller.submit_frame(
|
||||
stream_id=self.stream_id,
|
||||
frame=frame,
|
||||
metadata={
|
||||
"frame_number": self.frame_count,
|
||||
"shape": tuple(frame.shape),
|
||||
}
|
||||
)
|
||||
# Update connection status based on decoder status
|
||||
if self.decoder.is_connected() and self.status != ConnectionStatus.CONNECTED:
|
||||
logger.info(f"Stream {self.stream_id} reconnected")
|
||||
self.status = ConnectionStatus.CONNECTED
|
||||
elif not self.decoder.is_connected() and self.status == ConnectionStatus.CONNECTED:
|
||||
logger.warning(f"Stream {self.stream_id} disconnected")
|
||||
self.status = ConnectionStatus.DISCONNECTED
|
||||
|
||||
# Check decoder status
|
||||
if not self.decoder.is_connected():
|
||||
if self.status == ConnectionStatus.CONNECTED:
|
||||
logger.warning(f"Stream {self.stream_id} disconnected")
|
||||
self.status = ConnectionStatus.DISCONNECTED
|
||||
# Decoder will auto-reconnect, just update status
|
||||
await asyncio.sleep(1.0)
|
||||
if self.decoder.is_connected():
|
||||
logger.info(f"Stream {self.stream_id} reconnected")
|
||||
self.status = ConnectionStatus.CONNECTED
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame for {self.stream_id}: {e}", exc_info=True)
|
||||
self.error_queue.put(e)
|
||||
self.status = ConnectionStatus.ERROR
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in frame poller for {self.stream_id}: {e}", exc_info=True)
|
||||
await self.error_queue.put(e)
|
||||
self.status = ConnectionStatus.ERROR
|
||||
|
||||
# Sleep until next poll
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
|
||||
async def _handle_inference_result(self, result: Dict[str, Any]):
|
||||
def _handle_inference_result(self, result: Dict[str, Any]):
|
||||
"""
|
||||
Callback invoked by ModelController when inference is done.
|
||||
Runs tracking and emits final result.
|
||||
|
|
@ -190,12 +177,8 @@ class StreamConnection:
|
|||
# Extract detections
|
||||
detections = result["detections"]
|
||||
|
||||
# Run tracking (this is sync, so run in executor)
|
||||
loop = asyncio.get_event_loop()
|
||||
tracked_objects = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._run_tracking_sync(detections)
|
||||
)
|
||||
# Run tracking (synchronous)
|
||||
tracked_objects = self._run_tracking_sync(detections)
|
||||
|
||||
# Create tracking result
|
||||
tracking_result = TrackingResult(
|
||||
|
|
@ -208,11 +191,11 @@ class StreamConnection:
|
|||
)
|
||||
|
||||
# Emit to result queue
|
||||
await self.result_queue.put(tracking_result)
|
||||
self.result_queue.put(tracking_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
|
||||
await self.error_queue.put(e)
|
||||
self.error_queue.put(e)
|
||||
|
||||
def _run_tracking_sync(self, detections, min_confidence=0.7):
|
||||
"""
|
||||
|
|
@ -246,12 +229,12 @@ class StreamConnection:
|
|||
# Update tracker with detections (lightweight, no model dependency!)
|
||||
return self.tracking_controller.update(detection_list)
|
||||
|
||||
async def tracking_results(self) -> AsyncIterator[TrackingResult]:
|
||||
def tracking_results(self):
|
||||
"""
|
||||
Async generator for tracking results.
|
||||
Generator for tracking results (blocking iterator).
|
||||
|
||||
Usage:
|
||||
async for result in connection.tracking_results():
|
||||
for result in connection.tracking_results():
|
||||
print(result.tracked_objects)
|
||||
|
||||
Yields:
|
||||
|
|
@ -259,23 +242,23 @@ class StreamConnection:
|
|||
"""
|
||||
while self.running or not self.result_queue.empty():
|
||||
try:
|
||||
result = await asyncio.wait_for(self.result_queue.get(), timeout=1.0)
|
||||
result = self.result_queue.get(timeout=1.0)
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
async def errors(self) -> AsyncIterator[Exception]:
|
||||
def errors(self):
|
||||
"""
|
||||
Async generator for errors.
|
||||
Generator for errors (blocking iterator).
|
||||
|
||||
Yields:
|
||||
Exception objects as they occur
|
||||
"""
|
||||
while self.running or not self.error_queue.empty():
|
||||
try:
|
||||
error = await asyncio.wait_for(self.error_queue.get(), timeout=1.0)
|
||||
error = self.error_queue.get(timeout=1.0)
|
||||
yield error
|
||||
except asyncio.TimeoutError:
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
|
|
@ -342,7 +325,7 @@ class StreamConnectionManager:
|
|||
# State
|
||||
self.initialized = False
|
||||
|
||||
async def initialize(
|
||||
def initialize(
|
||||
self,
|
||||
model_path: str,
|
||||
model_id: str = "detector",
|
||||
|
|
@ -368,18 +351,14 @@ class StreamConnectionManager:
|
|||
"""
|
||||
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
|
||||
|
||||
# Load model
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model_repository.load_model(
|
||||
model_id,
|
||||
model_path,
|
||||
num_contexts=num_contexts,
|
||||
pt_input_shapes=pt_input_shapes,
|
||||
pt_precision=pt_precision,
|
||||
**pt_conversion_kwargs
|
||||
)
|
||||
# Load model (synchronous)
|
||||
self.model_repository.load_model(
|
||||
model_id,
|
||||
model_path,
|
||||
num_contexts=num_contexts,
|
||||
pt_input_shapes=pt_input_shapes,
|
||||
pt_precision=pt_precision,
|
||||
**pt_conversion_kwargs
|
||||
)
|
||||
logger.info(f"Loaded model {model_id} from {model_path}")
|
||||
|
||||
|
|
@ -392,7 +371,7 @@ class StreamConnectionManager:
|
|||
preprocess_fn=preprocess_fn,
|
||||
postprocess_fn=postprocess_fn,
|
||||
)
|
||||
await self.model_controller.start()
|
||||
self.model_controller.start()
|
||||
|
||||
# Don't create a shared tracking controller here
|
||||
# Each stream will get its own tracking controller to avoid track accumulation
|
||||
|
|
@ -402,7 +381,7 @@ class StreamConnectionManager:
|
|||
self.initialized = True
|
||||
logger.info("StreamConnectionManager initialized successfully")
|
||||
|
||||
async def connect_stream(
|
||||
def connect_stream(
|
||||
self,
|
||||
rtsp_url: str,
|
||||
stream_id: Optional[str] = None,
|
||||
|
|
@ -416,8 +395,8 @@ class StreamConnectionManager:
|
|||
Args:
|
||||
rtsp_url: RTSP stream URL
|
||||
stream_id: Optional stream identifier (auto-generated if not provided)
|
||||
on_tracking_result: Optional callback for tracking results (sync or async)
|
||||
on_error: Optional callback for errors (sync or async)
|
||||
on_tracking_result: Optional callback for tracking results (synchronous)
|
||||
on_error: Optional callback for errors (synchronous)
|
||||
buffer_size: Decoder buffer size (default: 30)
|
||||
|
||||
Returns:
|
||||
|
|
@ -466,22 +445,30 @@ class StreamConnectionManager:
|
|||
)
|
||||
|
||||
# Start connection
|
||||
await connection.start()
|
||||
connection.start()
|
||||
|
||||
# Store connection
|
||||
self.connections[stream_id] = connection
|
||||
|
||||
# Set up user callbacks if provided
|
||||
# Set up user callbacks if provided (run in separate threads)
|
||||
if on_tracking_result:
|
||||
asyncio.create_task(self._forward_results(connection, on_tracking_result))
|
||||
threading.Thread(
|
||||
target=self._forward_results,
|
||||
args=(connection, on_tracking_result),
|
||||
daemon=True
|
||||
).start()
|
||||
|
||||
if on_error:
|
||||
asyncio.create_task(self._forward_errors(connection, on_error))
|
||||
threading.Thread(
|
||||
target=self._forward_errors,
|
||||
args=(connection, on_error),
|
||||
daemon=True
|
||||
).start()
|
||||
|
||||
logger.info(f"Stream {stream_id} connected successfully")
|
||||
return connection
|
||||
|
||||
async def disconnect_stream(self, stream_id: str):
|
||||
def disconnect_stream(self, stream_id: str):
|
||||
"""
|
||||
Disconnect and cleanup a stream.
|
||||
|
||||
|
|
@ -490,27 +477,27 @@ class StreamConnectionManager:
|
|||
"""
|
||||
connection = self.connections.get(stream_id)
|
||||
if connection:
|
||||
await connection.stop()
|
||||
connection.stop()
|
||||
del self.connections[stream_id]
|
||||
logger.info(f"Stream {stream_id} disconnected")
|
||||
|
||||
async def disconnect_all(self):
|
||||
def disconnect_all(self):
|
||||
"""Disconnect all streams"""
|
||||
logger.info("Disconnecting all streams...")
|
||||
stream_ids = list(self.connections.keys())
|
||||
for stream_id in stream_ids:
|
||||
await self.disconnect_stream(stream_id)
|
||||
self.disconnect_stream(stream_id)
|
||||
|
||||
async def shutdown(self):
|
||||
def shutdown(self):
|
||||
"""Shutdown the manager and cleanup all resources"""
|
||||
logger.info("Shutting down StreamConnectionManager...")
|
||||
|
||||
# Disconnect all streams
|
||||
await self.disconnect_all()
|
||||
self.disconnect_all()
|
||||
|
||||
# Stop model controller
|
||||
if self.model_controller:
|
||||
await self.model_controller.stop()
|
||||
self.model_controller.stop()
|
||||
|
||||
# Note: Model repository cleanup is sync and may cause segfaults
|
||||
# Leaving cleanup to garbage collection for now
|
||||
|
|
@ -518,37 +505,31 @@ class StreamConnectionManager:
|
|||
self.initialized = False
|
||||
logger.info("StreamConnectionManager shutdown complete")
|
||||
|
||||
async def _forward_results(self, connection: StreamConnection, callback: Callable):
|
||||
def _forward_results(self, connection: StreamConnection, callback: Callable):
|
||||
"""
|
||||
Forward results from connection to user callback.
|
||||
|
||||
Args:
|
||||
connection: StreamConnection to listen to
|
||||
callback: User callback (sync or async)
|
||||
callback: User callback (synchronous)
|
||||
"""
|
||||
try:
|
||||
async for result in connection.tracking_results():
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(result)
|
||||
else:
|
||||
callback(result)
|
||||
for result in connection.tracking_results():
|
||||
callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True)
|
||||
|
||||
async def _forward_errors(self, connection: StreamConnection, callback: Callable):
|
||||
def _forward_errors(self, connection: StreamConnection, callback: Callable):
|
||||
"""
|
||||
Forward errors from connection to user callback.
|
||||
|
||||
Args:
|
||||
connection: StreamConnection to listen to
|
||||
callback: User callback (sync or async)
|
||||
callback: User callback (synchronous)
|
||||
"""
|
||||
try:
|
||||
async for error in connection.errors():
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(error)
|
||||
else:
|
||||
callback(error)
|
||||
for error in connection.errors():
|
||||
callback(error)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue