681 lines
No EOL
28 KiB
Python
681 lines
No EOL
28 KiB
Python
"""
|
|
Integration tests for complete detection workflow.
|
|
|
|
This module tests the full end-to-end detection pipeline from stream
|
|
to database update, ensuring all components work together correctly.
|
|
"""
|
|
import pytest
|
|
import asyncio
|
|
import uuid
|
|
import time
|
|
import json
|
|
import tempfile
|
|
from unittest.mock import Mock, AsyncMock, patch, MagicMock
|
|
from pathlib import Path
|
|
import numpy as np
|
|
import cv2
|
|
|
|
from detector_worker.app import create_app
|
|
from detector_worker.core.config import Configuration
|
|
from detector_worker.core.dependency_injection import ServiceContainer
|
|
from detector_worker.streams.stream_manager import StreamManager, StreamConfig
|
|
from detector_worker.models.model_manager import ModelManager
|
|
from detector_worker.pipeline.pipeline_executor import PipelineExecutor
|
|
from detector_worker.storage.database_manager import DatabaseManager
|
|
from detector_worker.storage.redis_client import RedisClient, RedisConfig
|
|
from detector_worker.communication.websocket_handler import WebSocketHandler
|
|
from detector_worker.communication.message_processor import MessageProcessor
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_config_file():
|
|
"""Create temporary configuration file."""
|
|
config_data = {
|
|
"poll_interval_ms": 100,
|
|
"max_streams": 5,
|
|
"target_fps": 10,
|
|
"reconnect_interval_sec": 1,
|
|
"max_retries": 2,
|
|
"database": {
|
|
"enabled": True,
|
|
"host": "localhost",
|
|
"port": 5432,
|
|
"database": "test_gas_station_1",
|
|
"user": "test_user",
|
|
"password": "test_pass"
|
|
},
|
|
"redis": {
|
|
"enabled": True,
|
|
"host": "localhost",
|
|
"port": 6379,
|
|
"db": 0
|
|
}
|
|
}
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
json.dump(config_data, f)
|
|
temp_path = f.name
|
|
|
|
yield temp_path
|
|
|
|
# Cleanup
|
|
Path(temp_path).unlink(missing_ok=True)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_frame():
|
|
"""Create a mock frame for testing."""
|
|
return np.ones((480, 640, 3), dtype=np.uint8) * 128
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_mpta_file():
|
|
"""Create a sample MPTA (pipeline) file."""
|
|
pipeline_config = {
|
|
"modelId": "car_frontal_detection_v1",
|
|
"modelFile": "car_frontal_detection_v1.pt",
|
|
"multiClass": True,
|
|
"expectedClasses": ["Car", "Frontal"],
|
|
"triggerClasses": ["Car", "Frontal"],
|
|
"minConfidence": 0.8,
|
|
"actions": [
|
|
{
|
|
"type": "redis_save_image",
|
|
"region": "Frontal",
|
|
"key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
|
|
"expire_seconds": 600
|
|
},
|
|
{
|
|
"type": "postgresql_create_record",
|
|
"table": "car_frontal_info",
|
|
"fields": {
|
|
"display_id": "{display_id}",
|
|
"captured_timestamp": "{timestamp}",
|
|
"session_id": "{session_id}",
|
|
"license_character": None,
|
|
"license_type": "No model available"
|
|
}
|
|
}
|
|
],
|
|
"branches": [
|
|
{
|
|
"modelId": "car_brand_cls_v1",
|
|
"modelFile": "car_brand_cls_v1.pt",
|
|
"parallel": True,
|
|
"crop": True,
|
|
"cropClass": "Frontal",
|
|
"triggerClasses": ["Frontal"],
|
|
"minConfidence": 0.85
|
|
},
|
|
{
|
|
"modelId": "car_bodytype_cls_v1",
|
|
"modelFile": "car_bodytype_cls_v1.pt",
|
|
"parallel": True,
|
|
"crop": True,
|
|
"cropClass": "Frontal",
|
|
"triggerClasses": ["Frontal"],
|
|
"minConfidence": 0.80
|
|
}
|
|
],
|
|
"parallelActions": [
|
|
{
|
|
"type": "postgresql_update_combined",
|
|
"table": "car_frontal_info",
|
|
"key_field": "session_id",
|
|
"waitForBranches": ["car_brand_cls_v1", "car_bodytype_cls_v1"],
|
|
"fields": {
|
|
"car_brand": "{car_brand_cls_v1.brand}",
|
|
"car_body_type": "{car_bodytype_cls_v1.body_type}"
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
json.dump(pipeline_config, f)
|
|
temp_path = f.name
|
|
|
|
yield temp_path
|
|
|
|
# Cleanup
|
|
Path(temp_path).unlink(missing_ok=True)
|
|
|
|
|
|
class TestCompleteDetectionWorkflow:
|
|
"""Test complete detection workflow from stream to database."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_rtsp_detection_workflow(self, temp_config_file, sample_mpta_file, mock_frame):
|
|
"""Test complete workflow: RTSP stream -> detection -> classification -> database."""
|
|
|
|
# Initialize configuration
|
|
config = Configuration()
|
|
config.load_from_file(temp_config_file)
|
|
|
|
# Create service container
|
|
container = ServiceContainer()
|
|
|
|
# Mock all external dependencies
|
|
with patch('cv2.VideoCapture') as mock_video_cap, \
|
|
patch('torch.load') as mock_torch_load, \
|
|
patch('psycopg2.connect') as mock_db_connect, \
|
|
patch('redis.Redis') as mock_redis:
|
|
|
|
# Setup video capture mock
|
|
mock_cap_instance = Mock()
|
|
mock_video_cap.return_value = mock_cap_instance
|
|
mock_cap_instance.isOpened.return_value = True
|
|
mock_cap_instance.read.return_value = (True, mock_frame)
|
|
|
|
# Setup model loading mock
|
|
mock_detection_model = Mock()
|
|
mock_brand_model = Mock()
|
|
mock_bodytype_model = Mock()
|
|
|
|
def mock_model_load(path, **kwargs):
|
|
if "detection" in path:
|
|
return mock_detection_model
|
|
elif "brand" in path:
|
|
return mock_brand_model
|
|
elif "bodytype" in path:
|
|
return mock_bodytype_model
|
|
return Mock()
|
|
|
|
mock_torch_load.side_effect = mock_model_load
|
|
|
|
# Setup detection model predictions
|
|
mock_detection_model.return_value = Mock()
|
|
mock_detection_model.return_value.boxes = Mock()
|
|
mock_detection_model.return_value.boxes.xyxy = Mock()
|
|
mock_detection_model.return_value.boxes.conf = Mock()
|
|
mock_detection_model.return_value.boxes.cls = Mock()
|
|
mock_detection_model.return_value.names = {0: "Car", 1: "Frontal"}
|
|
|
|
# Mock detection results - Car and Frontal detected
|
|
mock_detection_model.return_value.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
|
|
[100, 200, 300, 400], # Car bbox
|
|
[150, 250, 250, 350] # Frontal bbox
|
|
])
|
|
mock_detection_model.return_value.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
|
|
mock_detection_model.return_value.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
|
|
|
|
# Setup classification model predictions
|
|
mock_brand_result = Mock()
|
|
mock_brand_result.probs = Mock()
|
|
mock_brand_result.probs.top1 = 5 # Toyota index
|
|
mock_brand_result.probs.top1conf = Mock()
|
|
mock_brand_result.probs.top1conf.item.return_value = 0.87
|
|
mock_brand_result.names = {5: "Toyota"}
|
|
mock_brand_model.return_value = mock_brand_result
|
|
|
|
mock_bodytype_result = Mock()
|
|
mock_bodytype_result.probs = Mock()
|
|
mock_bodytype_result.probs.top1 = 2 # Sedan index
|
|
mock_bodytype_result.probs.top1conf = Mock()
|
|
mock_bodytype_result.probs.top1conf.item.return_value = 0.82
|
|
mock_bodytype_result.names = {2: "Sedan"}
|
|
mock_bodytype_model.return_value = mock_bodytype_result
|
|
|
|
# Setup database mock
|
|
mock_db_conn = Mock()
|
|
mock_db_connect.return_value = mock_db_conn
|
|
mock_cursor = Mock()
|
|
mock_db_conn.cursor.return_value = mock_cursor
|
|
|
|
# Setup Redis mock
|
|
mock_redis_instance = Mock()
|
|
mock_redis.return_value = mock_redis_instance
|
|
mock_redis_instance.ping.return_value = True
|
|
mock_redis_instance.set.return_value = True
|
|
mock_redis_instance.expire.return_value = True
|
|
|
|
# Initialize managers
|
|
stream_manager = StreamManager()
|
|
model_manager = ModelManager()
|
|
pipeline_executor = PipelineExecutor()
|
|
|
|
# Register services in container
|
|
container.register_singleton(StreamManager, lambda: stream_manager)
|
|
container.register_singleton(ModelManager, lambda: model_manager)
|
|
container.register_singleton(PipelineExecutor, lambda: pipeline_executor)
|
|
|
|
try:
|
|
# 1. Create RTSP stream
|
|
stream_config = StreamConfig(
|
|
stream_url="rtsp://example.com/stream",
|
|
stream_type="rtsp",
|
|
target_fps=10
|
|
)
|
|
|
|
stream_info = await stream_manager.create_stream(
|
|
camera_id="camera_001",
|
|
config=stream_config,
|
|
subscription_id="sub_001"
|
|
)
|
|
|
|
# 2. Load pipeline and models
|
|
pipeline_config = json.loads(Path(sample_mpta_file).read_text())
|
|
|
|
# Mock model file paths exist
|
|
with patch('os.path.exists', return_value=True):
|
|
await model_manager.load_models_from_config(pipeline_config)
|
|
|
|
# 3. Get frame from stream
|
|
frame = stream_manager.get_latest_frame("camera_001")
|
|
assert frame is not None
|
|
|
|
# 4. Run detection pipeline
|
|
detection_context = {
|
|
"camera_id": "camera_001",
|
|
"display_id": "display_001",
|
|
"frame": mock_frame,
|
|
"timestamp": int(time.time() * 1000),
|
|
"session_id": str(uuid.uuid4())
|
|
}
|
|
|
|
# Mock cv2.imencode for Redis image storage
|
|
with patch('cv2.imencode') as mock_imencode:
|
|
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
|
|
mock_imencode.return_value = (True, encoded_data)
|
|
|
|
pipeline_result = await pipeline_executor.execute_pipeline(
|
|
pipeline_config,
|
|
detection_context
|
|
)
|
|
|
|
# 5. Verify pipeline execution
|
|
assert pipeline_result is not None
|
|
assert pipeline_result.get("status") == "completed"
|
|
|
|
# 6. Verify database operations
|
|
# Should have called create record
|
|
create_calls = [call for call in mock_cursor.execute.call_args_list
|
|
if "INSERT" in str(call)]
|
|
assert len(create_calls) >= 1
|
|
|
|
# Should have called update with classification results
|
|
update_calls = [call for call in mock_cursor.execute.call_args_list
|
|
if "UPDATE" in str(call)]
|
|
assert len(update_calls) >= 1
|
|
|
|
# 7. Verify Redis operations
|
|
# Should have saved cropped image
|
|
assert mock_redis_instance.set.called
|
|
assert mock_redis_instance.expire.called
|
|
|
|
# 8. Verify final results
|
|
assert "detections" in pipeline_result
|
|
detections = pipeline_result["detections"]
|
|
assert len(detections) >= 2 # Car and Frontal
|
|
|
|
# Check detection classes
|
|
detected_classes = [d.get("class") for d in detections]
|
|
assert "Car" in detected_classes
|
|
assert "Frontal" in detected_classes
|
|
|
|
# 9. Verify classification results in context
|
|
classification_results = pipeline_result.get("classification_results", {})
|
|
assert "car_brand_cls_v1" in classification_results
|
|
assert "car_bodytype_cls_v1" in classification_results
|
|
|
|
brand_result = classification_results["car_brand_cls_v1"]
|
|
assert brand_result.get("brand") == "Toyota"
|
|
assert brand_result.get("confidence") == 0.87
|
|
|
|
bodytype_result = classification_results["car_bodytype_cls_v1"]
|
|
assert bodytype_result.get("body_type") == "Sedan"
|
|
assert bodytype_result.get("confidence") == 0.82
|
|
|
|
finally:
|
|
# Cleanup
|
|
await stream_manager.stop_all_streams()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_websocket_subscription_workflow(self, temp_config_file, sample_mpta_file):
|
|
"""Test complete WebSocket subscription workflow."""
|
|
|
|
# Mock WebSocket for testing
|
|
mock_websocket = Mock()
|
|
mock_websocket.accept = AsyncMock()
|
|
mock_websocket.send_json = AsyncMock()
|
|
mock_websocket.receive_json = AsyncMock()
|
|
|
|
# Initialize configuration
|
|
config = Configuration()
|
|
config.load_from_file(temp_config_file)
|
|
|
|
# Initialize message processor and WebSocket handler
|
|
message_processor = MessageProcessor()
|
|
websocket_handler = WebSocketHandler(message_processor)
|
|
|
|
with patch('cv2.VideoCapture') as mock_video_cap, \
|
|
patch('torch.load') as mock_torch_load, \
|
|
patch('psycopg2.connect') as mock_db_connect, \
|
|
patch('redis.Redis') as mock_redis:
|
|
|
|
# Setup mocks
|
|
mock_cap_instance = Mock()
|
|
mock_video_cap.return_value = mock_cap_instance
|
|
mock_cap_instance.isOpened.return_value = True
|
|
mock_cap_instance.read.return_value = (True, np.ones((480, 640, 3), dtype=np.uint8))
|
|
|
|
mock_torch_load.return_value = Mock()
|
|
mock_db_connect.return_value = Mock()
|
|
mock_redis.return_value = Mock()
|
|
|
|
# Simulate WebSocket message sequence
|
|
subscription_message = {
|
|
"type": "subscribe",
|
|
"payload": {
|
|
"subscriptionIdentifier": "display-001;cam-001",
|
|
"rtspUrl": "rtsp://example.com/stream",
|
|
"modelUrl": f"file://{sample_mpta_file}",
|
|
"modelId": 101,
|
|
"modelName": "Vehicle Detection",
|
|
"cropX1": 100, "cropY1": 200,
|
|
"cropX2": 300, "cropY2": 400
|
|
}
|
|
}
|
|
|
|
mock_websocket.receive_json.side_effect = [
|
|
subscription_message,
|
|
{"type": "requestState"},
|
|
{"type": "unsubscribe", "payload": {"subscriptionIdentifier": "display-001;cam-001"}}
|
|
]
|
|
|
|
# Mock file operations
|
|
with patch('builtins.open', mock_open(read_data=json.dumps(json.loads(Path(sample_mpta_file).read_text())))):
|
|
with patch('os.path.exists', return_value=True):
|
|
|
|
try:
|
|
# Start WebSocket handler (will process messages)
|
|
await websocket_handler.handle_websocket(mock_websocket, "client_001")
|
|
|
|
# Verify WebSocket interactions
|
|
mock_websocket.accept.assert_called_once()
|
|
|
|
# Should have sent subscription acknowledgment
|
|
send_calls = mock_websocket.send_json.call_args_list
|
|
assert len(send_calls) >= 1
|
|
|
|
# Check subscription acknowledgment
|
|
first_response = send_calls[0][0][0]
|
|
assert first_response.get("type") == "subscribeAck"
|
|
assert first_response.get("status") == "success"
|
|
|
|
# Should have sent state report
|
|
state_responses = [call[0][0] for call in send_calls
|
|
if call[0][0].get("type") == "stateReport"]
|
|
assert len(state_responses) >= 1
|
|
|
|
except Exception as e:
|
|
# WebSocket disconnect is expected at end of message sequence
|
|
pass
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_http_snapshot_workflow(self, temp_config_file, mock_frame):
|
|
"""Test HTTP snapshot workflow."""
|
|
|
|
config = Configuration()
|
|
config.load_from_file(temp_config_file)
|
|
|
|
stream_manager = StreamManager()
|
|
|
|
with patch('requests.get') as mock_requests:
|
|
mock_response = Mock()
|
|
mock_response.status_code = 200
|
|
mock_response.content = b"fake_jpeg_data"
|
|
mock_requests.return_value = mock_response
|
|
|
|
with patch('cv2.imdecode') as mock_imdecode:
|
|
mock_imdecode.return_value = mock_frame
|
|
|
|
try:
|
|
# Create HTTP snapshot stream
|
|
stream_config = StreamConfig(
|
|
stream_url="http://camera.example.com/snapshot.jpg",
|
|
stream_type="http_snapshot",
|
|
snapshot_interval=1.0
|
|
)
|
|
|
|
stream_info = await stream_manager.create_stream(
|
|
camera_id="camera_002",
|
|
config=stream_config,
|
|
subscription_id="sub_002"
|
|
)
|
|
|
|
# Wait for snapshot capture
|
|
await asyncio.sleep(1.2)
|
|
|
|
# Verify frame was captured
|
|
frame = stream_manager.get_latest_frame("camera_002")
|
|
assert frame is not None
|
|
assert np.array_equal(frame, mock_frame)
|
|
|
|
# Verify HTTP request was made
|
|
mock_requests.assert_called()
|
|
mock_imdecode.assert_called()
|
|
|
|
finally:
|
|
await stream_manager.stop_all_streams()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_error_recovery_workflow(self, temp_config_file):
|
|
"""Test error recovery and resilience."""
|
|
|
|
config = Configuration()
|
|
config.load_from_file(temp_config_file)
|
|
|
|
stream_manager = StreamManager()
|
|
|
|
with patch('cv2.VideoCapture') as mock_video_cap:
|
|
# Simulate connection failures then success
|
|
mock_cap_instances = []
|
|
|
|
# First attempt fails
|
|
mock_cap_fail = Mock()
|
|
mock_cap_fail.isOpened.return_value = False
|
|
mock_cap_instances.append(mock_cap_fail)
|
|
|
|
# Second attempt succeeds
|
|
mock_cap_success = Mock()
|
|
mock_cap_success.isOpened.return_value = True
|
|
mock_cap_success.read.return_value = (True, np.ones((480, 640, 3), dtype=np.uint8))
|
|
mock_cap_instances.append(mock_cap_success)
|
|
|
|
mock_video_cap.side_effect = mock_cap_instances
|
|
|
|
try:
|
|
stream_config = StreamConfig(
|
|
stream_url="rtsp://unreliable.example.com/stream",
|
|
stream_type="rtsp",
|
|
max_retries=2
|
|
)
|
|
|
|
# First attempt should fail
|
|
try:
|
|
await stream_manager.create_stream(
|
|
camera_id="camera_003",
|
|
config=stream_config,
|
|
subscription_id="sub_003"
|
|
)
|
|
assert False, "Should have failed on first attempt"
|
|
except Exception:
|
|
pass
|
|
|
|
# Retry should succeed
|
|
stream_info = await stream_manager.create_stream(
|
|
camera_id="camera_003",
|
|
config=stream_config,
|
|
subscription_id="sub_003"
|
|
)
|
|
|
|
assert stream_info is not None
|
|
assert mock_video_cap.call_count == 2
|
|
|
|
finally:
|
|
await stream_manager.stop_all_streams()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_streams_workflow(self, temp_config_file, mock_frame):
|
|
"""Test handling multiple concurrent streams."""
|
|
|
|
config = Configuration()
|
|
config.load_from_file(temp_config_file)
|
|
|
|
stream_manager = StreamManager()
|
|
|
|
with patch('cv2.VideoCapture') as mock_video_cap, \
|
|
patch('requests.get') as mock_requests:
|
|
|
|
# Setup RTSP mock
|
|
mock_cap_instance = Mock()
|
|
mock_video_cap.return_value = mock_cap_instance
|
|
mock_cap_instance.isOpened.return_value = True
|
|
mock_cap_instance.read.return_value = (True, mock_frame)
|
|
|
|
# Setup HTTP mock
|
|
mock_response = Mock()
|
|
mock_response.status_code = 200
|
|
mock_response.content = b"fake_jpeg_data"
|
|
mock_requests.return_value = mock_response
|
|
|
|
with patch('cv2.imdecode', return_value=mock_frame):
|
|
|
|
try:
|
|
# Create multiple streams concurrently
|
|
stream_tasks = []
|
|
|
|
# RTSP streams
|
|
for i in range(3):
|
|
config = StreamConfig(
|
|
stream_url=f"rtsp://camera{i}.example.com/stream",
|
|
stream_type="rtsp"
|
|
)
|
|
task = stream_manager.create_stream(
|
|
camera_id=f"camera_rtsp_{i}",
|
|
config=config,
|
|
subscription_id=f"sub_rtsp_{i}"
|
|
)
|
|
stream_tasks.append(task)
|
|
|
|
# HTTP snapshot streams
|
|
for i in range(2):
|
|
config = StreamConfig(
|
|
stream_url=f"http://camera{i}.example.com/snapshot.jpg",
|
|
stream_type="http_snapshot"
|
|
)
|
|
task = stream_manager.create_stream(
|
|
camera_id=f"camera_http_{i}",
|
|
config=config,
|
|
subscription_id=f"sub_http_{i}"
|
|
)
|
|
stream_tasks.append(task)
|
|
|
|
# Wait for all streams to be created
|
|
stream_results = await asyncio.gather(*stream_tasks, return_exceptions=True)
|
|
|
|
# Verify all streams were created successfully
|
|
successful_streams = [r for r in stream_results if not isinstance(r, Exception)]
|
|
assert len(successful_streams) == 5
|
|
|
|
# Verify frames can be retrieved from all streams
|
|
await asyncio.sleep(0.5) # Allow time for frame capture
|
|
|
|
for i in range(3):
|
|
frame = stream_manager.get_latest_frame(f"camera_rtsp_{i}")
|
|
assert frame is not None
|
|
|
|
for i in range(2):
|
|
frame = stream_manager.get_latest_frame(f"camera_http_{i}")
|
|
# HTTP snapshots might not have frames immediately
|
|
|
|
# Verify stream statistics
|
|
stats = stream_manager.get_stream_statistics()
|
|
assert stats["total_streams"] == 5
|
|
assert stats["active_streams"] >= 3 # At least RTSP streams should be active
|
|
|
|
finally:
|
|
await stream_manager.stop_all_streams()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_usage_workflow(self, temp_config_file):
|
|
"""Test memory usage tracking and cleanup."""
|
|
|
|
config = Configuration()
|
|
config.load_from_file(temp_config_file)
|
|
|
|
# Create managers with small limits for testing
|
|
stream_manager = StreamManager({"max_streams": 10})
|
|
model_manager = ModelManager({"cache_max_size": 5})
|
|
|
|
with patch('cv2.VideoCapture') as mock_video_cap, \
|
|
patch('torch.load') as mock_torch_load:
|
|
|
|
mock_cap_instance = Mock()
|
|
mock_video_cap.return_value = mock_cap_instance
|
|
mock_cap_instance.isOpened.return_value = True
|
|
mock_cap_instance.read.return_value = (True, np.ones((100, 100, 3), dtype=np.uint8))
|
|
|
|
# Mock model loading
|
|
def create_mock_model():
|
|
model = Mock()
|
|
# Mock model parameters for memory estimation
|
|
param1 = Mock()
|
|
param1.numel.return_value = 1000
|
|
param1.element_size.return_value = 4
|
|
model.parameters.return_value = [param1]
|
|
return model
|
|
|
|
mock_torch_load.side_effect = lambda *args, **kwargs: create_mock_model()
|
|
|
|
try:
|
|
# Create streams up to limit
|
|
for i in range(8):
|
|
stream_config = StreamConfig(
|
|
stream_url=f"rtsp://test{i}.example.com/stream",
|
|
stream_type="rtsp"
|
|
)
|
|
await stream_manager.create_stream(
|
|
camera_id=f"test_camera_{i}",
|
|
config=stream_config,
|
|
subscription_id=f"test_sub_{i}"
|
|
)
|
|
|
|
# Load models up to cache limit
|
|
with patch('os.path.exists', return_value=True):
|
|
for i in range(7):
|
|
config = {
|
|
"model_id": f"test_model_{i}",
|
|
"model_path": f"/fake/path/model_{i}.pt",
|
|
"model_type": "detection",
|
|
"device": "cpu"
|
|
}
|
|
await model_manager.load_model_from_dict(config)
|
|
|
|
# Check memory usage tracking
|
|
stream_stats = stream_manager.get_stream_statistics()
|
|
model_stats = model_manager.get_cache_statistics()
|
|
|
|
assert stream_stats["total_streams"] == 8
|
|
assert model_stats["size"] <= 5 # Should be limited by cache size
|
|
|
|
# Test cleanup
|
|
cleaned_models = model_manager.cleanup_unused_models()
|
|
assert cleaned_models >= 0
|
|
|
|
stopped_streams = await stream_manager.stop_all_streams()
|
|
assert stopped_streams == 8
|
|
|
|
# Verify cleanup
|
|
final_stream_stats = stream_manager.get_stream_statistics()
|
|
assert final_stream_stats["total_streams"] == 0
|
|
|
|
finally:
|
|
await stream_manager.stop_all_streams()
|
|
|
|
|
|
def mock_open(read_data=""):
|
|
"""Create a mock file opener."""
|
|
from unittest.mock import mock_open as _mock_open
|
|
return _mock_open(read_data=read_data) |