event driven system
This commit is contained in:
parent
0c5f56c8a6
commit
3a47920186
10 changed files with 782 additions and 253 deletions
|
|
@ -8,4 +8,6 @@
|
||||||
|
|
||||||
- Buffer Size for EACH CAMERA should be one, batch is for when processing multiple cameras, when new frame comes in, replace the old one at the old index if exist. This way the real time requirement is satisfied. We need a data structure to track this in addition to ring buffer tho.
|
- Buffer Size for EACH CAMERA should be one, batch is for when processing multiple cameras, when new frame comes in, replace the old one at the old index if exist. This way the real time requirement is satisfied. We need a data structure to track this in addition to ring buffer tho.
|
||||||
|
|
||||||
|
- Buffer should flush after TARGET_FRAME_INTERVAL_MS
|
||||||
|
|
||||||
- Blurry asyncio archtecture, require documentations
|
- Blurry asyncio archtecture, require documentations
|
||||||
4
app.py
4
app.py
|
|
@ -4,10 +4,10 @@ app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def root():
|
def root():
|
||||||
return {"message": "Hello World"}
|
return {"message": "Hello World"}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health_check():
|
def health_check():
|
||||||
return {"status": "healthy"}
|
return {"status": "healthy"}
|
||||||
|
|
|
||||||
91
scripts/decoder_test.py
Normal file
91
scripts/decoder_test.py
Normal file
|
|
@ -0,0 +1,91 @@
|
||||||
|
"""
|
||||||
|
Test decoder frame rate in isolation without any processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
from services.stream_decoder import StreamDecoderFactory
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
GPU_ID = 0
|
||||||
|
STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
||||||
|
MAX_FRAMES = 100
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("Decoder Frame Rate Test (No Processing)")
|
||||||
|
print("=" * 80)
|
||||||
|
print(f"\nStream: {STREAM_URL}")
|
||||||
|
print(f"Monitoring for {MAX_FRAMES} frames...\n")
|
||||||
|
|
||||||
|
# Create decoder
|
||||||
|
factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
||||||
|
decoder = factory.create_decoder(STREAM_URL, buffer_size=30)
|
||||||
|
|
||||||
|
# Start decoder
|
||||||
|
decoder.start()
|
||||||
|
|
||||||
|
# Wait for connection
|
||||||
|
print("Waiting for connection...")
|
||||||
|
max_wait = 10
|
||||||
|
waited = 0
|
||||||
|
while not decoder.is_connected() and waited < max_wait:
|
||||||
|
time.sleep(0.5)
|
||||||
|
waited += 0.5
|
||||||
|
|
||||||
|
if not decoder.is_connected():
|
||||||
|
print(f"Failed to connect after {max_wait}s!")
|
||||||
|
decoder.stop()
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"✓ Connected\n")
|
||||||
|
print("Monitoring frame arrivals...")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
last_count = 0
|
||||||
|
frame_times = []
|
||||||
|
start_time = time.time()
|
||||||
|
last_frame_time = start_time
|
||||||
|
|
||||||
|
while decoder.get_frame_count() < MAX_FRAMES:
|
||||||
|
current_count = decoder.get_frame_count()
|
||||||
|
|
||||||
|
if current_count > last_count:
|
||||||
|
current_time = time.time()
|
||||||
|
interval = (current_time - last_frame_time) * 1000
|
||||||
|
|
||||||
|
frame_times.append(interval)
|
||||||
|
print(f"Frame {current_count:3d}: interval={interval:6.1f}ms")
|
||||||
|
|
||||||
|
last_count = current_count
|
||||||
|
last_frame_time = current_time
|
||||||
|
|
||||||
|
time.sleep(0.001) # 1ms poll
|
||||||
|
|
||||||
|
# Stop decoder
|
||||||
|
decoder.stop()
|
||||||
|
|
||||||
|
# Analysis
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
actual_fps = MAX_FRAMES / elapsed
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("DECODER PERFORMANCE")
|
||||||
|
print("=" * 80)
|
||||||
|
print(f"\nFrames received: {MAX_FRAMES}")
|
||||||
|
print(f"Time: {elapsed:.1f}s")
|
||||||
|
print(f"Actual FPS: {actual_fps:.2f}")
|
||||||
|
print(f"\nFrame Intervals:")
|
||||||
|
print(f" Min: {min(frame_times[1:]):.1f}ms") # Skip first
|
||||||
|
print(f" Max: {max(frame_times[1:]):.1f}ms")
|
||||||
|
print(f" Avg: {sum(frame_times[1:])/len(frame_times[1:]):.1f}ms")
|
||||||
|
print(f" Expected (6 FPS): 166.7ms")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
165
scripts/detailed_profiling.py
Normal file
165
scripts/detailed_profiling.py
Normal file
|
|
@ -0,0 +1,165 @@
|
||||||
|
"""
|
||||||
|
Detailed profiling with timing instrumentation to find the exact bottleneck.
|
||||||
|
|
||||||
|
This script adds detailed timing logs at each stage of the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
from services import (
|
||||||
|
StreamConnectionManager,
|
||||||
|
YOLOv8Utils,
|
||||||
|
)
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Timing statistics
|
||||||
|
timings = defaultdict(list)
|
||||||
|
frame_timestamps = {}
|
||||||
|
|
||||||
|
def log_timing(event, frame_id=None, extra_data=None):
|
||||||
|
"""Log timing event"""
|
||||||
|
timestamp = time.time()
|
||||||
|
timings[event].append(timestamp)
|
||||||
|
if frame_id is not None:
|
||||||
|
if frame_id not in frame_timestamps:
|
||||||
|
frame_timestamps[frame_id] = {}
|
||||||
|
frame_timestamps[frame_id][event] = timestamp
|
||||||
|
if extra_data:
|
||||||
|
frame_timestamps[frame_id].update(extra_data)
|
||||||
|
|
||||||
|
async def instrumented_main():
|
||||||
|
"""Instrumented version of profiling script"""
|
||||||
|
print("=" * 80)
|
||||||
|
print("Detailed Profiling: Event-Driven GPU-Accelerated Object Tracking")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
GPU_ID = 0
|
||||||
|
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt"
|
||||||
|
STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
||||||
|
BATCH_SIZE = 4
|
||||||
|
FORCE_TIMEOUT = 0.05
|
||||||
|
MAX_FRAMES = 50 # Fewer frames for detailed analysis
|
||||||
|
|
||||||
|
print(f"\nConfiguration:")
|
||||||
|
print(f" GPU: {GPU_ID}")
|
||||||
|
print(f" Model: {MODEL_PATH}")
|
||||||
|
print(f" Stream: {STREAM_URL}")
|
||||||
|
print(f" Batch size: {BATCH_SIZE}")
|
||||||
|
print(f" Max frames: {MAX_FRAMES}\n")
|
||||||
|
|
||||||
|
# Create manager
|
||||||
|
print("[1/3] Creating StreamConnectionManager...")
|
||||||
|
manager = StreamConnectionManager(
|
||||||
|
gpu_id=GPU_ID,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
force_timeout=FORCE_TIMEOUT,
|
||||||
|
enable_pt_conversion=True
|
||||||
|
)
|
||||||
|
print("✓ Manager created")
|
||||||
|
|
||||||
|
# Initialize
|
||||||
|
print("\n[2/3] Initializing...")
|
||||||
|
await manager.initialize(
|
||||||
|
model_path=MODEL_PATH,
|
||||||
|
model_id="detector",
|
||||||
|
preprocess_fn=YOLOv8Utils.preprocess,
|
||||||
|
postprocess_fn=YOLOv8Utils.postprocess,
|
||||||
|
num_contexts=4,
|
||||||
|
pt_input_shapes={"images": (1, 3, 640, 640)},
|
||||||
|
pt_precision=torch.float16
|
||||||
|
)
|
||||||
|
print("✓ Initialized")
|
||||||
|
|
||||||
|
# Connect stream
|
||||||
|
print("\n[3/3] Connecting to stream...")
|
||||||
|
connection = await manager.connect_stream(
|
||||||
|
rtsp_url=STREAM_URL,
|
||||||
|
stream_id="camera_1",
|
||||||
|
buffer_size=30
|
||||||
|
)
|
||||||
|
print("✓ Connected\n")
|
||||||
|
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
print(f"Running instrumented profiling for {MAX_FRAMES} frames...")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
|
result_count = 0
|
||||||
|
start_time = time.time()
|
||||||
|
last_result_time = start_time
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for result in connection.tracking_results():
|
||||||
|
current_time = time.time()
|
||||||
|
result_interval = (current_time - last_result_time) * 1000
|
||||||
|
|
||||||
|
result_count += 1
|
||||||
|
frame_id = result_count
|
||||||
|
|
||||||
|
log_timing('result_received', frame_id, {
|
||||||
|
'interval_ms': result_interval,
|
||||||
|
'num_objects': len(result.tracked_objects),
|
||||||
|
'num_detections': len(result.detections)
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f"Frame {result_count:3d}: interval={result_interval:6.1f}ms, "
|
||||||
|
f"objects={len(result.tracked_objects):2d}, "
|
||||||
|
f"detections={len(result.detections):2d}")
|
||||||
|
|
||||||
|
last_result_time = current_time
|
||||||
|
|
||||||
|
if result_count >= MAX_FRAMES:
|
||||||
|
print(f"\n✓ Reached max frames limit ({MAX_FRAMES})")
|
||||||
|
break
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print(f"\n✓ Interrupted by user")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("Cleanup")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
await connection.stop()
|
||||||
|
await manager.shutdown()
|
||||||
|
print("✓ Stopped")
|
||||||
|
|
||||||
|
# Analysis
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
avg_fps = result_count / elapsed if elapsed > 0 else 0
|
||||||
|
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("TIMING ANALYSIS")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
print(f"\nOverall:")
|
||||||
|
print(f" Results: {result_count}")
|
||||||
|
print(f" Time: {elapsed:.1f}s")
|
||||||
|
print(f" FPS: {avg_fps:.2f}")
|
||||||
|
|
||||||
|
# Frame intervals
|
||||||
|
if len(frame_timestamps) > 1:
|
||||||
|
intervals = []
|
||||||
|
for i in range(2, result_count + 1):
|
||||||
|
if i in frame_timestamps and (i-1) in frame_timestamps:
|
||||||
|
interval = (frame_timestamps[i]['result_received'] -
|
||||||
|
frame_timestamps[i-1]['result_received']) * 1000
|
||||||
|
intervals.append(interval)
|
||||||
|
|
||||||
|
if intervals:
|
||||||
|
print(f"\nFrame Intervals:")
|
||||||
|
print(f" Min: {min(intervals):.1f}ms")
|
||||||
|
print(f" Max: {max(intervals):.1f}ms")
|
||||||
|
print(f" Avg: {sum(intervals)/len(intervals):.1f}ms")
|
||||||
|
print(f" Expected (6 FPS): 166.7ms")
|
||||||
|
print(f" Deviation: {(sum(intervals)/len(intervals) - 166.7):.1f}ms")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(instrumented_main())
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Profiling script for the real-time object tracking pipeline.
|
Profiling script for the real-time object tracking pipeline.
|
||||||
|
|
||||||
|
|
|
||||||
149
scripts/timing_instrumentation.py
Normal file
149
scripts/timing_instrumentation.py
Normal file
|
|
@ -0,0 +1,149 @@
|
||||||
|
"""
|
||||||
|
Add timing instrumentation to track where time is spent in the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
# Monkey patch to add timing
|
||||||
|
original_handle_result = None
|
||||||
|
original_run_tracking = None
|
||||||
|
original_infer = None
|
||||||
|
|
||||||
|
timings = []
|
||||||
|
|
||||||
|
def patch_timing():
|
||||||
|
"""Add timing instrumentation to key functions"""
|
||||||
|
from services import stream_connection_manager, model_repository
|
||||||
|
|
||||||
|
global original_handle_result, original_run_tracking, original_infer
|
||||||
|
|
||||||
|
# Patch _handle_inference_result
|
||||||
|
original_handle_result = stream_connection_manager.StreamConnection._handle_inference_result
|
||||||
|
async def timed_handle_result(self, result):
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
await original_handle_result(self, result)
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
timings.append(('handle_result', (t1 - t0) * 1000))
|
||||||
|
stream_connection_manager.StreamConnection._handle_inference_result = timed_handle_result
|
||||||
|
|
||||||
|
# Patch _run_tracking_sync
|
||||||
|
original_run_tracking = stream_connection_manager.StreamConnection._run_tracking_sync
|
||||||
|
def timed_run_tracking(self, detections, min_confidence=0.7):
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
result = original_run_tracking(self, detections, min_confidence)
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
timings.append(('tracking', (t1 - t0) * 1000))
|
||||||
|
return result
|
||||||
|
stream_connection_manager.StreamConnection._run_tracking_sync = timed_run_tracking
|
||||||
|
|
||||||
|
# Patch infer
|
||||||
|
original_infer = model_repository.TensorRTModelRepository.infer
|
||||||
|
def timed_infer(self, model_id, inputs, synchronize=True):
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
result = original_infer(self, model_id, inputs, synchronize)
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
timings.append(('infer', (t1 - t0) * 1000))
|
||||||
|
return result
|
||||||
|
model_repository.TensorRTModelRepository.infer = timed_infer
|
||||||
|
|
||||||
|
async def instrumented_main():
|
||||||
|
"""Instrumented profiling"""
|
||||||
|
from services import StreamConnectionManager, YOLOv8Utils
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("Timing Instrumentation")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Patch before creating manager
|
||||||
|
patch_timing()
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
GPU_ID = 0
|
||||||
|
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt"
|
||||||
|
STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
||||||
|
BATCH_SIZE = 4
|
||||||
|
FORCE_TIMEOUT = 0.05
|
||||||
|
MAX_FRAMES = 30
|
||||||
|
|
||||||
|
print(f"\nConfiguration: GPU={GPU_ID}, BATCH={BATCH_SIZE}, MAX={MAX_FRAMES}\n")
|
||||||
|
|
||||||
|
# Create and initialize manager
|
||||||
|
manager = StreamConnectionManager(
|
||||||
|
gpu_id=GPU_ID,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
force_timeout=FORCE_TIMEOUT,
|
||||||
|
enable_pt_conversion=True
|
||||||
|
)
|
||||||
|
|
||||||
|
await manager.initialize(
|
||||||
|
model_path=MODEL_PATH,
|
||||||
|
model_id="detector",
|
||||||
|
preprocess_fn=YOLOv8Utils.preprocess,
|
||||||
|
postprocess_fn=YOLOv8Utils.postprocess,
|
||||||
|
num_contexts=4,
|
||||||
|
pt_input_shapes={"images": (1, 3, 640, 640)},
|
||||||
|
pt_precision=torch.float16
|
||||||
|
)
|
||||||
|
|
||||||
|
connection = await manager.connect_stream(
|
||||||
|
rtsp_url=STREAM_URL,
|
||||||
|
stream_id="camera_1",
|
||||||
|
buffer_size=30
|
||||||
|
)
|
||||||
|
print("✓ Connected\n")
|
||||||
|
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
print(f"Processing {MAX_FRAMES} frames with timing...")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
|
result_count = 0
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for result in connection.tracking_results():
|
||||||
|
result_count += 1
|
||||||
|
if result_count >= MAX_FRAMES:
|
||||||
|
break
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await connection.stop()
|
||||||
|
await manager.shutdown()
|
||||||
|
|
||||||
|
# Analysis
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
print(f"\nProcessed {result_count} frames in {elapsed:.1f}s ({result_count/elapsed:.2f} FPS)\n")
|
||||||
|
|
||||||
|
# Analyze timings
|
||||||
|
from collections import defaultdict
|
||||||
|
timing_stats = defaultdict(list)
|
||||||
|
for operation, duration in timings:
|
||||||
|
timing_stats[operation].append(duration)
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("TIMING BREAKDOWN")
|
||||||
|
print("=" * 80)
|
||||||
|
for operation in ['infer', 'tracking', 'handle_result']:
|
||||||
|
if operation in timing_stats:
|
||||||
|
times = timing_stats[operation]
|
||||||
|
print(f"\n{operation}:")
|
||||||
|
print(f" Calls: {len(times)}")
|
||||||
|
print(f" Min: {min(times):.2f}ms")
|
||||||
|
print(f" Max: {max(times):.2f}ms")
|
||||||
|
print(f" Avg: {sum(times)/len(times):.2f}ms")
|
||||||
|
print(f" Total: {sum(times):.2f}ms")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(instrumented_main())
|
||||||
|
|
@ -1,17 +1,18 @@
|
||||||
"""
|
"""
|
||||||
ModelController - Async batching layer with ping-pong buffers for inference.
|
ModelController - Event-driven batching layer with ping-pong buffers for inference.
|
||||||
|
|
||||||
This module provides batched inference coordination using ping-pong circular buffers
|
This module provides batched inference coordination using ping-pong circular buffers
|
||||||
with force-switch timeout mechanism.
|
with force-switch timeout mechanism using threading and callbacks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import threading
|
||||||
import torch
|
import torch
|
||||||
from typing import Dict, List, Optional, Callable, Any
|
from typing import Dict, List, Optional, Callable, Any
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
import queue
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -43,7 +44,7 @@ class ModelController:
|
||||||
Features:
|
Features:
|
||||||
- Ping-pong circular buffers (BufferA/BufferB)
|
- Ping-pong circular buffers (BufferA/BufferB)
|
||||||
- Force-switch timeout to prevent batch starvation
|
- Force-switch timeout to prevent batch starvation
|
||||||
- Async event-driven processing
|
- Event-driven processing with callbacks
|
||||||
- Thread-safe frame submission
|
- Thread-safe frame submission
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -90,14 +91,15 @@ class ModelController:
|
||||||
self.buffer_a_state = BufferState.IDLE
|
self.buffer_a_state = BufferState.IDLE
|
||||||
self.buffer_b_state = BufferState.IDLE
|
self.buffer_b_state = BufferState.IDLE
|
||||||
|
|
||||||
# Async coordination
|
# Threading coordination
|
||||||
self.buffer_lock = asyncio.Lock()
|
self.buffer_lock = threading.RLock()
|
||||||
self.last_submit_time = time.time()
|
self.last_submit_time = time.time()
|
||||||
|
|
||||||
# Tasks
|
# Threads
|
||||||
self.timeout_task: Optional[asyncio.Task] = None
|
self.timeout_thread: Optional[threading.Thread] = None
|
||||||
self.processor_task: Optional[asyncio.Task] = None
|
self.processor_threads: Dict[str, threading.Thread] = {}
|
||||||
self.running = False
|
self.running = False
|
||||||
|
self.stop_event = threading.Event()
|
||||||
|
|
||||||
# Result callbacks (stream_id -> callback)
|
# Result callbacks (stream_id -> callback)
|
||||||
self.result_callbacks: Dict[str, Callable] = {}
|
self.result_callbacks: Dict[str, Callable] = {}
|
||||||
|
|
@ -130,42 +132,46 @@ class ModelController:
|
||||||
logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1")
|
logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1")
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
async def start(self):
|
def start(self):
|
||||||
"""Start the controller background tasks"""
|
"""Start the controller background threads"""
|
||||||
if self.running:
|
if self.running:
|
||||||
logger.warning("ModelController already running")
|
logger.warning("ModelController already running")
|
||||||
return
|
return
|
||||||
|
|
||||||
self.running = True
|
self.running = True
|
||||||
self.timeout_task = asyncio.create_task(self._timeout_monitor())
|
self.stop_event.clear()
|
||||||
self.processor_task = asyncio.create_task(self._batch_processor())
|
|
||||||
|
# Start timeout monitor thread
|
||||||
|
self.timeout_thread = threading.Thread(target=self._timeout_monitor, daemon=True)
|
||||||
|
self.timeout_thread.start()
|
||||||
|
|
||||||
|
# Start processor threads for each buffer
|
||||||
|
self.processor_threads['A'] = threading.Thread(target=self._batch_processor, args=('A',), daemon=True)
|
||||||
|
self.processor_threads['B'] = threading.Thread(target=self._batch_processor, args=('B',), daemon=True)
|
||||||
|
self.processor_threads['A'].start()
|
||||||
|
self.processor_threads['B'].start()
|
||||||
|
|
||||||
logger.info("ModelController started")
|
logger.info("ModelController started")
|
||||||
|
|
||||||
async def stop(self):
|
def stop(self):
|
||||||
"""Stop the controller and cleanup"""
|
"""Stop the controller and cleanup"""
|
||||||
if not self.running:
|
if not self.running:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Stopping ModelController...")
|
logger.info("Stopping ModelController...")
|
||||||
self.running = False
|
self.running = False
|
||||||
|
self.stop_event.set()
|
||||||
|
|
||||||
# Cancel tasks
|
# Wait for threads to finish
|
||||||
if self.timeout_task:
|
if self.timeout_thread and self.timeout_thread.is_alive():
|
||||||
self.timeout_task.cancel()
|
self.timeout_thread.join(timeout=2.0)
|
||||||
try:
|
|
||||||
await self.timeout_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if self.processor_task:
|
for thread in self.processor_threads.values():
|
||||||
self.processor_task.cancel()
|
if thread and thread.is_alive():
|
||||||
try:
|
thread.join(timeout=2.0)
|
||||||
await self.processor_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Process any remaining frames
|
# Process any remaining frames
|
||||||
await self._process_remaining_buffers()
|
self._process_remaining_buffers()
|
||||||
logger.info("ModelController stopped")
|
logger.info("ModelController stopped")
|
||||||
|
|
||||||
def register_callback(self, stream_id: str, callback: Callable):
|
def register_callback(self, stream_id: str, callback: Callable):
|
||||||
|
|
@ -189,7 +195,7 @@ class ModelController:
|
||||||
self.result_callbacks.pop(stream_id, None)
|
self.result_callbacks.pop(stream_id, None)
|
||||||
logger.debug(f"Unregistered callback for stream: {stream_id}")
|
logger.debug(f"Unregistered callback for stream: {stream_id}")
|
||||||
|
|
||||||
async def submit_frame(
|
def submit_frame(
|
||||||
self,
|
self,
|
||||||
stream_id: str,
|
stream_id: str,
|
||||||
frame: torch.Tensor,
|
frame: torch.Tensor,
|
||||||
|
|
@ -203,7 +209,7 @@ class ModelController:
|
||||||
frame: GPU tensor (3, H, W) or (C, H, W)
|
frame: GPU tensor (3, H, W) or (C, H, W)
|
||||||
metadata: Optional metadata to attach to the frame
|
metadata: Optional metadata to attach to the frame
|
||||||
"""
|
"""
|
||||||
async with self.buffer_lock:
|
with self.buffer_lock:
|
||||||
batch_frame = BatchFrame(
|
batch_frame = BatchFrame(
|
||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
frame=frame,
|
frame=frame,
|
||||||
|
|
@ -225,23 +231,21 @@ class ModelController:
|
||||||
|
|
||||||
# Check if we should immediately swap (batch full)
|
# Check if we should immediately swap (batch full)
|
||||||
if buffer_size >= self.batch_size:
|
if buffer_size >= self.batch_size:
|
||||||
await self._try_swap_buffers()
|
self._try_swap_buffers()
|
||||||
|
|
||||||
async def _timeout_monitor(self):
|
def _timeout_monitor(self):
|
||||||
"""Monitor force-switch timeout"""
|
"""Monitor force-switch timeout"""
|
||||||
while self.running:
|
while self.running and not self.stop_event.wait(0.01): # Check every 10ms
|
||||||
await asyncio.sleep(0.01) # Check every 10ms
|
with self.buffer_lock:
|
||||||
|
|
||||||
async with self.buffer_lock:
|
|
||||||
time_since_submit = time.time() - self.last_submit_time
|
time_since_submit = time.time() - self.last_submit_time
|
||||||
|
|
||||||
# Check if timeout expired and we have frames waiting
|
# Check if timeout expired and we have frames waiting
|
||||||
if time_since_submit >= self.force_timeout:
|
if time_since_submit >= self.force_timeout:
|
||||||
active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
||||||
if len(active_buffer) > 0:
|
if len(active_buffer) > 0:
|
||||||
await self._try_swap_buffers()
|
self._try_swap_buffers()
|
||||||
|
|
||||||
async def _try_swap_buffers(self):
|
def _try_swap_buffers(self):
|
||||||
"""
|
"""
|
||||||
Attempt to swap ping-pong buffers.
|
Attempt to swap ping-pong buffers.
|
||||||
Only swaps if the inactive buffer is not currently processing.
|
Only swaps if the inactive buffer is not currently processing.
|
||||||
|
|
@ -266,20 +270,22 @@ class ModelController:
|
||||||
|
|
||||||
logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})")
|
logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})")
|
||||||
|
|
||||||
async def _batch_processor(self):
|
def _batch_processor(self, buffer_name: str):
|
||||||
"""Background task that processes batches when available"""
|
"""Background thread that processes a specific buffer when available"""
|
||||||
while self.running:
|
while self.running and not self.stop_event.is_set():
|
||||||
await asyncio.sleep(0.001) # Check every 1ms
|
time.sleep(0.001) # Check every 1ms
|
||||||
|
|
||||||
# Check if buffer A needs processing
|
# Check if this buffer needs processing
|
||||||
if self.buffer_a_state == BufferState.PROCESSING:
|
with self.buffer_lock:
|
||||||
await self._process_buffer("A")
|
if buffer_name == "A":
|
||||||
|
should_process = self.buffer_a_state == BufferState.PROCESSING
|
||||||
|
else:
|
||||||
|
should_process = self.buffer_b_state == BufferState.PROCESSING
|
||||||
|
|
||||||
# Check if buffer B needs processing
|
if should_process:
|
||||||
if self.buffer_b_state == BufferState.PROCESSING:
|
self._process_buffer(buffer_name)
|
||||||
await self._process_buffer("B")
|
|
||||||
|
|
||||||
async def _process_buffer(self, buffer_name: str):
|
def _process_buffer(self, buffer_name: str):
|
||||||
"""
|
"""
|
||||||
Process a buffer through inference.
|
Process a buffer through inference.
|
||||||
|
|
||||||
|
|
@ -287,7 +293,7 @@ class ModelController:
|
||||||
buffer_name: "A" or "B"
|
buffer_name: "A" or "B"
|
||||||
"""
|
"""
|
||||||
# Extract buffer to process
|
# Extract buffer to process
|
||||||
async with self.buffer_lock:
|
with self.buffer_lock:
|
||||||
if buffer_name == "A":
|
if buffer_name == "A":
|
||||||
batch = self.buffer_a.copy()
|
batch = self.buffer_a.copy()
|
||||||
self.buffer_a.clear()
|
self.buffer_a.clear()
|
||||||
|
|
@ -297,7 +303,7 @@ class ModelController:
|
||||||
|
|
||||||
if len(batch) == 0:
|
if len(batch) == 0:
|
||||||
# Mark as idle and return
|
# Mark as idle and return
|
||||||
async with self.buffer_lock:
|
with self.buffer_lock:
|
||||||
if buffer_name == "A":
|
if buffer_name == "A":
|
||||||
self.buffer_a_state = BufferState.IDLE
|
self.buffer_a_state = BufferState.IDLE
|
||||||
else:
|
else:
|
||||||
|
|
@ -307,7 +313,7 @@ class ModelController:
|
||||||
# Process batch (outside lock to allow concurrent submissions)
|
# Process batch (outside lock to allow concurrent submissions)
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results = await self._run_batch_inference(batch)
|
results = self._run_batch_inference(batch)
|
||||||
inference_time = time.time() - start_time
|
inference_time = time.time() - start_time
|
||||||
|
|
||||||
# Update statistics
|
# Update statistics
|
||||||
|
|
@ -323,27 +329,24 @@ class ModelController:
|
||||||
for batch_frame, result in zip(batch, results):
|
for batch_frame, result in zip(batch, results):
|
||||||
callback = self.result_callbacks.get(batch_frame.stream_id)
|
callback = self.result_callbacks.get(batch_frame.stream_id)
|
||||||
if callback:
|
if callback:
|
||||||
# Schedule callback asynchronously
|
# Call callback directly (synchronous)
|
||||||
if asyncio.iscoroutinefunction(callback):
|
try:
|
||||||
asyncio.create_task(callback(result))
|
callback(result)
|
||||||
else:
|
except Exception as e:
|
||||||
# Run sync callback in executor to avoid blocking
|
logger.error(f"Error in callback for {batch_frame.stream_id}: {e}", exc_info=True)
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
loop.call_soon(lambda cb=callback, r=result: cb(r))
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing batch: {e}", exc_info=True)
|
logger.error(f"Error processing batch: {e}", exc_info=True)
|
||||||
# TODO: Emit error events to streams
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Mark buffer as idle
|
# Mark buffer as idle
|
||||||
async with self.buffer_lock:
|
with self.buffer_lock:
|
||||||
if buffer_name == "A":
|
if buffer_name == "A":
|
||||||
self.buffer_a_state = BufferState.IDLE
|
self.buffer_a_state = BufferState.IDLE
|
||||||
else:
|
else:
|
||||||
self.buffer_b_state = BufferState.IDLE
|
self.buffer_b_state = BufferState.IDLE
|
||||||
|
|
||||||
async def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Run inference on a batch of frames.
|
Run inference on a batch of frames.
|
||||||
|
|
||||||
|
|
@ -353,17 +356,15 @@ class ModelController:
|
||||||
Returns:
|
Returns:
|
||||||
List of detection results (one per frame)
|
List of detection results (one per frame)
|
||||||
"""
|
"""
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
# Check if model supports batching
|
# Check if model supports batching
|
||||||
if self.model_batch_size == 1:
|
if self.model_batch_size == 1:
|
||||||
# Process frames one at a time for batch_size=1 models
|
# Process frames one at a time for batch_size=1 models
|
||||||
return await self._run_sequential_inference(batch, loop)
|
return self._run_sequential_inference(batch)
|
||||||
else:
|
else:
|
||||||
# Use true batching for models that support it
|
# Use true batching for models that support it
|
||||||
return await self._run_batched_inference(batch, loop)
|
return self._run_batched_inference(batch)
|
||||||
|
|
||||||
async def _run_sequential_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
|
def _run_sequential_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||||
"""Run inference sequentially for batch_size=1 models"""
|
"""Run inference sequentially for batch_size=1 models"""
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
|
|
@ -376,14 +377,11 @@ class ModelController:
|
||||||
processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame
|
processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame
|
||||||
|
|
||||||
# Run inference for this frame
|
# Run inference for this frame
|
||||||
outputs = await loop.run_in_executor(
|
outputs = self.model_repository.infer(
|
||||||
None,
|
|
||||||
lambda p=processed: self.model_repository.infer(
|
|
||||||
self.model_id,
|
self.model_id,
|
||||||
{"images": p},
|
{"images": processed},
|
||||||
synchronize=True
|
synchronize=True
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Postprocess
|
# Postprocess
|
||||||
if self.postprocess_fn:
|
if self.postprocess_fn:
|
||||||
|
|
@ -406,7 +404,7 @@ class ModelController:
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def _run_batched_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
|
def _run_batched_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||||
"""Run true batched inference for models that support it"""
|
"""Run true batched inference for models that support it"""
|
||||||
# Preprocess frames (on GPU)
|
# Preprocess frames (on GPU)
|
||||||
preprocessed = []
|
preprocessed = []
|
||||||
|
|
@ -434,14 +432,11 @@ class ModelController:
|
||||||
batch = batch[:self.model_batch_size]
|
batch = batch[:self.model_batch_size]
|
||||||
|
|
||||||
# Run inference
|
# Run inference
|
||||||
outputs = await loop.run_in_executor(
|
outputs = self.model_repository.infer(
|
||||||
None,
|
|
||||||
lambda: self.model_repository.infer(
|
|
||||||
self.model_id,
|
self.model_id,
|
||||||
{"images": batch_tensor},
|
{"images": batch_tensor},
|
||||||
synchronize=True
|
synchronize=True
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Postprocess results (split batch back to individual results)
|
# Postprocess results (split batch back to individual results)
|
||||||
results = []
|
results = []
|
||||||
|
|
@ -472,14 +467,14 @@ class ModelController:
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def _process_remaining_buffers(self):
|
def _process_remaining_buffers(self):
|
||||||
"""Process any remaining frames in buffers during shutdown"""
|
"""Process any remaining frames in buffers during shutdown"""
|
||||||
if len(self.buffer_a) > 0:
|
if len(self.buffer_a) > 0:
|
||||||
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
|
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
|
||||||
await self._process_buffer("A")
|
self._process_buffer("A")
|
||||||
if len(self.buffer_b) > 0:
|
if len(self.buffer_b) > 0:
|
||||||
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
|
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
|
||||||
await self._process_buffer("B")
|
self._process_buffer("B")
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
"""Get current buffer statistics"""
|
"""Get current buffer statistics"""
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,17 @@
|
||||||
"""
|
"""
|
||||||
StreamConnectionManager - Async orchestration for stream processing with batched inference.
|
StreamConnectionManager - Event-driven orchestration for stream processing with batched inference.
|
||||||
|
|
||||||
This module provides high-level connection management for multiple RTSP streams,
|
This module provides high-level connection management for multiple RTSP streams,
|
||||||
coordinating decoders, batched inference, and tracking with an event-driven API.
|
coordinating decoders, batched inference, and tracking with callbacks and threading.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import Dict, Optional, Callable, AsyncIterator, Tuple, Any, List
|
from typing import Dict, Optional, Callable, Tuple, Any, List
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import logging
|
import logging
|
||||||
|
import queue
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
@ -44,7 +45,7 @@ class StreamConnection:
|
||||||
"""
|
"""
|
||||||
Represents a single stream connection with event emission.
|
Represents a single stream connection with event emission.
|
||||||
|
|
||||||
This class wraps a StreamDecoder, polls frames asynchronously, submits them
|
This class wraps a StreamDecoder, polls frames in a thread, submits them
|
||||||
to the ModelController for batched inference, runs tracking, and emits results
|
to the ModelController for batched inference, runs tracking, and emits results
|
||||||
via queues or callbacks.
|
via queues or callbacks.
|
||||||
|
|
||||||
|
|
@ -75,15 +76,19 @@ class StreamConnection:
|
||||||
self.last_frame_time = 0.0
|
self.last_frame_time = 0.0
|
||||||
|
|
||||||
# Event emission
|
# Event emission
|
||||||
self.result_queue: asyncio.Queue[TrackingResult] = asyncio.Queue()
|
self.result_queue: queue.Queue[TrackingResult] = queue.Queue()
|
||||||
self.error_queue: asyncio.Queue[Exception] = asyncio.Queue()
|
self.error_queue: queue.Queue[Exception] = queue.Queue()
|
||||||
|
|
||||||
# Tasks
|
# Event-driven state
|
||||||
self.poller_task: Optional[asyncio.Task] = None
|
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
async def start(self):
|
def start(self):
|
||||||
"""Start the connection (decoder and frame polling)"""
|
"""Start the connection (decoder with frame callback)"""
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
# Register callback for frame events from decoder
|
||||||
|
self.decoder.register_frame_callback(self._on_frame_decoded)
|
||||||
|
|
||||||
# Start decoder (runs in background thread)
|
# Start decoder (runs in background thread)
|
||||||
self.decoder.start()
|
self.decoder.start()
|
||||||
|
|
||||||
|
|
@ -93,7 +98,7 @@ class StreamConnection:
|
||||||
elapsed = 0.0
|
elapsed = 0.0
|
||||||
|
|
||||||
while elapsed < max_wait:
|
while elapsed < max_wait:
|
||||||
await asyncio.sleep(wait_interval)
|
time.sleep(wait_interval)
|
||||||
elapsed += wait_interval
|
elapsed += wait_interval
|
||||||
|
|
||||||
if self.decoder.is_connected():
|
if self.decoder.is_connected():
|
||||||
|
|
@ -105,21 +110,13 @@ class StreamConnection:
|
||||||
logger.warning(f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying...")
|
logger.warning(f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying...")
|
||||||
self.status = ConnectionStatus.CONNECTING
|
self.status = ConnectionStatus.CONNECTING
|
||||||
|
|
||||||
# Start frame polling task
|
def stop(self):
|
||||||
self.running = True
|
|
||||||
self.poller_task = asyncio.create_task(self._frame_poller())
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
"""Stop the connection and cleanup"""
|
"""Stop the connection and cleanup"""
|
||||||
logger.info(f"Stopping stream {self.stream_id}...")
|
logger.info(f"Stopping stream {self.stream_id}...")
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
if self.poller_task:
|
# Unregister frame callback
|
||||||
self.poller_task.cancel()
|
self.decoder.unregister_frame_callback(self._on_frame_decoded)
|
||||||
try:
|
|
||||||
await self.poller_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Stop decoder
|
# Stop decoder
|
||||||
self.decoder.stop()
|
self.decoder.stop()
|
||||||
|
|
@ -130,27 +127,23 @@ class StreamConnection:
|
||||||
self.status = ConnectionStatus.DISCONNECTED
|
self.status = ConnectionStatus.DISCONNECTED
|
||||||
logger.info(f"Stream {self.stream_id} stopped")
|
logger.info(f"Stream {self.stream_id} stopped")
|
||||||
|
|
||||||
async def _frame_poller(self):
|
def _on_frame_decoded(self, frame: torch.Tensor):
|
||||||
"""Poll frames from threaded decoder and submit to model controller"""
|
"""
|
||||||
last_decoder_frame_count = -1
|
Event handler called by decoder when a new frame is decoded.
|
||||||
|
This is the event-driven replacement for polling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: RGB frame tensor on GPU (3, H, W)
|
||||||
|
"""
|
||||||
|
if not self.running:
|
||||||
|
return
|
||||||
|
|
||||||
while self.running:
|
|
||||||
try:
|
try:
|
||||||
# Get current decoder frame count (no data transfer, just counter)
|
|
||||||
decoder_frame_count = self.decoder.get_frame_count()
|
|
||||||
|
|
||||||
# Check if decoder has a new frame (avoid reprocessing same frame)
|
|
||||||
if decoder_frame_count > last_decoder_frame_count:
|
|
||||||
# Poll frame from decoder (zero-copy - stays in VRAM)
|
|
||||||
frame = self.decoder.get_latest_frame(rgb=True)
|
|
||||||
|
|
||||||
if frame is not None:
|
|
||||||
last_decoder_frame_count = decoder_frame_count
|
|
||||||
self.last_frame_time = time.time()
|
self.last_frame_time = time.time()
|
||||||
self.frame_count += 1
|
self.frame_count += 1
|
||||||
|
|
||||||
# Submit to model controller for batched inference
|
# Submit to model controller for batched inference
|
||||||
await self.model_controller.submit_frame(
|
self.model_controller.submit_frame(
|
||||||
stream_id=self.stream_id,
|
stream_id=self.stream_id,
|
||||||
frame=frame,
|
frame=frame,
|
||||||
metadata={
|
metadata={
|
||||||
|
|
@ -159,26 +152,20 @@ class StreamConnection:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check decoder status
|
# Update connection status based on decoder status
|
||||||
if not self.decoder.is_connected():
|
if self.decoder.is_connected() and self.status != ConnectionStatus.CONNECTED:
|
||||||
if self.status == ConnectionStatus.CONNECTED:
|
|
||||||
logger.warning(f"Stream {self.stream_id} disconnected")
|
|
||||||
self.status = ConnectionStatus.DISCONNECTED
|
|
||||||
# Decoder will auto-reconnect, just update status
|
|
||||||
await asyncio.sleep(1.0)
|
|
||||||
if self.decoder.is_connected():
|
|
||||||
logger.info(f"Stream {self.stream_id} reconnected")
|
logger.info(f"Stream {self.stream_id} reconnected")
|
||||||
self.status = ConnectionStatus.CONNECTED
|
self.status = ConnectionStatus.CONNECTED
|
||||||
|
elif not self.decoder.is_connected() and self.status == ConnectionStatus.CONNECTED:
|
||||||
|
logger.warning(f"Stream {self.stream_id} disconnected")
|
||||||
|
self.status = ConnectionStatus.DISCONNECTED
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in frame poller for {self.stream_id}: {e}", exc_info=True)
|
logger.error(f"Error processing frame for {self.stream_id}: {e}", exc_info=True)
|
||||||
await self.error_queue.put(e)
|
self.error_queue.put(e)
|
||||||
self.status = ConnectionStatus.ERROR
|
self.status = ConnectionStatus.ERROR
|
||||||
|
|
||||||
# Sleep until next poll
|
def _handle_inference_result(self, result: Dict[str, Any]):
|
||||||
await asyncio.sleep(self.poll_interval)
|
|
||||||
|
|
||||||
async def _handle_inference_result(self, result: Dict[str, Any]):
|
|
||||||
"""
|
"""
|
||||||
Callback invoked by ModelController when inference is done.
|
Callback invoked by ModelController when inference is done.
|
||||||
Runs tracking and emits final result.
|
Runs tracking and emits final result.
|
||||||
|
|
@ -190,12 +177,8 @@ class StreamConnection:
|
||||||
# Extract detections
|
# Extract detections
|
||||||
detections = result["detections"]
|
detections = result["detections"]
|
||||||
|
|
||||||
# Run tracking (this is sync, so run in executor)
|
# Run tracking (synchronous)
|
||||||
loop = asyncio.get_event_loop()
|
tracked_objects = self._run_tracking_sync(detections)
|
||||||
tracked_objects = await loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
lambda: self._run_tracking_sync(detections)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create tracking result
|
# Create tracking result
|
||||||
tracking_result = TrackingResult(
|
tracking_result = TrackingResult(
|
||||||
|
|
@ -208,11 +191,11 @@ class StreamConnection:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Emit to result queue
|
# Emit to result queue
|
||||||
await self.result_queue.put(tracking_result)
|
self.result_queue.put(tracking_result)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
|
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
|
||||||
await self.error_queue.put(e)
|
self.error_queue.put(e)
|
||||||
|
|
||||||
def _run_tracking_sync(self, detections, min_confidence=0.7):
|
def _run_tracking_sync(self, detections, min_confidence=0.7):
|
||||||
"""
|
"""
|
||||||
|
|
@ -246,12 +229,12 @@ class StreamConnection:
|
||||||
# Update tracker with detections (lightweight, no model dependency!)
|
# Update tracker with detections (lightweight, no model dependency!)
|
||||||
return self.tracking_controller.update(detection_list)
|
return self.tracking_controller.update(detection_list)
|
||||||
|
|
||||||
async def tracking_results(self) -> AsyncIterator[TrackingResult]:
|
def tracking_results(self):
|
||||||
"""
|
"""
|
||||||
Async generator for tracking results.
|
Generator for tracking results (blocking iterator).
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
async for result in connection.tracking_results():
|
for result in connection.tracking_results():
|
||||||
print(result.tracked_objects)
|
print(result.tracked_objects)
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
|
|
@ -259,23 +242,23 @@ class StreamConnection:
|
||||||
"""
|
"""
|
||||||
while self.running or not self.result_queue.empty():
|
while self.running or not self.result_queue.empty():
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(self.result_queue.get(), timeout=1.0)
|
result = self.result_queue.get(timeout=1.0)
|
||||||
yield result
|
yield result
|
||||||
except asyncio.TimeoutError:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
async def errors(self) -> AsyncIterator[Exception]:
|
def errors(self):
|
||||||
"""
|
"""
|
||||||
Async generator for errors.
|
Generator for errors (blocking iterator).
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Exception objects as they occur
|
Exception objects as they occur
|
||||||
"""
|
"""
|
||||||
while self.running or not self.error_queue.empty():
|
while self.running or not self.error_queue.empty():
|
||||||
try:
|
try:
|
||||||
error = await asyncio.wait_for(self.error_queue.get(), timeout=1.0)
|
error = self.error_queue.get(timeout=1.0)
|
||||||
yield error
|
yield error
|
||||||
except asyncio.TimeoutError:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
|
@ -342,7 +325,7 @@ class StreamConnectionManager:
|
||||||
# State
|
# State
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
|
|
||||||
async def initialize(
|
def initialize(
|
||||||
self,
|
self,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
model_id: str = "detector",
|
model_id: str = "detector",
|
||||||
|
|
@ -368,11 +351,8 @@ class StreamConnectionManager:
|
||||||
"""
|
"""
|
||||||
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
|
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
|
||||||
|
|
||||||
# Load model
|
# Load model (synchronous)
|
||||||
loop = asyncio.get_event_loop()
|
self.model_repository.load_model(
|
||||||
await loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
lambda: self.model_repository.load_model(
|
|
||||||
model_id,
|
model_id,
|
||||||
model_path,
|
model_path,
|
||||||
num_contexts=num_contexts,
|
num_contexts=num_contexts,
|
||||||
|
|
@ -380,7 +360,6 @@ class StreamConnectionManager:
|
||||||
pt_precision=pt_precision,
|
pt_precision=pt_precision,
|
||||||
**pt_conversion_kwargs
|
**pt_conversion_kwargs
|
||||||
)
|
)
|
||||||
)
|
|
||||||
logger.info(f"Loaded model {model_id} from {model_path}")
|
logger.info(f"Loaded model {model_id} from {model_path}")
|
||||||
|
|
||||||
# Create model controller
|
# Create model controller
|
||||||
|
|
@ -392,7 +371,7 @@ class StreamConnectionManager:
|
||||||
preprocess_fn=preprocess_fn,
|
preprocess_fn=preprocess_fn,
|
||||||
postprocess_fn=postprocess_fn,
|
postprocess_fn=postprocess_fn,
|
||||||
)
|
)
|
||||||
await self.model_controller.start()
|
self.model_controller.start()
|
||||||
|
|
||||||
# Don't create a shared tracking controller here
|
# Don't create a shared tracking controller here
|
||||||
# Each stream will get its own tracking controller to avoid track accumulation
|
# Each stream will get its own tracking controller to avoid track accumulation
|
||||||
|
|
@ -402,7 +381,7 @@ class StreamConnectionManager:
|
||||||
self.initialized = True
|
self.initialized = True
|
||||||
logger.info("StreamConnectionManager initialized successfully")
|
logger.info("StreamConnectionManager initialized successfully")
|
||||||
|
|
||||||
async def connect_stream(
|
def connect_stream(
|
||||||
self,
|
self,
|
||||||
rtsp_url: str,
|
rtsp_url: str,
|
||||||
stream_id: Optional[str] = None,
|
stream_id: Optional[str] = None,
|
||||||
|
|
@ -416,8 +395,8 @@ class StreamConnectionManager:
|
||||||
Args:
|
Args:
|
||||||
rtsp_url: RTSP stream URL
|
rtsp_url: RTSP stream URL
|
||||||
stream_id: Optional stream identifier (auto-generated if not provided)
|
stream_id: Optional stream identifier (auto-generated if not provided)
|
||||||
on_tracking_result: Optional callback for tracking results (sync or async)
|
on_tracking_result: Optional callback for tracking results (synchronous)
|
||||||
on_error: Optional callback for errors (sync or async)
|
on_error: Optional callback for errors (synchronous)
|
||||||
buffer_size: Decoder buffer size (default: 30)
|
buffer_size: Decoder buffer size (default: 30)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -466,22 +445,30 @@ class StreamConnectionManager:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start connection
|
# Start connection
|
||||||
await connection.start()
|
connection.start()
|
||||||
|
|
||||||
# Store connection
|
# Store connection
|
||||||
self.connections[stream_id] = connection
|
self.connections[stream_id] = connection
|
||||||
|
|
||||||
# Set up user callbacks if provided
|
# Set up user callbacks if provided (run in separate threads)
|
||||||
if on_tracking_result:
|
if on_tracking_result:
|
||||||
asyncio.create_task(self._forward_results(connection, on_tracking_result))
|
threading.Thread(
|
||||||
|
target=self._forward_results,
|
||||||
|
args=(connection, on_tracking_result),
|
||||||
|
daemon=True
|
||||||
|
).start()
|
||||||
|
|
||||||
if on_error:
|
if on_error:
|
||||||
asyncio.create_task(self._forward_errors(connection, on_error))
|
threading.Thread(
|
||||||
|
target=self._forward_errors,
|
||||||
|
args=(connection, on_error),
|
||||||
|
daemon=True
|
||||||
|
).start()
|
||||||
|
|
||||||
logger.info(f"Stream {stream_id} connected successfully")
|
logger.info(f"Stream {stream_id} connected successfully")
|
||||||
return connection
|
return connection
|
||||||
|
|
||||||
async def disconnect_stream(self, stream_id: str):
|
def disconnect_stream(self, stream_id: str):
|
||||||
"""
|
"""
|
||||||
Disconnect and cleanup a stream.
|
Disconnect and cleanup a stream.
|
||||||
|
|
||||||
|
|
@ -490,27 +477,27 @@ class StreamConnectionManager:
|
||||||
"""
|
"""
|
||||||
connection = self.connections.get(stream_id)
|
connection = self.connections.get(stream_id)
|
||||||
if connection:
|
if connection:
|
||||||
await connection.stop()
|
connection.stop()
|
||||||
del self.connections[stream_id]
|
del self.connections[stream_id]
|
||||||
logger.info(f"Stream {stream_id} disconnected")
|
logger.info(f"Stream {stream_id} disconnected")
|
||||||
|
|
||||||
async def disconnect_all(self):
|
def disconnect_all(self):
|
||||||
"""Disconnect all streams"""
|
"""Disconnect all streams"""
|
||||||
logger.info("Disconnecting all streams...")
|
logger.info("Disconnecting all streams...")
|
||||||
stream_ids = list(self.connections.keys())
|
stream_ids = list(self.connections.keys())
|
||||||
for stream_id in stream_ids:
|
for stream_id in stream_ids:
|
||||||
await self.disconnect_stream(stream_id)
|
self.disconnect_stream(stream_id)
|
||||||
|
|
||||||
async def shutdown(self):
|
def shutdown(self):
|
||||||
"""Shutdown the manager and cleanup all resources"""
|
"""Shutdown the manager and cleanup all resources"""
|
||||||
logger.info("Shutting down StreamConnectionManager...")
|
logger.info("Shutting down StreamConnectionManager...")
|
||||||
|
|
||||||
# Disconnect all streams
|
# Disconnect all streams
|
||||||
await self.disconnect_all()
|
self.disconnect_all()
|
||||||
|
|
||||||
# Stop model controller
|
# Stop model controller
|
||||||
if self.model_controller:
|
if self.model_controller:
|
||||||
await self.model_controller.stop()
|
self.model_controller.stop()
|
||||||
|
|
||||||
# Note: Model repository cleanup is sync and may cause segfaults
|
# Note: Model repository cleanup is sync and may cause segfaults
|
||||||
# Leaving cleanup to garbage collection for now
|
# Leaving cleanup to garbage collection for now
|
||||||
|
|
@ -518,36 +505,30 @@ class StreamConnectionManager:
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
logger.info("StreamConnectionManager shutdown complete")
|
logger.info("StreamConnectionManager shutdown complete")
|
||||||
|
|
||||||
async def _forward_results(self, connection: StreamConnection, callback: Callable):
|
def _forward_results(self, connection: StreamConnection, callback: Callable):
|
||||||
"""
|
"""
|
||||||
Forward results from connection to user callback.
|
Forward results from connection to user callback.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
connection: StreamConnection to listen to
|
connection: StreamConnection to listen to
|
||||||
callback: User callback (sync or async)
|
callback: User callback (synchronous)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async for result in connection.tracking_results():
|
for result in connection.tracking_results():
|
||||||
if asyncio.iscoroutinefunction(callback):
|
|
||||||
await callback(result)
|
|
||||||
else:
|
|
||||||
callback(result)
|
callback(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True)
|
logger.error(f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True)
|
||||||
|
|
||||||
async def _forward_errors(self, connection: StreamConnection, callback: Callable):
|
def _forward_errors(self, connection: StreamConnection, callback: Callable):
|
||||||
"""
|
"""
|
||||||
Forward errors from connection to user callback.
|
Forward errors from connection to user callback.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
connection: StreamConnection to listen to
|
connection: StreamConnection to listen to
|
||||||
callback: User callback (sync or async)
|
callback: User callback (synchronous)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async for error in connection.errors():
|
for error in connection.errors():
|
||||||
if asyncio.iscoroutinefunction(callback):
|
|
||||||
await callback(error)
|
|
||||||
else:
|
|
||||||
callback(error)
|
callback(error)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True)
|
logger.error(f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import threading
|
import threading
|
||||||
from typing import Optional
|
from typing import Optional, Callable
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -10,6 +10,35 @@ from cuda.bindings import driver as cuda_driver
|
||||||
from .jpeg_encoder import encode_frame_to_jpeg
|
from .jpeg_encoder import encode_frame_to_jpeg
|
||||||
|
|
||||||
|
|
||||||
|
class FrameReference:
|
||||||
|
"""
|
||||||
|
CPU-side reference object for a GPU frame.
|
||||||
|
|
||||||
|
This object holds a cloned RGB tensor that is independent of PyNvVideoCodec's
|
||||||
|
DecodedFrame lifecycle. We don't keep the DecodedFrame to avoid conflicts
|
||||||
|
with PyNvVideoCodec's internal frame pool management.
|
||||||
|
"""
|
||||||
|
def __init__(self, rgb_tensor: torch.Tensor, buffer_index: int, decoder):
|
||||||
|
self.rgb_tensor = rgb_tensor # Cloned RGB tensor (independent copy)
|
||||||
|
self.buffer_index = buffer_index
|
||||||
|
self.decoder = decoder # Reference to decoder for marking as free
|
||||||
|
self._freed = False
|
||||||
|
|
||||||
|
def free(self):
|
||||||
|
"""Mark this frame as no longer in use"""
|
||||||
|
if not self._freed:
|
||||||
|
self._freed = True
|
||||||
|
self.decoder._mark_frame_free(self.buffer_index)
|
||||||
|
|
||||||
|
def is_freed(self) -> bool:
|
||||||
|
"""Check if this frame has been freed"""
|
||||||
|
return self._freed
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Auto-free on garbage collection"""
|
||||||
|
self.free()
|
||||||
|
|
||||||
|
|
||||||
def nv12_to_rgb_gpu(nv12_tensor: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
def nv12_to_rgb_gpu(nv12_tensor: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Convert NV12 format to RGB on GPU using PyTorch operations.
|
Convert NV12 format to RGB on GPU using PyTorch operations.
|
||||||
|
|
@ -183,10 +212,13 @@ class StreamDecoder:
|
||||||
self.status = ConnectionStatus.DISCONNECTED
|
self.status = ConnectionStatus.DISCONNECTED
|
||||||
self._status_lock = threading.Lock()
|
self._status_lock = threading.Lock()
|
||||||
|
|
||||||
# Frame buffer (ring buffer) - stores CUDA device pointers
|
# Frame buffer (ring buffer) - stores FrameReference objects
|
||||||
self.frame_buffer = deque(maxlen=buffer_size)
|
self.frame_buffer = deque(maxlen=buffer_size)
|
||||||
self._buffer_lock = threading.RLock()
|
self._buffer_lock = threading.RLock()
|
||||||
|
|
||||||
|
# Track which buffer slots are in use (list of FrameReference objects)
|
||||||
|
self._in_use_frames = [] # List of FrameReference objects currently held by callbacks
|
||||||
|
|
||||||
# Decoder and container instances
|
# Decoder and container instances
|
||||||
self.decoder = None
|
self.decoder = None
|
||||||
self.container = None
|
self.container = None
|
||||||
|
|
@ -200,6 +232,45 @@ class StreamDecoder:
|
||||||
self.frame_height: Optional[int] = None
|
self.frame_height: Optional[int] = None
|
||||||
self.frame_count: int = 0
|
self.frame_count: int = 0
|
||||||
|
|
||||||
|
# Frame callbacks - event-driven notification
|
||||||
|
self._frame_callbacks = []
|
||||||
|
self._callback_lock = threading.Lock()
|
||||||
|
|
||||||
|
def register_frame_callback(self, callback: Callable):
|
||||||
|
"""
|
||||||
|
Register a callback to be called when a new frame is decoded.
|
||||||
|
|
||||||
|
The callback will be called with the decoded frame tensor (GPU) as argument.
|
||||||
|
Callback signature: callback(frame: torch.Tensor) -> None
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: Function to call when new frame arrives
|
||||||
|
"""
|
||||||
|
with self._callback_lock:
|
||||||
|
self._frame_callbacks.append(callback)
|
||||||
|
|
||||||
|
def unregister_frame_callback(self, callback: Callable):
|
||||||
|
"""
|
||||||
|
Unregister a frame callback.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: The callback function to remove
|
||||||
|
"""
|
||||||
|
with self._callback_lock:
|
||||||
|
if callback in self._frame_callbacks:
|
||||||
|
self._frame_callbacks.remove(callback)
|
||||||
|
|
||||||
|
def _mark_frame_free(self, buffer_index: int):
|
||||||
|
"""
|
||||||
|
Mark a frame as freed (called by FrameReference when it's no longer in use).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
buffer_index: Index in the buffer for tracking purposes
|
||||||
|
"""
|
||||||
|
with self._buffer_lock:
|
||||||
|
# Remove from in-use tracking
|
||||||
|
self._in_use_frames = [f for f in self._in_use_frames if f.buffer_index != buffer_index]
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Start the RTSP stream decoding in background thread"""
|
"""Start the RTSP stream decoding in background thread"""
|
||||||
if self._decode_thread is not None and self._decode_thread.is_alive():
|
if self._decode_thread is not None and self._decode_thread.is_alive():
|
||||||
|
|
@ -278,6 +349,9 @@ class StreamDecoder:
|
||||||
|
|
||||||
def _decode_loop(self):
|
def _decode_loop(self):
|
||||||
"""Main decode loop running in background thread"""
|
"""Main decode loop running in background thread"""
|
||||||
|
# Set the CUDA device for this thread
|
||||||
|
torch.cuda.set_device(self.gpu_id)
|
||||||
|
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
max_retries = 5
|
max_retries = 5
|
||||||
|
|
||||||
|
|
@ -319,12 +393,61 @@ class StreamDecoder:
|
||||||
if not decoded_frames:
|
if not decoded_frames:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Add frames to ring buffer (thread-safe)
|
# Add frames to ring buffer and fire callbacks
|
||||||
with self._buffer_lock:
|
with self._buffer_lock:
|
||||||
for frame in decoded_frames:
|
for frame in decoded_frames:
|
||||||
self.frame_buffer.append(frame)
|
# Check for buffer overflow - discard oldest if needed
|
||||||
|
if len(self.frame_buffer) >= self.buffer_size:
|
||||||
|
# Check if oldest frame is still in use
|
||||||
|
if len(self._in_use_frames) > 0:
|
||||||
|
oldest_ref = self.frame_buffer[0] if len(self.frame_buffer) > 0 else None
|
||||||
|
if oldest_ref and not oldest_ref.is_freed():
|
||||||
|
# Force free the oldest frame to prevent overflow
|
||||||
|
print(f"[WARNING] Buffer overflow, force-freeing oldest frame (buffer_index={oldest_ref.buffer_index})")
|
||||||
|
oldest_ref.free()
|
||||||
|
|
||||||
|
# Deque will automatically remove oldest when at maxlen
|
||||||
|
|
||||||
|
# Convert to tensor
|
||||||
|
try:
|
||||||
|
# Convert DecodedFrame to PyTorch tensor using DLPack (zero-copy)
|
||||||
|
nv12_tensor = torch.from_dlpack(frame)
|
||||||
|
|
||||||
|
# Convert NV12 to RGB on GPU
|
||||||
|
if self.frame_height is not None and self.frame_width is not None:
|
||||||
|
rgb_tensor = nv12_to_rgb_gpu(nv12_tensor, self.frame_height, self.frame_width)
|
||||||
|
|
||||||
|
# CRITICAL: Clone the RGB tensor to break CUDA memory dependency
|
||||||
|
# The nv12_to_rgb_gpu creates a new tensor, but it still references
|
||||||
|
# the same CUDA context/stream. We need an independent copy.
|
||||||
|
rgb_tensor_cloned = rgb_tensor.clone()
|
||||||
|
|
||||||
|
# Create FrameReference object for C++-style memory management
|
||||||
|
# We don't keep the DecodedFrame to avoid conflicts with PyNvVideoCodec's
|
||||||
|
# internal frame pool - the clone is fully independent
|
||||||
|
buffer_index = self.frame_count
|
||||||
|
frame_ref = FrameReference(
|
||||||
|
rgb_tensor=rgb_tensor_cloned, # Independent cloned tensor
|
||||||
|
buffer_index=buffer_index,
|
||||||
|
decoder=self
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to buffer and in-use tracking
|
||||||
|
self.frame_buffer.append(frame_ref)
|
||||||
|
self._in_use_frames.append(frame_ref)
|
||||||
self.frame_count += 1
|
self.frame_count += 1
|
||||||
|
|
||||||
|
# Fire callbacks with the cloned RGB tensor from FrameReference
|
||||||
|
# The tensor is now independent of the DecodedFrame lifecycle
|
||||||
|
with self._callback_lock:
|
||||||
|
for callback in self._frame_callbacks:
|
||||||
|
try:
|
||||||
|
callback(frame_ref.rgb_tensor)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in frame callback: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error converting frame for callback: {e}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in decode loop for {self.rtsp_url}: {e}")
|
print(f"Error in decode loop for {self.rtsp_url}: {e}")
|
||||||
self._set_status(ConnectionStatus.RECONNECTING)
|
self._set_status(ConnectionStatus.RECONNECTING)
|
||||||
|
|
@ -351,35 +474,25 @@ class StreamDecoder:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
index: Frame index in buffer (-1 for latest, -2 for second latest, etc.)
|
index: Frame index in buffer (-1 for latest, -2 for second latest, etc.)
|
||||||
rgb: If True, convert NV12 to RGB. If False, return raw NV12 format.
|
rgb: If True, return RGB tensor. If False, not supported (returns None).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor in CUDA memory (device tensor) or None if buffer empty
|
torch.Tensor in CUDA memory (device tensor) or None if buffer empty
|
||||||
- If rgb=True: Shape (3, H, W) in RGB format, dtype uint8
|
- If rgb=True: Shape (3, H, W) in RGB format, dtype uint8
|
||||||
- If rgb=False: Shape (H*3/2, W) in NV12 format, dtype uint8
|
- If rgb=False: Not supported with FrameReference (returns None)
|
||||||
"""
|
"""
|
||||||
with self._buffer_lock:
|
with self._buffer_lock:
|
||||||
if len(self.frame_buffer) == 0:
|
if len(self.frame_buffer) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
|
||||||
decoded_frame = self.frame_buffer[index]
|
|
||||||
|
|
||||||
# Convert DecodedFrame to PyTorch tensor using DLPack (zero-copy)
|
|
||||||
# This keeps the data in GPU memory
|
|
||||||
nv12_tensor = torch.from_dlpack(decoded_frame)
|
|
||||||
|
|
||||||
if not rgb:
|
if not rgb:
|
||||||
# Return raw NV12 format
|
print("Warning: NV12 format not supported with FrameReference, only RGB")
|
||||||
return nv12_tensor
|
|
||||||
|
|
||||||
# Convert NV12 to RGB on GPU
|
|
||||||
if self.frame_height is None or self.frame_width is None:
|
|
||||||
print("Frame dimensions not available")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
rgb_tensor = nv12_to_rgb_gpu(nv12_tensor, self.frame_height, self.frame_width)
|
try:
|
||||||
return rgb_tensor
|
frame_ref = self.frame_buffer[index]
|
||||||
|
# Return the RGB tensor from FrameReference (cloned, independent)
|
||||||
|
return frame_ref.rgb_tensor
|
||||||
|
|
||||||
except (IndexError, Exception) as e:
|
except (IndexError, Exception) as e:
|
||||||
print(f"Error getting frame: {e}")
|
print(f"Error getting frame: {e}")
|
||||||
|
|
@ -448,6 +561,39 @@ class StreamDecoder:
|
||||||
with self._buffer_lock:
|
with self._buffer_lock:
|
||||||
return len(self.frame_buffer)
|
return len(self.frame_buffer)
|
||||||
|
|
||||||
|
def get_all_frames(self, rgb: bool = True) -> list:
|
||||||
|
"""
|
||||||
|
Get all frames currently in the buffer as CUDA tensors.
|
||||||
|
This drains the buffer and returns all frames.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rgb: If True, return RGB tensors. If False, not supported (returns empty list).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of torch.Tensor objects in CUDA memory
|
||||||
|
"""
|
||||||
|
if not rgb:
|
||||||
|
print("Warning: NV12 format not supported with FrameReference, only RGB")
|
||||||
|
return []
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
with self._buffer_lock:
|
||||||
|
# Get all frames from buffer
|
||||||
|
for frame_ref in self.frame_buffer:
|
||||||
|
try:
|
||||||
|
# Get RGB tensor from FrameReference
|
||||||
|
frames.append(frame_ref.rgb_tensor)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error getting frame: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Clear the buffer after reading all frames and free all references
|
||||||
|
for frame_ref in self.frame_buffer:
|
||||||
|
frame_ref.free()
|
||||||
|
self.frame_buffer.clear()
|
||||||
|
|
||||||
|
return frames
|
||||||
|
|
||||||
def get_frame_count(self) -> int:
|
def get_frame_count(self) -> int:
|
||||||
"""Get total number of frames decoded since start"""
|
"""Get total number of frames decoded since start"""
|
||||||
return self.frame_count
|
return self.frame_count
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,10 @@ This script demonstrates:
|
||||||
- Event-driven stream processing with StreamConnectionManager
|
- Event-driven stream processing with StreamConnectionManager
|
||||||
- Batched GPU inference with ModelController
|
- Batched GPU inference with ModelController
|
||||||
- Ping-pong buffer architecture for optimal throughput
|
- Ping-pong buffer architecture for optimal throughput
|
||||||
- Async/await pattern for multiple RTSP streams
|
- Callback-based event-driven pattern for RTSP streams
|
||||||
- Automatic PT to TensorRT conversion
|
- Automatic PT to TensorRT conversion
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -26,7 +25,7 @@ from services import (
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
async def main_single_stream():
|
def main_single_stream():
|
||||||
"""Single stream example with event-driven architecture."""
|
"""Single stream example with event-driven architecture."""
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
print("Event-Driven GPU-Accelerated Object Tracking - Single Stream")
|
print("Event-Driven GPU-Accelerated Object Tracking - Single Stream")
|
||||||
|
|
@ -66,7 +65,7 @@ async def main_single_stream():
|
||||||
print("Subsequent loads will use cached TensorRT engine\n")
|
print("Subsequent loads will use cached TensorRT engine\n")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await manager.initialize(
|
manager.initialize(
|
||||||
model_path=MODEL_PATH,
|
model_path=MODEL_PATH,
|
||||||
model_id="detector",
|
model_id="detector",
|
||||||
preprocess_fn=YOLOv8Utils.preprocess,
|
preprocess_fn=YOLOv8Utils.preprocess,
|
||||||
|
|
@ -85,7 +84,7 @@ async def main_single_stream():
|
||||||
# Connect stream
|
# Connect stream
|
||||||
print("\n[3/3] Connecting to stream...")
|
print("\n[3/3] Connecting to stream...")
|
||||||
try:
|
try:
|
||||||
connection = await manager.connect_stream(
|
connection = manager.connect_stream(
|
||||||
rtsp_url=STREAM_URL,
|
rtsp_url=STREAM_URL,
|
||||||
stream_id="camera_1",
|
stream_id="camera_1",
|
||||||
buffer_size=30
|
buffer_size=30
|
||||||
|
|
@ -110,7 +109,7 @@ async def main_single_stream():
|
||||||
cv2.resizeWindow("Object Tracking", 1280, 720)
|
cv2.resizeWindow("Object Tracking", 1280, 720)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for result in connection.tracking_results():
|
for result in connection.tracking_results():
|
||||||
result_count += 1
|
result_count += 1
|
||||||
|
|
||||||
# Check if we've reached max frames
|
# Check if we've reached max frames
|
||||||
|
|
@ -189,8 +188,8 @@ async def main_single_stream():
|
||||||
if ENABLE_DISPLAY:
|
if ENABLE_DISPLAY:
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
await connection.stop()
|
connection.stop()
|
||||||
await manager.shutdown()
|
manager.shutdown()
|
||||||
print("✓ Stopped")
|
print("✓ Stopped")
|
||||||
|
|
||||||
# Final stats
|
# Final stats
|
||||||
|
|
@ -199,7 +198,7 @@ async def main_single_stream():
|
||||||
print(f"\nFinal: {result_count} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)")
|
print(f"\nFinal: {result_count} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)")
|
||||||
|
|
||||||
|
|
||||||
async def main_multi_stream():
|
def main_multi_stream():
|
||||||
"""Multi-stream example with batched inference."""
|
"""Multi-stream example with batched inference."""
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
print("Event-Driven GPU-Accelerated Object Tracking - Multi-Stream")
|
print("Event-Driven GPU-Accelerated Object Tracking - Multi-Stream")
|
||||||
|
|
@ -245,7 +244,7 @@ async def main_multi_stream():
|
||||||
# Initialize with PT model
|
# Initialize with PT model
|
||||||
print("\n[2/3] Initializing with PT model...")
|
print("\n[2/3] Initializing with PT model...")
|
||||||
try:
|
try:
|
||||||
await manager.initialize(
|
manager.initialize(
|
||||||
model_path=MODEL_PATH,
|
model_path=MODEL_PATH,
|
||||||
model_id="detector",
|
model_id="detector",
|
||||||
preprocess_fn=YOLOv8Utils.preprocess,
|
preprocess_fn=YOLOv8Utils.preprocess,
|
||||||
|
|
@ -266,7 +265,7 @@ async def main_multi_stream():
|
||||||
connections = {}
|
connections = {}
|
||||||
for stream_id, rtsp_url in camera_urls:
|
for stream_id, rtsp_url in camera_urls:
|
||||||
try:
|
try:
|
||||||
conn = await manager.connect_stream(
|
conn = manager.connect_stream(
|
||||||
rtsp_url=rtsp_url,
|
rtsp_url=rtsp_url,
|
||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
buffer_size=30
|
buffer_size=30
|
||||||
|
|
@ -295,7 +294,7 @@ async def main_multi_stream():
|
||||||
# Simple approach: iterate over first connection's results
|
# Simple approach: iterate over first connection's results
|
||||||
# In production, you'd properly merge all result streams
|
# In production, you'd properly merge all result streams
|
||||||
for conn in connections.values():
|
for conn in connections.values():
|
||||||
async for result in conn.tracking_results():
|
for result in conn.tracking_results():
|
||||||
total_results += 1
|
total_results += 1
|
||||||
stream_id = result.stream_id
|
stream_id = result.stream_id
|
||||||
|
|
||||||
|
|
@ -322,8 +321,8 @@ async def main_multi_stream():
|
||||||
print(f"{'=' * 80}")
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
for conn in connections.values():
|
for conn in connections.values():
|
||||||
await conn.stop()
|
conn.stop()
|
||||||
await manager.shutdown()
|
manager.shutdown()
|
||||||
print("✓ Stopped")
|
print("✓ Stopped")
|
||||||
|
|
||||||
# Final stats
|
# Final stats
|
||||||
|
|
@ -335,6 +334,6 @@ async def main_multi_stream():
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
if len(sys.argv) > 1 and sys.argv[1] == "single":
|
if len(sys.argv) > 1 and sys.argv[1] == "single":
|
||||||
asyncio.run(main_single_stream())
|
main_single_stream()
|
||||||
else:
|
else:
|
||||||
asyncio.run(main_multi_stream())
|
main_multi_stream()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue