event driven system

This commit is contained in:
Siwat Sirichai 2025-11-10 11:51:06 +07:00
parent 0c5f56c8a6
commit 3a47920186
10 changed files with 782 additions and 253 deletions

View file

@ -1,16 +1,17 @@
"""
StreamConnectionManager - Async orchestration for stream processing with batched inference.
StreamConnectionManager - Event-driven orchestration for stream processing with batched inference.
This module provides high-level connection management for multiple RTSP streams,
coordinating decoders, batched inference, and tracking with an event-driven API.
coordinating decoders, batched inference, and tracking with callbacks and threading.
"""
import asyncio
import threading
import time
from typing import Dict, Optional, Callable, AsyncIterator, Tuple, Any, List
from typing import Dict, Optional, Callable, Tuple, Any, List
from dataclasses import dataclass
from enum import Enum
import logging
import queue
import torch
@ -44,7 +45,7 @@ class StreamConnection:
"""
Represents a single stream connection with event emission.
This class wraps a StreamDecoder, polls frames asynchronously, submits them
This class wraps a StreamDecoder, polls frames in a thread, submits them
to the ModelController for batched inference, runs tracking, and emits results
via queues or callbacks.
@ -75,15 +76,19 @@ class StreamConnection:
self.last_frame_time = 0.0
# Event emission
self.result_queue: asyncio.Queue[TrackingResult] = asyncio.Queue()
self.error_queue: asyncio.Queue[Exception] = asyncio.Queue()
self.result_queue: queue.Queue[TrackingResult] = queue.Queue()
self.error_queue: queue.Queue[Exception] = queue.Queue()
# Tasks
self.poller_task: Optional[asyncio.Task] = None
# Event-driven state
self.running = False
async def start(self):
"""Start the connection (decoder and frame polling)"""
def start(self):
"""Start the connection (decoder with frame callback)"""
self.running = True
# Register callback for frame events from decoder
self.decoder.register_frame_callback(self._on_frame_decoded)
# Start decoder (runs in background thread)
self.decoder.start()
@ -93,7 +98,7 @@ class StreamConnection:
elapsed = 0.0
while elapsed < max_wait:
await asyncio.sleep(wait_interval)
time.sleep(wait_interval)
elapsed += wait_interval
if self.decoder.is_connected():
@ -105,21 +110,13 @@ class StreamConnection:
logger.warning(f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying...")
self.status = ConnectionStatus.CONNECTING
# Start frame polling task
self.running = True
self.poller_task = asyncio.create_task(self._frame_poller())
async def stop(self):
def stop(self):
"""Stop the connection and cleanup"""
logger.info(f"Stopping stream {self.stream_id}...")
self.running = False
if self.poller_task:
self.poller_task.cancel()
try:
await self.poller_task
except asyncio.CancelledError:
pass
# Unregister frame callback
self.decoder.unregister_frame_callback(self._on_frame_decoded)
# Stop decoder
self.decoder.stop()
@ -130,55 +127,45 @@ class StreamConnection:
self.status = ConnectionStatus.DISCONNECTED
logger.info(f"Stream {self.stream_id} stopped")
async def _frame_poller(self):
"""Poll frames from threaded decoder and submit to model controller"""
last_decoder_frame_count = -1
def _on_frame_decoded(self, frame: torch.Tensor):
"""
Event handler called by decoder when a new frame is decoded.
This is the event-driven replacement for polling.
while self.running:
try:
# Get current decoder frame count (no data transfer, just counter)
decoder_frame_count = self.decoder.get_frame_count()
Args:
frame: RGB frame tensor on GPU (3, H, W)
"""
if not self.running:
return
# Check if decoder has a new frame (avoid reprocessing same frame)
if decoder_frame_count > last_decoder_frame_count:
# Poll frame from decoder (zero-copy - stays in VRAM)
frame = self.decoder.get_latest_frame(rgb=True)
try:
self.last_frame_time = time.time()
self.frame_count += 1
if frame is not None:
last_decoder_frame_count = decoder_frame_count
self.last_frame_time = time.time()
self.frame_count += 1
# Submit to model controller for batched inference
self.model_controller.submit_frame(
stream_id=self.stream_id,
frame=frame,
metadata={
"frame_number": self.frame_count,
"shape": tuple(frame.shape),
}
)
# Submit to model controller for batched inference
await self.model_controller.submit_frame(
stream_id=self.stream_id,
frame=frame,
metadata={
"frame_number": self.frame_count,
"shape": tuple(frame.shape),
}
)
# Update connection status based on decoder status
if self.decoder.is_connected() and self.status != ConnectionStatus.CONNECTED:
logger.info(f"Stream {self.stream_id} reconnected")
self.status = ConnectionStatus.CONNECTED
elif not self.decoder.is_connected() and self.status == ConnectionStatus.CONNECTED:
logger.warning(f"Stream {self.stream_id} disconnected")
self.status = ConnectionStatus.DISCONNECTED
# Check decoder status
if not self.decoder.is_connected():
if self.status == ConnectionStatus.CONNECTED:
logger.warning(f"Stream {self.stream_id} disconnected")
self.status = ConnectionStatus.DISCONNECTED
# Decoder will auto-reconnect, just update status
await asyncio.sleep(1.0)
if self.decoder.is_connected():
logger.info(f"Stream {self.stream_id} reconnected")
self.status = ConnectionStatus.CONNECTED
except Exception as e:
logger.error(f"Error processing frame for {self.stream_id}: {e}", exc_info=True)
self.error_queue.put(e)
self.status = ConnectionStatus.ERROR
except Exception as e:
logger.error(f"Error in frame poller for {self.stream_id}: {e}", exc_info=True)
await self.error_queue.put(e)
self.status = ConnectionStatus.ERROR
# Sleep until next poll
await asyncio.sleep(self.poll_interval)
async def _handle_inference_result(self, result: Dict[str, Any]):
def _handle_inference_result(self, result: Dict[str, Any]):
"""
Callback invoked by ModelController when inference is done.
Runs tracking and emits final result.
@ -190,12 +177,8 @@ class StreamConnection:
# Extract detections
detections = result["detections"]
# Run tracking (this is sync, so run in executor)
loop = asyncio.get_event_loop()
tracked_objects = await loop.run_in_executor(
None,
lambda: self._run_tracking_sync(detections)
)
# Run tracking (synchronous)
tracked_objects = self._run_tracking_sync(detections)
# Create tracking result
tracking_result = TrackingResult(
@ -208,11 +191,11 @@ class StreamConnection:
)
# Emit to result queue
await self.result_queue.put(tracking_result)
self.result_queue.put(tracking_result)
except Exception as e:
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
await self.error_queue.put(e)
self.error_queue.put(e)
def _run_tracking_sync(self, detections, min_confidence=0.7):
"""
@ -246,12 +229,12 @@ class StreamConnection:
# Update tracker with detections (lightweight, no model dependency!)
return self.tracking_controller.update(detection_list)
async def tracking_results(self) -> AsyncIterator[TrackingResult]:
def tracking_results(self):
"""
Async generator for tracking results.
Generator for tracking results (blocking iterator).
Usage:
async for result in connection.tracking_results():
for result in connection.tracking_results():
print(result.tracked_objects)
Yields:
@ -259,23 +242,23 @@ class StreamConnection:
"""
while self.running or not self.result_queue.empty():
try:
result = await asyncio.wait_for(self.result_queue.get(), timeout=1.0)
result = self.result_queue.get(timeout=1.0)
yield result
except asyncio.TimeoutError:
except queue.Empty:
continue
async def errors(self) -> AsyncIterator[Exception]:
def errors(self):
"""
Async generator for errors.
Generator for errors (blocking iterator).
Yields:
Exception objects as they occur
"""
while self.running or not self.error_queue.empty():
try:
error = await asyncio.wait_for(self.error_queue.get(), timeout=1.0)
error = self.error_queue.get(timeout=1.0)
yield error
except asyncio.TimeoutError:
except queue.Empty:
continue
def get_stats(self) -> Dict[str, Any]:
@ -342,7 +325,7 @@ class StreamConnectionManager:
# State
self.initialized = False
async def initialize(
def initialize(
self,
model_path: str,
model_id: str = "detector",
@ -368,18 +351,14 @@ class StreamConnectionManager:
"""
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
# Load model
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
lambda: self.model_repository.load_model(
model_id,
model_path,
num_contexts=num_contexts,
pt_input_shapes=pt_input_shapes,
pt_precision=pt_precision,
**pt_conversion_kwargs
)
# Load model (synchronous)
self.model_repository.load_model(
model_id,
model_path,
num_contexts=num_contexts,
pt_input_shapes=pt_input_shapes,
pt_precision=pt_precision,
**pt_conversion_kwargs
)
logger.info(f"Loaded model {model_id} from {model_path}")
@ -392,7 +371,7 @@ class StreamConnectionManager:
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
await self.model_controller.start()
self.model_controller.start()
# Don't create a shared tracking controller here
# Each stream will get its own tracking controller to avoid track accumulation
@ -402,7 +381,7 @@ class StreamConnectionManager:
self.initialized = True
logger.info("StreamConnectionManager initialized successfully")
async def connect_stream(
def connect_stream(
self,
rtsp_url: str,
stream_id: Optional[str] = None,
@ -416,8 +395,8 @@ class StreamConnectionManager:
Args:
rtsp_url: RTSP stream URL
stream_id: Optional stream identifier (auto-generated if not provided)
on_tracking_result: Optional callback for tracking results (sync or async)
on_error: Optional callback for errors (sync or async)
on_tracking_result: Optional callback for tracking results (synchronous)
on_error: Optional callback for errors (synchronous)
buffer_size: Decoder buffer size (default: 30)
Returns:
@ -466,22 +445,30 @@ class StreamConnectionManager:
)
# Start connection
await connection.start()
connection.start()
# Store connection
self.connections[stream_id] = connection
# Set up user callbacks if provided
# Set up user callbacks if provided (run in separate threads)
if on_tracking_result:
asyncio.create_task(self._forward_results(connection, on_tracking_result))
threading.Thread(
target=self._forward_results,
args=(connection, on_tracking_result),
daemon=True
).start()
if on_error:
asyncio.create_task(self._forward_errors(connection, on_error))
threading.Thread(
target=self._forward_errors,
args=(connection, on_error),
daemon=True
).start()
logger.info(f"Stream {stream_id} connected successfully")
return connection
async def disconnect_stream(self, stream_id: str):
def disconnect_stream(self, stream_id: str):
"""
Disconnect and cleanup a stream.
@ -490,27 +477,27 @@ class StreamConnectionManager:
"""
connection = self.connections.get(stream_id)
if connection:
await connection.stop()
connection.stop()
del self.connections[stream_id]
logger.info(f"Stream {stream_id} disconnected")
async def disconnect_all(self):
def disconnect_all(self):
"""Disconnect all streams"""
logger.info("Disconnecting all streams...")
stream_ids = list(self.connections.keys())
for stream_id in stream_ids:
await self.disconnect_stream(stream_id)
self.disconnect_stream(stream_id)
async def shutdown(self):
def shutdown(self):
"""Shutdown the manager and cleanup all resources"""
logger.info("Shutting down StreamConnectionManager...")
# Disconnect all streams
await self.disconnect_all()
self.disconnect_all()
# Stop model controller
if self.model_controller:
await self.model_controller.stop()
self.model_controller.stop()
# Note: Model repository cleanup is sync and may cause segfaults
# Leaving cleanup to garbage collection for now
@ -518,37 +505,31 @@ class StreamConnectionManager:
self.initialized = False
logger.info("StreamConnectionManager shutdown complete")
async def _forward_results(self, connection: StreamConnection, callback: Callable):
def _forward_results(self, connection: StreamConnection, callback: Callable):
"""
Forward results from connection to user callback.
Args:
connection: StreamConnection to listen to
callback: User callback (sync or async)
callback: User callback (synchronous)
"""
try:
async for result in connection.tracking_results():
if asyncio.iscoroutinefunction(callback):
await callback(result)
else:
callback(result)
for result in connection.tracking_results():
callback(result)
except Exception as e:
logger.error(f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True)
async def _forward_errors(self, connection: StreamConnection, callback: Callable):
def _forward_errors(self, connection: StreamConnection, callback: Callable):
"""
Forward errors from connection to user callback.
Args:
connection: StreamConnection to listen to
callback: User callback (sync or async)
callback: User callback (synchronous)
"""
try:
async for error in connection.errors():
if asyncio.iscoroutinefunction(callback):
await callback(error)
else:
callback(error)
for error in connection.errors():
callback(error)
except Exception as e:
logger.error(f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True)