batch processing/event driven
This commit is contained in:
parent
e71316ef3d
commit
dd57b5a246
7 changed files with 2673 additions and 2 deletions
373
test_event_driven.py
Normal file
373
test_event_driven.py
Normal file
|
|
@ -0,0 +1,373 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for event-driven stream processing with batched inference.
|
||||
|
||||
This demonstrates the new AsyncIO-based API for connecting to RTSP streams,
|
||||
processing frames through batched inference, and receiving tracking results
|
||||
via callbacks and async generators.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from services import StreamConnectionManager, YOLOv8Utils, COCO_CLASSES
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Example 1: Simple callback pattern
|
||||
async def example_callback_pattern():
|
||||
"""Demonstrates the simple callback pattern for a single stream"""
|
||||
logger.info("=== Example 1: Callback Pattern ===")
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
camera_url = os.getenv('CAMERA_URL_1')
|
||||
if not camera_url:
|
||||
logger.error("CAMERA_URL_1 not found in .env file")
|
||||
return
|
||||
|
||||
# Create manager
|
||||
manager = StreamConnectionManager(
|
||||
gpu_id=0,
|
||||
batch_size=16,
|
||||
force_timeout=0.05, # 50ms
|
||||
poll_interval=0.01, # 100 FPS
|
||||
)
|
||||
|
||||
# Initialize with YOLOv8 model
|
||||
model_path = "models/yolov8n.trt" # Adjust path as needed
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f"Model file not found: {model_path}")
|
||||
return
|
||||
|
||||
await manager.initialize(
|
||||
model_path=model_path,
|
||||
model_id="yolo",
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess,
|
||||
)
|
||||
|
||||
# Define callback for tracking results
|
||||
def on_tracking_result(result):
|
||||
logger.info(f"[{result.stream_id}] Frame {result.metadata.get('frame_number', 0)}")
|
||||
logger.info(f" Timestamp: {result.timestamp:.3f}")
|
||||
logger.info(f" Tracked objects: {len(result.tracked_objects)}")
|
||||
|
||||
for obj in result.tracked_objects[:5]: # Show first 5
|
||||
class_name = COCO_CLASSES.get(obj.class_id, f"Class {obj.class_id}")
|
||||
logger.info(
|
||||
f" Track ID {obj.track_id}: {class_name}, "
|
||||
f"conf={obj.confidence:.2f}, bbox={obj.bbox}"
|
||||
)
|
||||
|
||||
def on_error(error):
|
||||
logger.error(f"Stream error: {error}")
|
||||
|
||||
# Connect to stream
|
||||
connection = await manager.connect_stream(
|
||||
rtsp_url=camera_url,
|
||||
stream_id="camera1",
|
||||
on_tracking_result=on_tracking_result,
|
||||
on_error=on_error,
|
||||
)
|
||||
|
||||
# Let it run for 30 seconds
|
||||
logger.info("Processing stream for 30 seconds...")
|
||||
await asyncio.sleep(30)
|
||||
|
||||
# Get statistics
|
||||
stats = manager.get_stats()
|
||||
logger.info("=== Statistics ===")
|
||||
logger.info(f"Manager: {stats['manager']}")
|
||||
logger.info(f"Model Controller: {stats['model_controller']}")
|
||||
logger.info(f"Connection: {stats['connections']['camera1']}")
|
||||
|
||||
# Cleanup
|
||||
await manager.shutdown()
|
||||
logger.info("Example 1 complete\n")
|
||||
|
||||
|
||||
# Example 2: Async generator pattern with multiple streams
|
||||
async def example_async_generator_pattern():
|
||||
"""Demonstrates async generator pattern for multiple streams"""
|
||||
logger.info("=== Example 2: Async Generator Pattern (Multiple Streams) ===")
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
camera_urls = []
|
||||
for i in range(1, 5): # Try to load 4 cameras
|
||||
url = os.getenv(f'CAMERA_URL_{i}')
|
||||
if url:
|
||||
camera_urls.append((url, f"camera{i}"))
|
||||
|
||||
if not camera_urls:
|
||||
logger.error("No camera URLs found in .env file")
|
||||
return
|
||||
|
||||
logger.info(f"Found {len(camera_urls)} camera(s)")
|
||||
|
||||
# Create manager with larger batch for multiple streams
|
||||
manager = StreamConnectionManager(
|
||||
gpu_id=0,
|
||||
batch_size=32, # Larger batch for multiple streams
|
||||
force_timeout=0.05,
|
||||
)
|
||||
|
||||
# Initialize
|
||||
model_path = "models/yolov8n.trt"
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f"Model file not found: {model_path}")
|
||||
return
|
||||
|
||||
await manager.initialize(
|
||||
model_path=model_path,
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess,
|
||||
)
|
||||
|
||||
# Connect to all streams
|
||||
connections = []
|
||||
for url, stream_id in camera_urls:
|
||||
try:
|
||||
connection = await manager.connect_stream(
|
||||
rtsp_url=url,
|
||||
stream_id=stream_id,
|
||||
)
|
||||
connections.append((connection, stream_id))
|
||||
logger.info(f"Connected to {stream_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to {stream_id}: {e}")
|
||||
|
||||
# Process each stream with async generator
|
||||
async def process_stream(connection, stream_name):
|
||||
"""Process results from a single stream"""
|
||||
frame_count = 0
|
||||
person_detections = 0
|
||||
|
||||
async for result in connection.tracking_results():
|
||||
frame_count += 1
|
||||
|
||||
# Count person detections (class_id 0 in COCO)
|
||||
for obj in result.tracked_objects:
|
||||
if obj.class_id == 0:
|
||||
person_detections += 1
|
||||
|
||||
# Log every 10th frame
|
||||
if frame_count % 10 == 0:
|
||||
logger.info(
|
||||
f"[{stream_name}] Processed {frame_count} frames, "
|
||||
f"{person_detections} person detections"
|
||||
)
|
||||
|
||||
# Stop after 100 frames
|
||||
if frame_count >= 100:
|
||||
break
|
||||
|
||||
# Run all streams concurrently
|
||||
tasks = [
|
||||
asyncio.create_task(process_stream(conn, name))
|
||||
for conn, name in connections
|
||||
]
|
||||
|
||||
# Wait for all tasks to complete
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Get final statistics
|
||||
stats = manager.get_stats()
|
||||
logger.info("\n=== Final Statistics ===")
|
||||
logger.info(f"Total connections: {stats['manager']['num_connections']}")
|
||||
logger.info(f"Frames processed: {stats['model_controller']['total_frames_processed']}")
|
||||
logger.info(f"Batches processed: {stats['model_controller']['total_batches_processed']}")
|
||||
logger.info(f"Avg batch size: {stats['model_controller']['avg_batch_size']:.2f}")
|
||||
|
||||
# Cleanup
|
||||
await manager.shutdown()
|
||||
logger.info("Example 2 complete\n")
|
||||
|
||||
|
||||
# Example 3: Queue-based pattern
|
||||
async def example_queue_pattern():
|
||||
"""Demonstrates direct queue access for custom processing"""
|
||||
logger.info("=== Example 3: Queue-Based Pattern ===")
|
||||
|
||||
# Load environment
|
||||
load_dotenv()
|
||||
camera_url = os.getenv('CAMERA_URL_1')
|
||||
if not camera_url:
|
||||
logger.error("CAMERA_URL_1 not found in .env file")
|
||||
return
|
||||
|
||||
# Create manager
|
||||
manager = StreamConnectionManager(gpu_id=0, batch_size=16)
|
||||
|
||||
# Initialize
|
||||
model_path = "models/yolov8n.trt"
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f"Model file not found: {model_path}")
|
||||
return
|
||||
|
||||
await manager.initialize(
|
||||
model_path=model_path,
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess,
|
||||
)
|
||||
|
||||
# Connect to stream (no callback)
|
||||
connection = await manager.connect_stream(
|
||||
rtsp_url=camera_url,
|
||||
stream_id="main_camera",
|
||||
)
|
||||
|
||||
# Use the built-in queue directly
|
||||
result_queue = connection.result_queue
|
||||
|
||||
# Process results from queue
|
||||
processed_count = 0
|
||||
while processed_count < 50: # Process 50 frames
|
||||
try:
|
||||
result = await asyncio.wait_for(result_queue.get(), timeout=5.0)
|
||||
processed_count += 1
|
||||
|
||||
# Custom processing
|
||||
has_person = any(obj.class_id == 0 for obj in result.tracked_objects)
|
||||
has_car = any(obj.class_id == 2 for obj in result.tracked_objects)
|
||||
|
||||
if has_person or has_car:
|
||||
logger.info(
|
||||
f"Frame {processed_count}: "
|
||||
f"Person={'Yes' if has_person else 'No'}, "
|
||||
f"Car={'Yes' if has_car else 'No'}"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timeout waiting for result")
|
||||
break
|
||||
|
||||
# Cleanup
|
||||
await manager.shutdown()
|
||||
logger.info("Example 3 complete\n")
|
||||
|
||||
|
||||
# Example 4: Performance monitoring
|
||||
async def example_performance_monitoring():
|
||||
"""Demonstrates real-time performance monitoring"""
|
||||
logger.info("=== Example 4: Performance Monitoring ===")
|
||||
|
||||
# Load environment
|
||||
load_dotenv()
|
||||
camera_url = os.getenv('CAMERA_URL_1')
|
||||
if not camera_url:
|
||||
logger.error("CAMERA_URL_1 not found in .env file")
|
||||
return
|
||||
|
||||
# Create manager
|
||||
manager = StreamConnectionManager(
|
||||
gpu_id=0,
|
||||
batch_size=16,
|
||||
force_timeout=0.05,
|
||||
)
|
||||
|
||||
# Initialize
|
||||
model_path = "models/yolov8n.trt"
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f"Model file not found: {model_path}")
|
||||
return
|
||||
|
||||
await manager.initialize(
|
||||
model_path=model_path,
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess,
|
||||
)
|
||||
|
||||
# Track performance metrics
|
||||
frame_times = []
|
||||
last_frame_time = None
|
||||
|
||||
def on_tracking_result(result):
|
||||
nonlocal last_frame_time
|
||||
current_time = time.time()
|
||||
|
||||
if last_frame_time is not None:
|
||||
frame_interval = current_time - last_frame_time
|
||||
frame_times.append(frame_interval)
|
||||
|
||||
last_frame_time = current_time
|
||||
|
||||
# Connect
|
||||
connection = await manager.connect_stream(
|
||||
rtsp_url=camera_url,
|
||||
on_tracking_result=on_tracking_result,
|
||||
)
|
||||
|
||||
# Monitor stats periodically
|
||||
for i in range(6): # Monitor for 60 seconds
|
||||
await asyncio.sleep(10)
|
||||
|
||||
stats = manager.get_stats()
|
||||
model_stats = stats['model_controller']
|
||||
conn_stats = stats['connections'].get('stream_0', {})
|
||||
|
||||
logger.info(f"\n=== Stats Update {i+1} ===")
|
||||
logger.info(f"Buffer A: {model_stats['buffer_a_size']} ({model_stats['buffer_a_state']})")
|
||||
logger.info(f"Buffer B: {model_stats['buffer_b_size']} ({model_stats['buffer_b_state']})")
|
||||
logger.info(f"Active buffer: {model_stats['active_buffer']}")
|
||||
logger.info(f"Total frames: {model_stats['total_frames_processed']}")
|
||||
logger.info(f"Total batches: {model_stats['total_batches_processed']}")
|
||||
logger.info(f"Avg batch size: {model_stats['avg_batch_size']:.2f}")
|
||||
logger.info(f"Decoder frames: {conn_stats.get('frame_count', 0)}")
|
||||
|
||||
if frame_times:
|
||||
avg_fps = 1.0 / (sum(frame_times) / len(frame_times))
|
||||
logger.info(f"Processing FPS: {avg_fps:.2f}")
|
||||
|
||||
# Cleanup
|
||||
await manager.shutdown()
|
||||
logger.info("Example 4 complete\n")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all examples"""
|
||||
logger.info("Starting event-driven stream processing tests\n")
|
||||
|
||||
# Choose which example to run
|
||||
choice = os.getenv('EXAMPLE', '1')
|
||||
|
||||
if choice == '1':
|
||||
await example_callback_pattern()
|
||||
elif choice == '2':
|
||||
await example_async_generator_pattern()
|
||||
elif choice == '3':
|
||||
await example_queue_pattern()
|
||||
elif choice == '4':
|
||||
await example_performance_monitoring()
|
||||
elif choice == 'all':
|
||||
await example_callback_pattern()
|
||||
await asyncio.sleep(2)
|
||||
await example_async_generator_pattern()
|
||||
await asyncio.sleep(2)
|
||||
await example_queue_pattern()
|
||||
await asyncio.sleep(2)
|
||||
await example_performance_monitoring()
|
||||
else:
|
||||
logger.error(f"Invalid choice: {choice}")
|
||||
logger.info("Set EXAMPLE env var to 1, 2, 3, 4, or 'all'")
|
||||
|
||||
logger.info("All tests complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\nInterrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}", exc_info=True)
|
||||
Loading…
Add table
Add a link
Reference in a new issue