Refactor: PHASE 8: Testing & Integration
This commit is contained in:
parent
af34f4fd08
commit
9e8c6804a7
32 changed files with 17128 additions and 0 deletions
959
tests/unit/pipeline/test_action_executor.py
Normal file
959
tests/unit/pipeline/test_action_executor.py
Normal file
|
@ -0,0 +1,959 @@
|
|||
"""
|
||||
Unit tests for action execution functionality.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, MagicMock, patch, AsyncMock
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from detector_worker.pipeline.action_executor import (
|
||||
ActionExecutor,
|
||||
ActionResult,
|
||||
ActionType,
|
||||
RedisAction,
|
||||
PostgreSQLAction,
|
||||
FileAction
|
||||
)
|
||||
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
|
||||
from detector_worker.core.exceptions import ActionError, RedisError, DatabaseError
|
||||
|
||||
|
||||
class TestActionResult:
|
||||
"""Test action execution result."""
|
||||
|
||||
def test_creation_success(self):
|
||||
"""Test successful action result creation."""
|
||||
result = ActionResult(
|
||||
action_type=ActionType.REDIS_SAVE,
|
||||
success=True,
|
||||
execution_time=0.05,
|
||||
metadata={"key": "saved_image_key", "expiry": 600}
|
||||
)
|
||||
|
||||
assert result.action_type == ActionType.REDIS_SAVE
|
||||
assert result.success is True
|
||||
assert result.execution_time == 0.05
|
||||
assert result.metadata["key"] == "saved_image_key"
|
||||
assert result.error is None
|
||||
|
||||
def test_creation_failure(self):
|
||||
"""Test failed action result creation."""
|
||||
result = ActionResult(
|
||||
action_type=ActionType.POSTGRESQL_INSERT,
|
||||
success=False,
|
||||
error="Database connection failed",
|
||||
execution_time=0.02
|
||||
)
|
||||
|
||||
assert result.action_type == ActionType.POSTGRESQL_INSERT
|
||||
assert result.success is False
|
||||
assert result.error == "Database connection failed"
|
||||
assert result.metadata == {}
|
||||
|
||||
|
||||
class TestRedisAction:
|
||||
"""Test Redis action implementations."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test Redis action creation."""
|
||||
action_config = {
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:{display_id}:{timestamp}:{session_id}",
|
||||
"expire_seconds": 600
|
||||
}
|
||||
|
||||
action = RedisAction(action_config)
|
||||
|
||||
assert action.action_type == ActionType.REDIS_SAVE
|
||||
assert action.region == "car"
|
||||
assert action.key_template == "inference:{display_id}:{timestamp}:{session_id}"
|
||||
assert action.expire_seconds == 600
|
||||
|
||||
def test_resolve_key_template(self):
|
||||
"""Test key template resolution."""
|
||||
action_config = {
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
|
||||
"expire_seconds": 600
|
||||
}
|
||||
|
||||
action = RedisAction(action_config)
|
||||
|
||||
context = {
|
||||
"display_id": "display_001",
|
||||
"timestamp": "1640995200000",
|
||||
"session_id": "session_123",
|
||||
"filename": "detection.jpg"
|
||||
}
|
||||
|
||||
resolved_key = action.resolve_key(context)
|
||||
expected_key = "inference:display_001:1640995200000:session_123:detection.jpg"
|
||||
|
||||
assert resolved_key == expected_key
|
||||
|
||||
def test_resolve_key_missing_variable(self):
|
||||
"""Test key resolution with missing variable."""
|
||||
action_config = {
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:{display_id}:{missing_var}",
|
||||
"expire_seconds": 600
|
||||
}
|
||||
|
||||
action = RedisAction(action_config)
|
||||
|
||||
context = {"display_id": "display_001"}
|
||||
|
||||
with pytest.raises(ActionError):
|
||||
action.resolve_key(context)
|
||||
|
||||
|
||||
class TestPostgreSQLAction:
|
||||
"""Test PostgreSQL action implementations."""
|
||||
|
||||
def test_creation_insert(self):
|
||||
"""Test PostgreSQL insert action creation."""
|
||||
action_config = {
|
||||
"type": "postgresql_insert",
|
||||
"table": "detections",
|
||||
"fields": {
|
||||
"camera_id": "{camera_id}",
|
||||
"session_id": "{session_id}",
|
||||
"detection_class": "{class}",
|
||||
"confidence": "{confidence}",
|
||||
"bbox_x1": "{bbox.x1}",
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
}
|
||||
|
||||
action = PostgreSQLAction(action_config)
|
||||
|
||||
assert action.action_type == ActionType.POSTGRESQL_INSERT
|
||||
assert action.table == "detections"
|
||||
assert len(action.fields) == 6
|
||||
assert action.key_field is None
|
||||
|
||||
def test_creation_update(self):
|
||||
"""Test PostgreSQL update action creation."""
|
||||
action_config = {
|
||||
"type": "postgresql_update_combined",
|
||||
"table": "car_info",
|
||||
"key_field": "session_id",
|
||||
"fields": {
|
||||
"car_brand": "{car_brand_cls.brand}",
|
||||
"car_body_type": "{car_bodytype_cls.body_type}",
|
||||
"updated_at": "NOW()"
|
||||
},
|
||||
"waitForBranches": ["car_brand_cls", "car_bodytype_cls"]
|
||||
}
|
||||
|
||||
action = PostgreSQLAction(action_config)
|
||||
|
||||
assert action.action_type == ActionType.POSTGRESQL_UPDATE
|
||||
assert action.table == "car_info"
|
||||
assert action.key_field == "session_id"
|
||||
assert action.wait_for_branches == ["car_brand_cls", "car_bodytype_cls"]
|
||||
|
||||
def test_resolve_field_values(self):
|
||||
"""Test field value resolution."""
|
||||
action_config = {
|
||||
"type": "postgresql_insert",
|
||||
"table": "detections",
|
||||
"fields": {
|
||||
"camera_id": "{camera_id}",
|
||||
"detection_class": "{class}",
|
||||
"confidence": "{confidence}",
|
||||
"brand": "{car_brand_cls.brand}"
|
||||
}
|
||||
}
|
||||
|
||||
action = PostgreSQLAction(action_config)
|
||||
|
||||
context = {
|
||||
"camera_id": "camera_001",
|
||||
"class": "car",
|
||||
"confidence": 0.85
|
||||
}
|
||||
|
||||
branch_results = {
|
||||
"car_brand_cls": {"brand": "Toyota", "confidence": 0.78}
|
||||
}
|
||||
|
||||
resolved_fields = action.resolve_field_values(context, branch_results)
|
||||
|
||||
assert resolved_fields["camera_id"] == "camera_001"
|
||||
assert resolved_fields["detection_class"] == "car"
|
||||
assert resolved_fields["confidence"] == 0.85
|
||||
assert resolved_fields["brand"] == "Toyota"
|
||||
|
||||
|
||||
class TestFileAction:
|
||||
"""Test file action implementations."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test file action creation."""
|
||||
action_config = {
|
||||
"type": "save_image",
|
||||
"path": "/tmp/detections/{camera_id}_{timestamp}.jpg",
|
||||
"region": "car",
|
||||
"format": "jpeg",
|
||||
"quality": 85
|
||||
}
|
||||
|
||||
action = FileAction(action_config)
|
||||
|
||||
assert action.action_type == ActionType.SAVE_IMAGE
|
||||
assert action.path_template == "/tmp/detections/{camera_id}_{timestamp}.jpg"
|
||||
assert action.region == "car"
|
||||
assert action.format == "jpeg"
|
||||
assert action.quality == 85
|
||||
|
||||
def test_resolve_path_template(self):
|
||||
"""Test path template resolution."""
|
||||
action_config = {
|
||||
"type": "save_image",
|
||||
"path": "/tmp/detections/{camera_id}/{date}/{timestamp}.jpg"
|
||||
}
|
||||
|
||||
action = FileAction(action_config)
|
||||
|
||||
context = {
|
||||
"camera_id": "camera_001",
|
||||
"timestamp": "1640995200000",
|
||||
"date": "2022-01-01"
|
||||
}
|
||||
|
||||
resolved_path = action.resolve_path(context)
|
||||
expected_path = "/tmp/detections/camera_001/2022-01-01/1640995200000.jpg"
|
||||
|
||||
assert resolved_path == expected_path
|
||||
|
||||
|
||||
class TestActionExecutor:
|
||||
"""Test action execution functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test action executor initialization."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
assert executor.redis_client is None
|
||||
assert executor.db_manager is None
|
||||
assert executor.max_concurrent_actions == 10
|
||||
assert executor.action_timeout == 30.0
|
||||
|
||||
def test_initialization_with_clients(self, mock_redis_client, mock_database_connection):
|
||||
"""Test initialization with client instances."""
|
||||
executor = ActionExecutor(
|
||||
redis_client=mock_redis_client,
|
||||
db_manager=mock_database_connection
|
||||
)
|
||||
|
||||
assert executor.redis_client is mock_redis_client
|
||||
assert executor.db_manager is mock_database_connection
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_actions_empty_list(self):
|
||||
"""Test executing empty action list."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
context = {
|
||||
"camera_id": "camera_001",
|
||||
"session_id": "session_123"
|
||||
}
|
||||
|
||||
results = await executor.execute_actions([], {}, context)
|
||||
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_redis_save_action(self, mock_redis_client, mock_frame):
|
||||
"""Test executing Redis save image action."""
|
||||
executor = ActionExecutor(redis_client=mock_redis_client)
|
||||
|
||||
actions = [
|
||||
{
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:{camera_id}:{session_id}",
|
||||
"expire_seconds": 600
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = {
|
||||
"camera_id": "camera_001",
|
||||
"session_id": "session_123",
|
||||
"frame_data": mock_frame
|
||||
}
|
||||
|
||||
# Mock successful Redis operations
|
||||
mock_redis_client.set.return_value = True
|
||||
mock_redis_client.expire.return_value = True
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].success is True
|
||||
assert results[0].action_type == ActionType.REDIS_SAVE
|
||||
|
||||
# Verify Redis calls
|
||||
mock_redis_client.set.assert_called_once()
|
||||
mock_redis_client.expire.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_postgresql_insert_action(self, mock_database_connection):
|
||||
"""Test executing PostgreSQL insert action."""
|
||||
# Mock database manager
|
||||
mock_db_manager = Mock()
|
||||
mock_db_manager.execute_query = AsyncMock(return_value=True)
|
||||
|
||||
executor = ActionExecutor(db_manager=mock_db_manager)
|
||||
|
||||
actions = [
|
||||
{
|
||||
"type": "postgresql_insert",
|
||||
"table": "detections",
|
||||
"fields": {
|
||||
"camera_id": "{camera_id}",
|
||||
"session_id": "{session_id}",
|
||||
"detection_class": "{class}",
|
||||
"confidence": "{confidence}",
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = {
|
||||
"camera_id": "camera_001",
|
||||
"session_id": "session_123",
|
||||
"class": "car",
|
||||
"confidence": 0.9
|
||||
}
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].success is True
|
||||
assert results[0].action_type == ActionType.POSTGRESQL_INSERT
|
||||
|
||||
# Verify database call
|
||||
mock_db_manager.execute_query.assert_called_once()
|
||||
call_args = mock_db_manager.execute_query.call_args[0]
|
||||
assert "INSERT INTO detections" in call_args[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_postgresql_update_action(self, mock_database_connection):
|
||||
"""Test executing PostgreSQL update action."""
|
||||
mock_db_manager = Mock()
|
||||
mock_db_manager.execute_query = AsyncMock(return_value=True)
|
||||
|
||||
executor = ActionExecutor(db_manager=mock_db_manager)
|
||||
|
||||
actions = [
|
||||
{
|
||||
"type": "postgresql_update_combined",
|
||||
"table": "car_info",
|
||||
"key_field": "session_id",
|
||||
"fields": {
|
||||
"car_brand": "{car_brand_cls.brand}",
|
||||
"car_body_type": "{car_bodytype_cls.body_type}",
|
||||
"updated_at": "NOW()"
|
||||
},
|
||||
"waitForBranches": ["car_brand_cls", "car_bodytype_cls"]
|
||||
}
|
||||
]
|
||||
|
||||
regions = {}
|
||||
|
||||
context = {
|
||||
"session_id": "session_123"
|
||||
}
|
||||
|
||||
branch_results = {
|
||||
"car_brand_cls": {"brand": "Toyota"},
|
||||
"car_bodytype_cls": {"body_type": "Sedan"}
|
||||
}
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context, branch_results)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].success is True
|
||||
assert results[0].action_type == ActionType.POSTGRESQL_UPDATE
|
||||
|
||||
# Verify database call
|
||||
mock_db_manager.execute_query.assert_called_once()
|
||||
call_args = mock_db_manager.execute_query.call_args[0]
|
||||
assert "UPDATE car_info SET" in call_args[0]
|
||||
assert "WHERE session_id" in call_args[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_file_save_action(self, mock_frame):
|
||||
"""Test executing file save action."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
actions = [
|
||||
{
|
||||
"type": "save_image",
|
||||
"path": "/tmp/test_{camera_id}_{timestamp}.jpg",
|
||||
"region": "car",
|
||||
"format": "jpeg",
|
||||
"quality": 85
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = {
|
||||
"camera_id": "camera_001",
|
||||
"timestamp": "1640995200000",
|
||||
"frame_data": mock_frame
|
||||
}
|
||||
|
||||
with patch('cv2.imwrite') as mock_imwrite:
|
||||
mock_imwrite.return_value = True
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].success is True
|
||||
assert results[0].action_type == ActionType.SAVE_IMAGE
|
||||
|
||||
# Verify file save call
|
||||
mock_imwrite.assert_called_once()
|
||||
call_args = mock_imwrite.call_args
|
||||
assert "/tmp/test_camera_001_1640995200000.jpg" in call_args[0][0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_actions_parallel(self, mock_redis_client):
|
||||
"""Test parallel execution of multiple actions."""
|
||||
executor = ActionExecutor(redis_client=mock_redis_client)
|
||||
|
||||
# Multiple Redis actions
|
||||
actions = [
|
||||
{
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:car:{session_id}",
|
||||
"expire_seconds": 600
|
||||
},
|
||||
{
|
||||
"type": "redis_publish",
|
||||
"channel": "detections",
|
||||
"message": "{camera_id}:car_detected"
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = {
|
||||
"camera_id": "camera_001",
|
||||
"session_id": "session_123",
|
||||
"frame_data": np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
}
|
||||
|
||||
# Mock Redis operations
|
||||
mock_redis_client.set.return_value = True
|
||||
mock_redis_client.expire.return_value = True
|
||||
mock_redis_client.publish.return_value = 1
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
assert len(results) == 2
|
||||
assert all(result.success for result in results)
|
||||
|
||||
# Should execute in parallel (faster than sequential)
|
||||
assert execution_time < 0.1 # Allow some overhead
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_actions_error_handling(self, mock_redis_client):
|
||||
"""Test error handling in action execution."""
|
||||
executor = ActionExecutor(redis_client=mock_redis_client)
|
||||
|
||||
actions = [
|
||||
{
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:{session_id}",
|
||||
"expire_seconds": 600
|
||||
},
|
||||
{
|
||||
"type": "redis_save_image", # This one will fail
|
||||
"region": "truck", # Region not detected
|
||||
"key": "inference:truck:{session_id}",
|
||||
"expire_seconds": 600
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
# No truck region
|
||||
}
|
||||
|
||||
context = {
|
||||
"session_id": "session_123",
|
||||
"frame_data": np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
}
|
||||
|
||||
# Mock Redis operations
|
||||
mock_redis_client.set.return_value = True
|
||||
mock_redis_client.expire.return_value = True
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].success is True # Car action succeeds
|
||||
assert results[1].success is False # Truck action fails
|
||||
assert "Region 'truck' not found" in results[1].error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_actions_timeout(self, mock_redis_client):
|
||||
"""Test action execution timeout."""
|
||||
config = {"action_timeout": 0.001} # Very short timeout
|
||||
executor = ActionExecutor(redis_client=mock_redis_client, config=config)
|
||||
|
||||
def slow_redis_operation(*args, **kwargs):
|
||||
import time
|
||||
time.sleep(1) # Longer than timeout
|
||||
return True
|
||||
|
||||
mock_redis_client.set.side_effect = slow_redis_operation
|
||||
|
||||
actions = [
|
||||
{
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:{session_id}",
|
||||
"expire_seconds": 600
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = {
|
||||
"session_id": "session_123",
|
||||
"frame_data": np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
}
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].success is False
|
||||
assert "timeout" in results[0].error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_redis_publish_action(self, mock_redis_client):
|
||||
"""Test executing Redis publish action."""
|
||||
executor = ActionExecutor(redis_client=mock_redis_client)
|
||||
|
||||
actions = [
|
||||
{
|
||||
"type": "redis_publish",
|
||||
"channel": "detections:{camera_id}",
|
||||
"message": {
|
||||
"camera_id": "{camera_id}",
|
||||
"detection_class": "{class}",
|
||||
"confidence": "{confidence}",
|
||||
"timestamp": "{timestamp}"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = {
|
||||
"camera_id": "camera_001",
|
||||
"class": "car",
|
||||
"confidence": 0.9,
|
||||
"timestamp": "1640995200000"
|
||||
}
|
||||
|
||||
mock_redis_client.publish.return_value = 1
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].success is True
|
||||
assert results[0].action_type == ActionType.REDIS_PUBLISH
|
||||
|
||||
# Verify publish call
|
||||
mock_redis_client.publish.assert_called_once()
|
||||
call_args = mock_redis_client.publish.call_args
|
||||
assert call_args[0][0] == "detections:camera_001" # Channel
|
||||
|
||||
# Message should be JSON
|
||||
message = call_args[0][1]
|
||||
parsed_message = json.loads(message)
|
||||
assert parsed_message["camera_id"] == "camera_001"
|
||||
assert parsed_message["detection_class"] == "car"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_conditional_action(self):
|
||||
"""Test executing conditional actions."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
actions = [
|
||||
{
|
||||
"type": "conditional",
|
||||
"condition": "{confidence} > 0.8",
|
||||
"actions": [
|
||||
{
|
||||
"type": "log",
|
||||
"message": "High confidence detection: {class} ({confidence})"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.95, # High confidence
|
||||
"detection": DetectionResult("car", 0.95, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = {
|
||||
"class": "car",
|
||||
"confidence": 0.95
|
||||
}
|
||||
|
||||
with patch('logging.info') as mock_log:
|
||||
results = await executor.execute_actions(actions, regions, context)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].success is True
|
||||
|
||||
# Should have logged the message
|
||||
mock_log.assert_called_once()
|
||||
log_message = mock_log.call_args[0][0]
|
||||
assert "High confidence detection: car (0.95)" in log_message
|
||||
|
||||
def test_crop_region_from_frame(self, mock_frame):
|
||||
"""Test cropping region from frame."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
|
||||
cropped = executor._crop_region_from_frame(mock_frame, detection.bbox)
|
||||
|
||||
assert cropped.shape == (200, 200, 3) # 400-200, 300-100
|
||||
|
||||
def test_encode_image_base64(self, mock_frame):
|
||||
"""Test encoding image to base64."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
# Crop a small region
|
||||
cropped_frame = mock_frame[200:400, 100:300] # 200x200 region
|
||||
|
||||
with patch('cv2.imencode') as mock_imencode:
|
||||
# Mock successful encoding
|
||||
mock_imencode.return_value = (True, np.array([1, 2, 3, 4], dtype=np.uint8))
|
||||
|
||||
encoded = executor._encode_image_base64(cropped_frame, format="jpeg")
|
||||
|
||||
# Should return base64 string
|
||||
assert isinstance(encoded, str)
|
||||
assert len(encoded) > 0
|
||||
|
||||
# Verify encoding call
|
||||
mock_imencode.assert_called_once()
|
||||
assert mock_imencode.call_args[0][0] == '.jpg'
|
||||
|
||||
def test_build_insert_query(self):
|
||||
"""Test building INSERT SQL query."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
table = "detections"
|
||||
fields = {
|
||||
"camera_id": "camera_001",
|
||||
"detection_class": "car",
|
||||
"confidence": 0.9,
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
|
||||
query, values = executor._build_insert_query(table, fields)
|
||||
|
||||
assert "INSERT INTO detections" in query
|
||||
assert "camera_id, detection_class, confidence, created_at" in query
|
||||
assert "VALUES (%s, %s, %s, NOW())" in query
|
||||
assert values == ["camera_001", "car", 0.9]
|
||||
|
||||
def test_build_update_query(self):
|
||||
"""Test building UPDATE SQL query."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
table = "car_info"
|
||||
fields = {
|
||||
"car_brand": "Toyota",
|
||||
"car_body_type": "Sedan",
|
||||
"updated_at": "NOW()"
|
||||
}
|
||||
key_field = "session_id"
|
||||
key_value = "session_123"
|
||||
|
||||
query, values = executor._build_update_query(table, fields, key_field, key_value)
|
||||
|
||||
assert "UPDATE car_info SET" in query
|
||||
assert "car_brand = %s" in query
|
||||
assert "car_body_type = %s" in query
|
||||
assert "updated_at = NOW()" in query
|
||||
assert "WHERE session_id = %s" in query
|
||||
assert values == ["Toyota", "Sedan", "session_123"]
|
||||
|
||||
def test_evaluate_condition(self):
|
||||
"""Test evaluating conditional expressions."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
context = {
|
||||
"confidence": 0.85,
|
||||
"class": "car",
|
||||
"area": 40000
|
||||
}
|
||||
|
||||
# Simple comparisons
|
||||
assert executor._evaluate_condition("{confidence} > 0.8", context) is True
|
||||
assert executor._evaluate_condition("{confidence} < 0.8", context) is False
|
||||
assert executor._evaluate_condition("{confidence} >= 0.85", context) is True
|
||||
assert executor._evaluate_condition("{confidence} == 0.85", context) is True
|
||||
|
||||
# String comparisons
|
||||
assert executor._evaluate_condition("{class} == 'car'", context) is True
|
||||
assert executor._evaluate_condition("{class} != 'truck'", context) is True
|
||||
|
||||
# Complex conditions
|
||||
assert executor._evaluate_condition("{confidence} > 0.8 and {area} > 30000", context) is True
|
||||
assert executor._evaluate_condition("{confidence} > 0.9 or {area} > 30000", context) is True
|
||||
assert executor._evaluate_condition("{confidence} > 0.9 and {area} < 30000", context) is False
|
||||
|
||||
def test_validate_action_config(self):
|
||||
"""Test action configuration validation."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
# Valid Redis action
|
||||
valid_redis = {
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:{session_id}",
|
||||
"expire_seconds": 600
|
||||
}
|
||||
assert executor._validate_action_config(valid_redis) is True
|
||||
|
||||
# Invalid action (missing required fields)
|
||||
invalid_action = {
|
||||
"type": "redis_save_image"
|
||||
# Missing region and key
|
||||
}
|
||||
with pytest.raises(ActionError):
|
||||
executor._validate_action_config(invalid_action)
|
||||
|
||||
# Unknown action type
|
||||
unknown_action = {
|
||||
"type": "unknown_action_type",
|
||||
"some_field": "value"
|
||||
}
|
||||
with pytest.raises(ActionError):
|
||||
executor._validate_action_config(unknown_action)
|
||||
|
||||
|
||||
class TestActionExecutorIntegration:
|
||||
"""Integration tests for action execution."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_detection_workflow(self, mock_redis_client, mock_frame):
|
||||
"""Test complete detection workflow with multiple actions."""
|
||||
# Mock database manager
|
||||
mock_db_manager = Mock()
|
||||
mock_db_manager.execute_query = AsyncMock(return_value=True)
|
||||
|
||||
executor = ActionExecutor(
|
||||
redis_client=mock_redis_client,
|
||||
db_manager=mock_db_manager
|
||||
)
|
||||
|
||||
# Complete action workflow
|
||||
actions = [
|
||||
# Save cropped image to Redis
|
||||
{
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:{camera_id}:{timestamp}:{session_id}:car",
|
||||
"expire_seconds": 600
|
||||
},
|
||||
# Insert initial detection record
|
||||
{
|
||||
"type": "postgresql_insert",
|
||||
"table": "car_detections",
|
||||
"fields": {
|
||||
"camera_id": "{camera_id}",
|
||||
"session_id": "{session_id}",
|
||||
"detection_class": "{class}",
|
||||
"confidence": "{confidence}",
|
||||
"bbox_x1": "{bbox.x1}",
|
||||
"bbox_y1": "{bbox.y1}",
|
||||
"bbox_x2": "{bbox.x2}",
|
||||
"bbox_y2": "{bbox.y2}",
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
},
|
||||
# Publish detection event
|
||||
{
|
||||
"type": "redis_publish",
|
||||
"channel": "detections:{camera_id}",
|
||||
"message": {
|
||||
"event": "car_detected",
|
||||
"camera_id": "{camera_id}",
|
||||
"session_id": "{session_id}",
|
||||
"timestamp": "{timestamp}"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.92,
|
||||
"detection": DetectionResult("car", 0.92, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = {
|
||||
"camera_id": "camera_001",
|
||||
"session_id": "session_123",
|
||||
"timestamp": "1640995200000",
|
||||
"class": "car",
|
||||
"confidence": 0.92,
|
||||
"bbox": {"x1": 100, "y1": 200, "x2": 300, "y2": 400},
|
||||
"frame_data": mock_frame
|
||||
}
|
||||
|
||||
# Mock all Redis operations
|
||||
mock_redis_client.set.return_value = True
|
||||
mock_redis_client.expire.return_value = True
|
||||
mock_redis_client.publish.return_value = 1
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context)
|
||||
|
||||
# All actions should succeed
|
||||
assert len(results) == 3
|
||||
assert all(result.success for result in results)
|
||||
|
||||
# Verify all operations were called
|
||||
mock_redis_client.set.assert_called_once() # Image save
|
||||
mock_redis_client.expire.assert_called_once() # Set expiry
|
||||
mock_redis_client.publish.assert_called_once() # Publish event
|
||||
mock_db_manager.execute_query.assert_called_once() # Database insert
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_branch_dependent_actions(self, mock_database_connection):
|
||||
"""Test actions that depend on branch results."""
|
||||
mock_db_manager = Mock()
|
||||
mock_db_manager.execute_query = AsyncMock(return_value=True)
|
||||
|
||||
executor = ActionExecutor(db_manager=mock_db_manager)
|
||||
|
||||
# Action that waits for classification branches
|
||||
actions = [
|
||||
{
|
||||
"type": "postgresql_update_combined",
|
||||
"table": "car_info",
|
||||
"key_field": "session_id",
|
||||
"fields": {
|
||||
"car_brand": "{car_brand_cls.brand}",
|
||||
"car_body_type": "{car_bodytype_cls.body_type}",
|
||||
"car_color": "{car_color_cls.color}",
|
||||
"confidence_brand": "{car_brand_cls.confidence}",
|
||||
"confidence_bodytype": "{car_bodytype_cls.confidence}",
|
||||
"updated_at": "NOW()"
|
||||
},
|
||||
"waitForBranches": ["car_brand_cls", "car_bodytype_cls", "car_color_cls"]
|
||||
}
|
||||
]
|
||||
|
||||
regions = {}
|
||||
|
||||
context = {
|
||||
"session_id": "session_123"
|
||||
}
|
||||
|
||||
# Simulated branch results
|
||||
branch_results = {
|
||||
"car_brand_cls": {"brand": "Toyota", "confidence": 0.87},
|
||||
"car_bodytype_cls": {"body_type": "Sedan", "confidence": 0.82},
|
||||
"car_color_cls": {"color": "Red", "confidence": 0.79}
|
||||
}
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context, branch_results)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].success is True
|
||||
assert results[0].action_type == ActionType.POSTGRESQL_UPDATE
|
||||
|
||||
# Verify database call with all branch data
|
||||
mock_db_manager.execute_query.assert_called_once()
|
||||
call_args = mock_db_manager.execute_query.call_args
|
||||
query = call_args[0][0]
|
||||
values = call_args[0][1]
|
||||
|
||||
assert "UPDATE car_info SET" in query
|
||||
assert "car_brand = %s" in query
|
||||
assert "car_body_type = %s" in query
|
||||
assert "car_color = %s" in query
|
||||
assert "WHERE session_id = %s" in query
|
||||
|
||||
assert "Toyota" in values
|
||||
assert "Sedan" in values
|
||||
assert "Red" in values
|
||||
assert "session_123" in values
|
786
tests/unit/pipeline/test_field_mapper.py
Normal file
786
tests/unit/pipeline/test_field_mapper.py
Normal file
|
@ -0,0 +1,786 @@
|
|||
"""
|
||||
Unit tests for field mapping and template resolution.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from detector_worker.pipeline.field_mapper import (
|
||||
FieldMapper,
|
||||
MappingContext,
|
||||
TemplateResolver,
|
||||
FieldMappingError,
|
||||
NestedFieldAccessor
|
||||
)
|
||||
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
|
||||
|
||||
|
||||
class TestNestedFieldAccessor:
|
||||
"""Test nested field access functionality."""
|
||||
|
||||
def test_get_nested_value_simple(self):
|
||||
"""Test getting simple nested values."""
|
||||
data = {
|
||||
"user": {
|
||||
"name": "John",
|
||||
"age": 30,
|
||||
"address": {
|
||||
"city": "New York",
|
||||
"zip": "10001"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessor = NestedFieldAccessor()
|
||||
|
||||
assert accessor.get_nested_value(data, "user.name") == "John"
|
||||
assert accessor.get_nested_value(data, "user.age") == 30
|
||||
assert accessor.get_nested_value(data, "user.address.city") == "New York"
|
||||
assert accessor.get_nested_value(data, "user.address.zip") == "10001"
|
||||
|
||||
def test_get_nested_value_array_access(self):
|
||||
"""Test accessing array elements."""
|
||||
data = {
|
||||
"results": [
|
||||
{"score": 0.9, "label": "car"},
|
||||
{"score": 0.8, "label": "truck"}
|
||||
],
|
||||
"bbox": [100, 200, 300, 400]
|
||||
}
|
||||
|
||||
accessor = NestedFieldAccessor()
|
||||
|
||||
assert accessor.get_nested_value(data, "results[0].score") == 0.9
|
||||
assert accessor.get_nested_value(data, "results[0].label") == "car"
|
||||
assert accessor.get_nested_value(data, "results[1].score") == 0.8
|
||||
assert accessor.get_nested_value(data, "bbox[0]") == 100
|
||||
assert accessor.get_nested_value(data, "bbox[3]") == 400
|
||||
|
||||
def test_get_nested_value_nonexistent_path(self):
|
||||
"""Test accessing non-existent paths."""
|
||||
data = {"user": {"name": "John"}}
|
||||
accessor = NestedFieldAccessor()
|
||||
|
||||
assert accessor.get_nested_value(data, "user.nonexistent") is None
|
||||
assert accessor.get_nested_value(data, "nonexistent.field") is None
|
||||
assert accessor.get_nested_value(data, "user.address.city") is None
|
||||
|
||||
def test_get_nested_value_with_default(self):
|
||||
"""Test getting nested values with default fallback."""
|
||||
data = {"user": {"name": "John"}}
|
||||
accessor = NestedFieldAccessor()
|
||||
|
||||
assert accessor.get_nested_value(data, "user.age", default=25) == 25
|
||||
assert accessor.get_nested_value(data, "user.name", default="Unknown") == "John"
|
||||
|
||||
def test_set_nested_value(self):
|
||||
"""Test setting nested values."""
|
||||
data = {}
|
||||
accessor = NestedFieldAccessor()
|
||||
|
||||
accessor.set_nested_value(data, "user.name", "John")
|
||||
assert data["user"]["name"] == "John"
|
||||
|
||||
accessor.set_nested_value(data, "user.address.city", "New York")
|
||||
assert data["user"]["address"]["city"] == "New York"
|
||||
|
||||
accessor.set_nested_value(data, "scores[0]", 0.95)
|
||||
assert data["scores"][0] == 0.95
|
||||
|
||||
def test_set_nested_value_overwrite(self):
|
||||
"""Test overwriting existing nested values."""
|
||||
data = {"user": {"name": "John", "age": 30}}
|
||||
accessor = NestedFieldAccessor()
|
||||
|
||||
accessor.set_nested_value(data, "user.name", "Jane")
|
||||
assert data["user"]["name"] == "Jane"
|
||||
assert data["user"]["age"] == 30 # Should not affect other fields
|
||||
|
||||
|
||||
class TestTemplateResolver:
|
||||
"""Test template string resolution."""
|
||||
|
||||
def test_resolve_simple_template(self):
|
||||
"""Test resolving simple template variables."""
|
||||
resolver = TemplateResolver()
|
||||
|
||||
template = "Hello {name}, you are {age} years old"
|
||||
context = {"name": "John", "age": 30}
|
||||
|
||||
result = resolver.resolve(template, context)
|
||||
assert result == "Hello John, you are 30 years old"
|
||||
|
||||
def test_resolve_nested_template(self):
|
||||
"""Test resolving nested field templates."""
|
||||
resolver = TemplateResolver()
|
||||
|
||||
template = "User: {user.name} from {user.address.city}"
|
||||
context = {
|
||||
"user": {
|
||||
"name": "John",
|
||||
"address": {"city": "New York", "zip": "10001"}
|
||||
}
|
||||
}
|
||||
|
||||
result = resolver.resolve(template, context)
|
||||
assert result == "User: John from New York"
|
||||
|
||||
def test_resolve_array_template(self):
|
||||
"""Test resolving array element templates."""
|
||||
resolver = TemplateResolver()
|
||||
|
||||
template = "First result: {results[0].label} ({results[0].score})"
|
||||
context = {
|
||||
"results": [
|
||||
{"label": "car", "score": 0.95},
|
||||
{"label": "truck", "score": 0.87}
|
||||
]
|
||||
}
|
||||
|
||||
result = resolver.resolve(template, context)
|
||||
assert result == "First result: car (0.95)"
|
||||
|
||||
def test_resolve_missing_variables(self):
|
||||
"""Test resolving templates with missing variables."""
|
||||
resolver = TemplateResolver()
|
||||
|
||||
template = "Hello {name}, you are {age} years old"
|
||||
context = {"name": "John"} # Missing age
|
||||
|
||||
with pytest.raises(FieldMappingError) as exc_info:
|
||||
resolver.resolve(template, context)
|
||||
|
||||
assert "Variable 'age' not found" in str(exc_info.value)
|
||||
|
||||
def test_resolve_with_defaults(self):
|
||||
"""Test resolving templates with default values."""
|
||||
resolver = TemplateResolver(allow_missing=True)
|
||||
|
||||
template = "Hello {name}, you are {age|25} years old"
|
||||
context = {"name": "John"} # Missing age, should use default
|
||||
|
||||
result = resolver.resolve(template, context)
|
||||
assert result == "Hello John, you are 25 years old"
|
||||
|
||||
def test_resolve_complex_template(self):
|
||||
"""Test resolving complex templates with multiple variable types."""
|
||||
resolver = TemplateResolver()
|
||||
|
||||
template = "{camera_id}:{timestamp}:{session_id}:{results[0].class}_{bbox[0]}_{bbox[1]}"
|
||||
context = {
|
||||
"camera_id": "cam001",
|
||||
"timestamp": 1640995200000,
|
||||
"session_id": "sess123",
|
||||
"results": [{"class": "car", "confidence": 0.95}],
|
||||
"bbox": [100, 200, 300, 400]
|
||||
}
|
||||
|
||||
result = resolver.resolve(template, context)
|
||||
assert result == "cam001:1640995200000:sess123:car_100_200"
|
||||
|
||||
def test_resolve_conditional_template(self):
|
||||
"""Test resolving conditional templates."""
|
||||
resolver = TemplateResolver()
|
||||
|
||||
# Simple conditional
|
||||
template = "{name} is {age > 18 ? 'adult' : 'minor'}"
|
||||
|
||||
context_adult = {"name": "John", "age": 25}
|
||||
result_adult = resolver.resolve(template, context_adult)
|
||||
assert result_adult == "John is adult"
|
||||
|
||||
context_minor = {"name": "Jane", "age": 16}
|
||||
result_minor = resolver.resolve(template, context_minor)
|
||||
assert result_minor == "Jane is minor"
|
||||
|
||||
def test_escape_braces(self):
|
||||
"""Test escaping braces in templates."""
|
||||
resolver = TemplateResolver()
|
||||
|
||||
template = "Literal {{braces}} and variable {name}"
|
||||
context = {"name": "John"}
|
||||
|
||||
result = resolver.resolve(template, context)
|
||||
assert result == "Literal {braces} and variable John"
|
||||
|
||||
|
||||
class TestMappingContext:
|
||||
"""Test mapping context data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test mapping context creation."""
|
||||
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
|
||||
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
detection=detection,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
assert context.camera_id == "camera_001"
|
||||
assert context.display_id == "display_001"
|
||||
assert context.session_id == "session_123"
|
||||
assert context.detection == detection
|
||||
assert context.timestamp == 1640995200000
|
||||
assert context.branch_results == {}
|
||||
assert context.metadata == {}
|
||||
|
||||
def test_add_branch_result(self):
|
||||
"""Test adding branch results to context."""
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123"
|
||||
)
|
||||
|
||||
context.add_branch_result("car_brand_cls", {"brand": "Toyota", "confidence": 0.87})
|
||||
context.add_branch_result("car_bodytype_cls", {"body_type": "Sedan", "confidence": 0.82})
|
||||
|
||||
assert len(context.branch_results) == 2
|
||||
assert context.branch_results["car_brand_cls"]["brand"] == "Toyota"
|
||||
assert context.branch_results["car_bodytype_cls"]["body_type"] == "Sedan"
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test converting context to dictionary."""
|
||||
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
|
||||
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
detection=detection,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
context.add_branch_result("car_brand_cls", {"brand": "Toyota"})
|
||||
context.add_metadata("model_id", "yolo_v8")
|
||||
|
||||
context_dict = context.to_dict()
|
||||
|
||||
assert context_dict["camera_id"] == "camera_001"
|
||||
assert context_dict["display_id"] == "display_001"
|
||||
assert context_dict["session_id"] == "session_123"
|
||||
assert context_dict["timestamp"] == 1640995200000
|
||||
assert context_dict["class"] == "car"
|
||||
assert context_dict["confidence"] == 0.9
|
||||
assert context_dict["track_id"] == 1001
|
||||
assert context_dict["bbox"]["x1"] == 100
|
||||
assert context_dict["car_brand_cls"]["brand"] == "Toyota"
|
||||
assert context_dict["model_id"] == "yolo_v8"
|
||||
|
||||
def test_add_metadata(self):
|
||||
"""Test adding metadata to context."""
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123"
|
||||
)
|
||||
|
||||
context.add_metadata("model_version", "v2.1")
|
||||
context.add_metadata("inference_time", 0.15)
|
||||
|
||||
assert context.metadata["model_version"] == "v2.1"
|
||||
assert context.metadata["inference_time"] == 0.15
|
||||
|
||||
|
||||
class TestFieldMapper:
|
||||
"""Test field mapping functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test field mapper initialization."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
assert isinstance(mapper.template_resolver, TemplateResolver)
|
||||
assert isinstance(mapper.field_accessor, NestedFieldAccessor)
|
||||
|
||||
def test_map_fields_simple(self):
|
||||
"""Test simple field mapping."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
field_mappings = {
|
||||
"camera_id": "{camera_id}",
|
||||
"detection_class": "{class}",
|
||||
"confidence_score": "{confidence}",
|
||||
"track_identifier": "{track_id}"
|
||||
}
|
||||
|
||||
detection = DetectionResult("car", 0.92, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
detection=detection,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
mapped_fields = mapper.map_fields(field_mappings, context)
|
||||
|
||||
assert mapped_fields["camera_id"] == "camera_001"
|
||||
assert mapped_fields["detection_class"] == "car"
|
||||
assert mapped_fields["confidence_score"] == 0.92
|
||||
assert mapped_fields["track_identifier"] == 1001
|
||||
|
||||
def test_map_fields_with_branch_results(self):
|
||||
"""Test field mapping with branch results."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
field_mappings = {
|
||||
"car_brand": "{car_brand_cls.brand}",
|
||||
"car_model": "{car_brand_cls.model}",
|
||||
"body_type": "{car_bodytype_cls.body_type}",
|
||||
"brand_confidence": "{car_brand_cls.confidence}",
|
||||
"combined_info": "{car_brand_cls.brand} {car_bodytype_cls.body_type}"
|
||||
}
|
||||
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123"
|
||||
)
|
||||
|
||||
context.add_branch_result("car_brand_cls", {
|
||||
"brand": "Toyota",
|
||||
"model": "Camry",
|
||||
"confidence": 0.87
|
||||
})
|
||||
context.add_branch_result("car_bodytype_cls", {
|
||||
"body_type": "Sedan",
|
||||
"confidence": 0.82
|
||||
})
|
||||
|
||||
mapped_fields = mapper.map_fields(field_mappings, context)
|
||||
|
||||
assert mapped_fields["car_brand"] == "Toyota"
|
||||
assert mapped_fields["car_model"] == "Camry"
|
||||
assert mapped_fields["body_type"] == "Sedan"
|
||||
assert mapped_fields["brand_confidence"] == 0.87
|
||||
assert mapped_fields["combined_info"] == "Toyota Sedan"
|
||||
|
||||
def test_map_fields_bbox_access(self):
|
||||
"""Test field mapping with bounding box access."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
field_mappings = {
|
||||
"bbox_x1": "{bbox.x1}",
|
||||
"bbox_y1": "{bbox.y1}",
|
||||
"bbox_x2": "{bbox.x2}",
|
||||
"bbox_y2": "{bbox.y2}",
|
||||
"bbox_width": "{bbox.width}",
|
||||
"bbox_height": "{bbox.height}",
|
||||
"bbox_area": "{bbox.area}",
|
||||
"bbox_center_x": "{bbox.center_x}",
|
||||
"bbox_center_y": "{bbox.center_y}"
|
||||
}
|
||||
|
||||
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
detection=detection
|
||||
)
|
||||
|
||||
mapped_fields = mapper.map_fields(field_mappings, context)
|
||||
|
||||
assert mapped_fields["bbox_x1"] == 100
|
||||
assert mapped_fields["bbox_y1"] == 200
|
||||
assert mapped_fields["bbox_x2"] == 300
|
||||
assert mapped_fields["bbox_y2"] == 400
|
||||
assert mapped_fields["bbox_width"] == 200 # 300 - 100
|
||||
assert mapped_fields["bbox_height"] == 200 # 400 - 200
|
||||
assert mapped_fields["bbox_area"] == 40000 # 200 * 200
|
||||
assert mapped_fields["bbox_center_x"] == 200 # (100 + 300) / 2
|
||||
assert mapped_fields["bbox_center_y"] == 300 # (200 + 400) / 2
|
||||
|
||||
def test_map_fields_with_sql_functions(self):
|
||||
"""Test field mapping with SQL function templates."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
field_mappings = {
|
||||
"created_at": "NOW()",
|
||||
"updated_at": "CURRENT_TIMESTAMP",
|
||||
"uuid_field": "UUID()",
|
||||
"json_data": "JSON_OBJECT('class', '{class}', 'confidence', {confidence})"
|
||||
}
|
||||
|
||||
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
detection=detection
|
||||
)
|
||||
|
||||
mapped_fields = mapper.map_fields(field_mappings, context)
|
||||
|
||||
# SQL functions should pass through unchanged
|
||||
assert mapped_fields["created_at"] == "NOW()"
|
||||
assert mapped_fields["updated_at"] == "CURRENT_TIMESTAMP"
|
||||
assert mapped_fields["uuid_field"] == "UUID()"
|
||||
assert mapped_fields["json_data"] == "JSON_OBJECT('class', 'car', 'confidence', 0.9)"
|
||||
|
||||
def test_map_fields_missing_branch_data(self):
|
||||
"""Test field mapping with missing branch data."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
field_mappings = {
|
||||
"car_brand": "{car_brand_cls.brand}",
|
||||
"car_model": "{nonexistent_branch.model}"
|
||||
}
|
||||
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123"
|
||||
)
|
||||
|
||||
# Only add one branch result
|
||||
context.add_branch_result("car_brand_cls", {"brand": "Toyota"})
|
||||
|
||||
with pytest.raises(FieldMappingError) as exc_info:
|
||||
mapper.map_fields(field_mappings, context)
|
||||
|
||||
assert "nonexistent_branch.model" in str(exc_info.value)
|
||||
|
||||
def test_map_fields_with_defaults(self):
|
||||
"""Test field mapping with default values."""
|
||||
mapper = FieldMapper(allow_missing=True)
|
||||
|
||||
field_mappings = {
|
||||
"car_brand": "{car_brand_cls.brand|Unknown}",
|
||||
"car_model": "{car_brand_cls.model|N/A}",
|
||||
"confidence": "{confidence|0.0}"
|
||||
}
|
||||
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123"
|
||||
)
|
||||
|
||||
# Don't add any branch results
|
||||
mapped_fields = mapper.map_fields(field_mappings, context)
|
||||
|
||||
assert mapped_fields["car_brand"] == "Unknown"
|
||||
assert mapped_fields["car_model"] == "N/A"
|
||||
assert mapped_fields["confidence"] == "0.0"
|
||||
|
||||
def test_map_database_fields(self):
|
||||
"""Test mapping fields for database operations."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
# Database field mapping
|
||||
db_field_mappings = {
|
||||
"camera_id": "{camera_id}",
|
||||
"session_id": "{session_id}",
|
||||
"detection_timestamp": "{timestamp}",
|
||||
"object_class": "{class}",
|
||||
"detection_confidence": "{confidence}",
|
||||
"track_id": "{track_id}",
|
||||
"bbox_json": "JSON_OBJECT('x1', {bbox.x1}, 'y1', {bbox.y1}, 'x2', {bbox.x2}, 'y2', {bbox.y2})",
|
||||
"car_brand": "{car_brand_cls.brand}",
|
||||
"car_body_type": "{car_bodytype_cls.body_type}",
|
||||
"license_plate": "{license_ocr.text}",
|
||||
"created_at": "NOW()",
|
||||
"updated_at": "NOW()"
|
||||
}
|
||||
|
||||
detection = DetectionResult("car", 0.93, BoundingBox(150, 250, 350, 450), 2001, 1640995300000)
|
||||
context = MappingContext(
|
||||
camera_id="camera_002",
|
||||
display_id="display_002",
|
||||
session_id="session_456",
|
||||
detection=detection,
|
||||
timestamp=1640995300000
|
||||
)
|
||||
|
||||
# Add branch results
|
||||
context.add_branch_result("car_brand_cls", {"brand": "Honda", "confidence": 0.89})
|
||||
context.add_branch_result("car_bodytype_cls", {"body_type": "SUV", "confidence": 0.85})
|
||||
context.add_branch_result("license_ocr", {"text": "ABC-123", "confidence": 0.76})
|
||||
|
||||
mapped_fields = mapper.map_fields(db_field_mappings, context)
|
||||
|
||||
assert mapped_fields["camera_id"] == "camera_002"
|
||||
assert mapped_fields["session_id"] == "session_456"
|
||||
assert mapped_fields["detection_timestamp"] == 1640995300000
|
||||
assert mapped_fields["object_class"] == "car"
|
||||
assert mapped_fields["detection_confidence"] == 0.93
|
||||
assert mapped_fields["track_id"] == 2001
|
||||
assert mapped_fields["bbox_json"] == "JSON_OBJECT('x1', 150, 'y1', 250, 'x2', 350, 'y2', 450)"
|
||||
assert mapped_fields["car_brand"] == "Honda"
|
||||
assert mapped_fields["car_body_type"] == "SUV"
|
||||
assert mapped_fields["license_plate"] == "ABC-123"
|
||||
assert mapped_fields["created_at"] == "NOW()"
|
||||
assert mapped_fields["updated_at"] == "NOW()"
|
||||
|
||||
def test_map_redis_keys(self):
|
||||
"""Test mapping Redis key templates."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
key_templates = [
|
||||
"inference:{camera_id}:{timestamp}:{session_id}:car",
|
||||
"detection:{display_id}:{track_id}",
|
||||
"cropped_image:{camera_id}:{session_id}:{class}",
|
||||
"metadata:{session_id}:brands:{car_brand_cls.brand}",
|
||||
"tracking:{camera_id}:active_tracks"
|
||||
]
|
||||
|
||||
detection = DetectionResult("car", 0.88, BoundingBox(200, 300, 400, 500), 3001, 1640995400000)
|
||||
context = MappingContext(
|
||||
camera_id="camera_003",
|
||||
display_id="display_003",
|
||||
session_id="session_789",
|
||||
detection=detection,
|
||||
timestamp=1640995400000
|
||||
)
|
||||
|
||||
context.add_branch_result("car_brand_cls", {"brand": "Ford"})
|
||||
|
||||
mapped_keys = [mapper.map_template(template, context) for template in key_templates]
|
||||
|
||||
expected_keys = [
|
||||
"inference:camera_003:1640995400000:session_789:car",
|
||||
"detection:display_003:3001",
|
||||
"cropped_image:camera_003:session_789:car",
|
||||
"metadata:session_789:brands:Ford",
|
||||
"tracking:camera_003:active_tracks"
|
||||
]
|
||||
|
||||
assert mapped_keys == expected_keys
|
||||
|
||||
def test_map_template(self):
|
||||
"""Test single template mapping."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
template = "Camera {camera_id} detected {class} with {confidence:.2f} confidence at {timestamp}"
|
||||
|
||||
detection = DetectionResult("truck", 0.876, BoundingBox(100, 150, 300, 350), 4001, 1640995500000)
|
||||
context = MappingContext(
|
||||
camera_id="camera_004",
|
||||
display_id="display_004",
|
||||
session_id="session_101",
|
||||
detection=detection,
|
||||
timestamp=1640995500000
|
||||
)
|
||||
|
||||
result = mapper.map_template(template, context)
|
||||
expected = "Camera camera_004 detected truck with 0.88 confidence at 1640995500000"
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_validate_field_mappings(self):
|
||||
"""Test field mapping validation."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
# Valid mappings
|
||||
valid_mappings = {
|
||||
"camera_id": "{camera_id}",
|
||||
"class": "{class}",
|
||||
"confidence": "{confidence}",
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
|
||||
assert mapper.validate_field_mappings(valid_mappings) is True
|
||||
|
||||
# Invalid mappings (malformed templates)
|
||||
invalid_mappings = {
|
||||
"camera_id": "{camera_id", # Missing closing brace
|
||||
"class": "class}", # Missing opening brace
|
||||
"confidence": "{nonexistent_field}" # This might be valid depending on context
|
||||
}
|
||||
|
||||
with pytest.raises(FieldMappingError):
|
||||
mapper.validate_field_mappings(invalid_mappings)
|
||||
|
||||
def test_create_context_from_detection(self):
|
||||
"""Test creating mapping context from detection result."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
detection = DetectionResult("car", 0.95, BoundingBox(50, 100, 250, 300), 5001, 1640995600000)
|
||||
|
||||
context = mapper.create_context_from_detection(
|
||||
detection,
|
||||
camera_id="camera_005",
|
||||
display_id="display_005",
|
||||
session_id="session_202"
|
||||
)
|
||||
|
||||
assert context.camera_id == "camera_005"
|
||||
assert context.display_id == "display_005"
|
||||
assert context.session_id == "session_202"
|
||||
assert context.detection == detection
|
||||
assert context.timestamp == 1640995600000
|
||||
|
||||
def test_format_sql_value(self):
|
||||
"""Test SQL value formatting."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
# String values should be quoted
|
||||
assert mapper.format_sql_value("test_string") == "'test_string'"
|
||||
assert mapper.format_sql_value("John's car") == "'John''s car'" # Escape quotes
|
||||
|
||||
# Numeric values should not be quoted
|
||||
assert mapper.format_sql_value(42) == "42"
|
||||
assert mapper.format_sql_value(3.14) == "3.14"
|
||||
assert mapper.format_sql_value(0.95) == "0.95"
|
||||
|
||||
# Boolean values
|
||||
assert mapper.format_sql_value(True) == "TRUE"
|
||||
assert mapper.format_sql_value(False) == "FALSE"
|
||||
|
||||
# None/NULL values
|
||||
assert mapper.format_sql_value(None) == "NULL"
|
||||
|
||||
# SQL functions should pass through
|
||||
assert mapper.format_sql_value("NOW()") == "NOW()"
|
||||
assert mapper.format_sql_value("CURRENT_TIMESTAMP") == "CURRENT_TIMESTAMP"
|
||||
|
||||
|
||||
class TestFieldMapperIntegration:
|
||||
"""Integration tests for field mapping."""
|
||||
|
||||
def test_complete_mapping_workflow(self):
|
||||
"""Test complete field mapping workflow."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
# Simulate complete detection workflow
|
||||
detection = DetectionResult("car", 0.91, BoundingBox(120, 180, 320, 380), 6001, 1640995700000)
|
||||
context = MappingContext(
|
||||
camera_id="camera_006",
|
||||
display_id="display_006",
|
||||
session_id="session_303",
|
||||
detection=detection,
|
||||
timestamp=1640995700000
|
||||
)
|
||||
|
||||
# Add comprehensive branch results
|
||||
context.add_branch_result("car_brand_cls", {
|
||||
"brand": "BMW",
|
||||
"model": "X5",
|
||||
"confidence": 0.84,
|
||||
"top3_brands": ["BMW", "Audi", "Mercedes"]
|
||||
})
|
||||
|
||||
context.add_branch_result("car_bodytype_cls", {
|
||||
"body_type": "SUV",
|
||||
"confidence": 0.79,
|
||||
"features": ["tall", "4_doors", "roof_rails"]
|
||||
})
|
||||
|
||||
context.add_branch_result("car_color_cls", {
|
||||
"color": "Black",
|
||||
"confidence": 0.73,
|
||||
"rgb_values": [20, 25, 30]
|
||||
})
|
||||
|
||||
context.add_branch_result("license_ocr", {
|
||||
"text": "XYZ-789",
|
||||
"confidence": 0.68,
|
||||
"region_bbox": [150, 320, 290, 360]
|
||||
})
|
||||
|
||||
# Database field mapping
|
||||
db_mappings = {
|
||||
"camera_id": "{camera_id}",
|
||||
"display_id": "{display_id}",
|
||||
"session_id": "{session_id}",
|
||||
"detection_timestamp": "{timestamp}",
|
||||
"object_class": "{class}",
|
||||
"detection_confidence": "{confidence}",
|
||||
"track_id": "{track_id}",
|
||||
"bbox_x1": "{bbox.x1}",
|
||||
"bbox_y1": "{bbox.y1}",
|
||||
"bbox_x2": "{bbox.x2}",
|
||||
"bbox_y2": "{bbox.y2}",
|
||||
"bbox_area": "{bbox.area}",
|
||||
"car_brand": "{car_brand_cls.brand}",
|
||||
"car_model": "{car_brand_cls.model}",
|
||||
"car_body_type": "{car_bodytype_cls.body_type}",
|
||||
"car_color": "{car_color_cls.color}",
|
||||
"license_plate": "{license_ocr.text}",
|
||||
"brand_confidence": "{car_brand_cls.confidence}",
|
||||
"bodytype_confidence": "{car_bodytype_cls.confidence}",
|
||||
"color_confidence": "{car_color_cls.confidence}",
|
||||
"license_confidence": "{license_ocr.confidence}",
|
||||
"detection_summary": "{car_brand_cls.brand} {car_bodytype_cls.body_type} ({car_color_cls.color})",
|
||||
"created_at": "NOW()",
|
||||
"updated_at": "NOW()"
|
||||
}
|
||||
|
||||
mapped_db_fields = mapper.map_fields(db_mappings, context)
|
||||
|
||||
# Verify all mappings
|
||||
assert mapped_db_fields["camera_id"] == "camera_006"
|
||||
assert mapped_db_fields["session_id"] == "session_303"
|
||||
assert mapped_db_fields["object_class"] == "car"
|
||||
assert mapped_db_fields["detection_confidence"] == 0.91
|
||||
assert mapped_db_fields["track_id"] == 6001
|
||||
assert mapped_db_fields["bbox_area"] == 40000 # 200 * 200
|
||||
assert mapped_db_fields["car_brand"] == "BMW"
|
||||
assert mapped_db_fields["car_model"] == "X5"
|
||||
assert mapped_db_fields["car_body_type"] == "SUV"
|
||||
assert mapped_db_fields["car_color"] == "Black"
|
||||
assert mapped_db_fields["license_plate"] == "XYZ-789"
|
||||
assert mapped_db_fields["detection_summary"] == "BMW SUV (Black)"
|
||||
|
||||
# Redis key mapping
|
||||
redis_key_templates = [
|
||||
"detection:{camera_id}:{session_id}:main",
|
||||
"cropped:{camera_id}:{session_id}:car_image",
|
||||
"metadata:{session_id}:brand:{car_brand_cls.brand}",
|
||||
"tracking:{camera_id}:track_{track_id}",
|
||||
"classification:{session_id}:results"
|
||||
]
|
||||
|
||||
mapped_redis_keys = [
|
||||
mapper.map_template(template, context)
|
||||
for template in redis_key_templates
|
||||
]
|
||||
|
||||
expected_redis_keys = [
|
||||
"detection:camera_006:session_303:main",
|
||||
"cropped:camera_006:session_303:car_image",
|
||||
"metadata:session_303:brand:BMW",
|
||||
"tracking:camera_006:track_6001",
|
||||
"classification:session_303:results"
|
||||
]
|
||||
|
||||
assert mapped_redis_keys == expected_redis_keys
|
||||
|
||||
def test_error_handling_and_recovery(self):
|
||||
"""Test error handling and recovery in field mapping."""
|
||||
mapper = FieldMapper(allow_missing=True)
|
||||
|
||||
# Context with missing detection
|
||||
context = MappingContext(
|
||||
camera_id="camera_007",
|
||||
display_id="display_007",
|
||||
session_id="session_404"
|
||||
)
|
||||
|
||||
# Partial branch results
|
||||
context.add_branch_result("car_brand_cls", {"brand": "Unknown"})
|
||||
# Missing car_bodytype_cls branch
|
||||
|
||||
# Field mappings with some missing data
|
||||
mappings = {
|
||||
"camera_id": "{camera_id}",
|
||||
"detection_class": "{class|Unknown}",
|
||||
"confidence": "{confidence|0.0}",
|
||||
"car_brand": "{car_brand_cls.brand|N/A}",
|
||||
"car_body_type": "{car_bodytype_cls.body_type|Unknown}",
|
||||
"car_model": "{car_brand_cls.model|N/A}"
|
||||
}
|
||||
|
||||
mapped_fields = mapper.map_fields(mappings, context)
|
||||
|
||||
assert mapped_fields["camera_id"] == "camera_007"
|
||||
assert mapped_fields["detection_class"] == "Unknown"
|
||||
assert mapped_fields["confidence"] == "0.0"
|
||||
assert mapped_fields["car_brand"] == "Unknown"
|
||||
assert mapped_fields["car_body_type"] == "Unknown"
|
||||
assert mapped_fields["car_model"] == "N/A"
|
921
tests/unit/pipeline/test_pipeline_executor.py
Normal file
921
tests/unit/pipeline/test_pipeline_executor.py
Normal file
|
@ -0,0 +1,921 @@
|
|||
"""
|
||||
Unit tests for pipeline execution functionality.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, MagicMock, patch, AsyncMock
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import json
|
||||
|
||||
from detector_worker.pipeline.pipeline_executor import (
|
||||
PipelineExecutor,
|
||||
PipelineContext,
|
||||
PipelineResult,
|
||||
BranchResult,
|
||||
ExecutionMode
|
||||
)
|
||||
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
|
||||
from detector_worker.core.exceptions import PipelineError, ModelError, ActionError
|
||||
|
||||
|
||||
class TestPipelineContext:
|
||||
"""Test pipeline context data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test pipeline context creation."""
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
)
|
||||
|
||||
assert context.camera_id == "camera_001"
|
||||
assert context.display_id == "display_001"
|
||||
assert context.session_id == "session_123"
|
||||
assert context.timestamp == 1640995200000
|
||||
assert context.frame_data.shape == (480, 640, 3)
|
||||
assert context.metadata == {}
|
||||
assert context.crop_region is None
|
||||
|
||||
def test_creation_with_crop_region(self):
|
||||
"""Test context creation with crop region."""
|
||||
crop_region = (100, 200, 300, 400)
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=np.zeros((480, 640, 3), dtype=np.uint8),
|
||||
crop_region=crop_region
|
||||
)
|
||||
|
||||
assert context.crop_region == crop_region
|
||||
|
||||
def test_add_metadata(self):
|
||||
"""Test adding metadata to context."""
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
)
|
||||
|
||||
context.add_metadata("model_id", "yolo_v8")
|
||||
context.add_metadata("confidence_threshold", 0.8)
|
||||
|
||||
assert context.metadata["model_id"] == "yolo_v8"
|
||||
assert context.metadata["confidence_threshold"] == 0.8
|
||||
|
||||
def test_get_cropped_frame(self):
|
||||
"""Test getting cropped frame."""
|
||||
frame = np.ones((480, 640, 3), dtype=np.uint8) * 255
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=frame,
|
||||
crop_region=(100, 200, 300, 400)
|
||||
)
|
||||
|
||||
cropped = context.get_cropped_frame()
|
||||
|
||||
assert cropped.shape == (200, 200, 3) # 400-200, 300-100
|
||||
assert np.all(cropped == 255)
|
||||
|
||||
def test_get_cropped_frame_no_crop(self):
|
||||
"""Test getting frame when no crop region specified."""
|
||||
frame = np.ones((480, 640, 3), dtype=np.uint8) * 255
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=frame
|
||||
)
|
||||
|
||||
cropped = context.get_cropped_frame()
|
||||
|
||||
assert np.array_equal(cropped, frame)
|
||||
|
||||
|
||||
class TestBranchResult:
|
||||
"""Test branch execution result."""
|
||||
|
||||
def test_creation_success(self):
|
||||
"""Test successful branch result creation."""
|
||||
detections = [
|
||||
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
|
||||
]
|
||||
|
||||
result = BranchResult(
|
||||
branch_id="car_brand_cls",
|
||||
success=True,
|
||||
detections=detections,
|
||||
metadata={"brand": "Toyota"},
|
||||
execution_time=0.15
|
||||
)
|
||||
|
||||
assert result.branch_id == "car_brand_cls"
|
||||
assert result.success is True
|
||||
assert len(result.detections) == 1
|
||||
assert result.metadata["brand"] == "Toyota"
|
||||
assert result.execution_time == 0.15
|
||||
assert result.error is None
|
||||
|
||||
def test_creation_failure(self):
|
||||
"""Test failed branch result creation."""
|
||||
result = BranchResult(
|
||||
branch_id="car_brand_cls",
|
||||
success=False,
|
||||
error="Model inference failed",
|
||||
execution_time=0.05
|
||||
)
|
||||
|
||||
assert result.branch_id == "car_brand_cls"
|
||||
assert result.success is False
|
||||
assert result.detections == []
|
||||
assert result.metadata == {}
|
||||
assert result.error == "Model inference failed"
|
||||
|
||||
|
||||
class TestPipelineResult:
|
||||
"""Test pipeline execution result."""
|
||||
|
||||
def test_creation_success(self):
|
||||
"""Test successful pipeline result creation."""
|
||||
main_detections = [
|
||||
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
|
||||
]
|
||||
|
||||
branch_results = {
|
||||
"car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1),
|
||||
"car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12)
|
||||
}
|
||||
|
||||
result = PipelineResult(
|
||||
success=True,
|
||||
detections=main_detections,
|
||||
branch_results=branch_results,
|
||||
total_execution_time=0.5
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.detections) == 1
|
||||
assert len(result.branch_results) == 2
|
||||
assert result.total_execution_time == 0.5
|
||||
assert result.error is None
|
||||
|
||||
def test_creation_failure(self):
|
||||
"""Test failed pipeline result creation."""
|
||||
result = PipelineResult(
|
||||
success=False,
|
||||
error="Pipeline execution failed",
|
||||
total_execution_time=0.1
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.detections == []
|
||||
assert result.branch_results == {}
|
||||
assert result.error == "Pipeline execution failed"
|
||||
|
||||
def test_get_combined_results(self):
|
||||
"""Test getting combined results from all branches."""
|
||||
main_detections = [
|
||||
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
|
||||
]
|
||||
|
||||
branch_results = {
|
||||
"car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1),
|
||||
"car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12)
|
||||
}
|
||||
|
||||
result = PipelineResult(
|
||||
success=True,
|
||||
detections=main_detections,
|
||||
branch_results=branch_results,
|
||||
total_execution_time=0.5
|
||||
)
|
||||
|
||||
combined = result.get_combined_results()
|
||||
|
||||
assert "brand" in combined
|
||||
assert "body_type" in combined
|
||||
assert combined["brand"] == "Toyota"
|
||||
assert combined["body_type"] == "Sedan"
|
||||
|
||||
|
||||
class TestPipelineExecutor:
|
||||
"""Test pipeline execution functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test pipeline executor initialization."""
|
||||
executor = PipelineExecutor()
|
||||
|
||||
assert isinstance(executor.thread_pool, ThreadPoolExecutor)
|
||||
assert executor.max_workers == 4
|
||||
assert executor.execution_mode == ExecutionMode.PARALLEL
|
||||
assert executor.timeout == 30.0
|
||||
|
||||
def test_initialization_custom_config(self):
|
||||
"""Test initialization with custom configuration."""
|
||||
config = {
|
||||
"max_workers": 8,
|
||||
"execution_mode": "sequential",
|
||||
"timeout": 60.0
|
||||
}
|
||||
|
||||
executor = PipelineExecutor(config)
|
||||
|
||||
assert executor.max_workers == 8
|
||||
assert executor.execution_mode == ExecutionMode.SEQUENTIAL
|
||||
assert executor.timeout == 60.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_pipeline_simple(self, mock_yolo_model, mock_frame):
|
||||
"""Test simple pipeline execution."""
|
||||
# Mock pipeline configuration
|
||||
pipeline_config = {
|
||||
"modelId": "car_detection_v1",
|
||||
"modelFile": "car_detection.pt",
|
||||
"expectedClasses": ["car"],
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.8,
|
||||
"branches": [],
|
||||
"actions": []
|
||||
}
|
||||
|
||||
# Mock detection result
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0]
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
executor = PipelineExecutor()
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
|
||||
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
|
||||
|
||||
result = await executor.execute_pipeline(pipeline_config, context)
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.detections) == 1
|
||||
assert result.detections[0].class_name == "0" # Default class name
|
||||
assert result.detections[0].confidence == 0.9
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_pipeline_with_branches(self, mock_yolo_model, mock_frame):
|
||||
"""Test pipeline execution with classification branches."""
|
||||
import torch
|
||||
|
||||
# Mock main detection
|
||||
mock_detection_result = Mock()
|
||||
mock_detection_result.boxes = Mock()
|
||||
mock_detection_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0] # car detection
|
||||
])
|
||||
mock_detection_result.boxes.id = torch.tensor([1001])
|
||||
|
||||
# Mock classification results
|
||||
mock_brand_result = Mock()
|
||||
mock_brand_result.probs = Mock()
|
||||
mock_brand_result.probs.top1 = 2 # Toyota
|
||||
mock_brand_result.probs.top1conf = 0.85
|
||||
|
||||
mock_bodytype_result = Mock()
|
||||
mock_bodytype_result.probs = Mock()
|
||||
mock_bodytype_result.probs.top1 = 1 # Sedan
|
||||
mock_bodytype_result.probs.top1conf = 0.78
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_detection_result]
|
||||
mock_yolo_model.predict.return_value = [mock_brand_result]
|
||||
|
||||
mock_brand_model = Mock()
|
||||
mock_brand_model.predict.return_value = [mock_brand_result]
|
||||
mock_brand_model.names = {0: "Honda", 1: "Ford", 2: "Toyota"}
|
||||
|
||||
mock_bodytype_model = Mock()
|
||||
mock_bodytype_model.predict.return_value = [mock_bodytype_result]
|
||||
mock_bodytype_model.names = {0: "SUV", 1: "Sedan", 2: "Hatchback"}
|
||||
|
||||
# Pipeline configuration with branches
|
||||
pipeline_config = {
|
||||
"modelId": "car_detection_v1",
|
||||
"modelFile": "car_detection.pt",
|
||||
"expectedClasses": ["car"],
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.8,
|
||||
"branches": [
|
||||
{
|
||||
"modelId": "car_brand_cls",
|
||||
"modelFile": "car_brand.pt",
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.7,
|
||||
"parallel": True,
|
||||
"crop": True,
|
||||
"cropClass": "car"
|
||||
},
|
||||
{
|
||||
"modelId": "car_bodytype_cls",
|
||||
"modelFile": "car_bodytype.pt",
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.7,
|
||||
"parallel": True,
|
||||
"crop": True,
|
||||
"cropClass": "car"
|
||||
}
|
||||
],
|
||||
"actions": []
|
||||
}
|
||||
|
||||
executor = PipelineExecutor()
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
|
||||
def get_model_side_effect(model_id, camera_id):
|
||||
if model_id == "car_detection_v1":
|
||||
return mock_yolo_model
|
||||
elif model_id == "car_brand_cls":
|
||||
return mock_brand_model
|
||||
elif model_id == "car_bodytype_cls":
|
||||
return mock_bodytype_model
|
||||
return None
|
||||
|
||||
mock_model_manager.return_value.get_model.side_effect = get_model_side_effect
|
||||
|
||||
result = await executor.execute_pipeline(pipeline_config, context)
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.detections) == 1
|
||||
assert len(result.branch_results) == 2
|
||||
|
||||
# Check branch results
|
||||
assert "car_brand_cls" in result.branch_results
|
||||
assert "car_bodytype_cls" in result.branch_results
|
||||
|
||||
brand_result = result.branch_results["car_brand_cls"]
|
||||
assert brand_result.success is True
|
||||
assert brand_result.metadata.get("brand") == "Toyota"
|
||||
|
||||
bodytype_result = result.branch_results["car_bodytype_cls"]
|
||||
assert bodytype_result.success is True
|
||||
assert bodytype_result.metadata.get("body_type") == "Sedan"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_pipeline_sequential_mode(self, mock_yolo_model, mock_frame):
|
||||
"""Test pipeline execution in sequential mode."""
|
||||
import torch
|
||||
|
||||
config = {"execution_mode": "sequential"}
|
||||
executor = PipelineExecutor(config)
|
||||
|
||||
# Mock detection result
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0]
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
pipeline_config = {
|
||||
"modelId": "car_detection_v1",
|
||||
"modelFile": "car_detection.pt",
|
||||
"expectedClasses": ["car"],
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.8,
|
||||
"branches": [
|
||||
{
|
||||
"modelId": "car_brand_cls",
|
||||
"modelFile": "car_brand.pt",
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.7,
|
||||
"parallel": False # Sequential execution
|
||||
}
|
||||
],
|
||||
"actions": []
|
||||
}
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
|
||||
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
|
||||
|
||||
result = await executor.execute_pipeline(pipeline_config, context)
|
||||
|
||||
assert result.success is True
|
||||
assert executor.execution_mode == ExecutionMode.SEQUENTIAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_pipeline_with_actions(self, mock_yolo_model, mock_frame):
|
||||
"""Test pipeline execution with actions."""
|
||||
import torch
|
||||
|
||||
# Mock detection result
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0]
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
# Pipeline configuration with actions
|
||||
pipeline_config = {
|
||||
"modelId": "car_detection_v1",
|
||||
"modelFile": "car_detection.pt",
|
||||
"expectedClasses": ["car"],
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.8,
|
||||
"branches": [],
|
||||
"actions": [
|
||||
{
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:{display_id}:{timestamp}:{session_id}",
|
||||
"expire_seconds": 600
|
||||
},
|
||||
{
|
||||
"type": "postgresql_insert",
|
||||
"table": "detections",
|
||||
"fields": {
|
||||
"camera_id": "{camera_id}",
|
||||
"detection_class": "{class}",
|
||||
"confidence": "{confidence}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
executor = PipelineExecutor()
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager, \
|
||||
patch('detector_worker.pipeline.action_executor.ActionExecutor') as mock_action_executor:
|
||||
|
||||
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
|
||||
mock_action_executor.return_value.execute_actions = AsyncMock(return_value=True)
|
||||
|
||||
result = await executor.execute_pipeline(pipeline_config, context)
|
||||
|
||||
assert result.success is True
|
||||
# Actions should be executed
|
||||
mock_action_executor.return_value.execute_actions.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_pipeline_model_error(self, mock_frame):
|
||||
"""Test pipeline execution with model error."""
|
||||
pipeline_config = {
|
||||
"modelId": "car_detection_v1",
|
||||
"modelFile": "car_detection.pt",
|
||||
"expectedClasses": ["car"],
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.8,
|
||||
"branches": [],
|
||||
"actions": []
|
||||
}
|
||||
|
||||
executor = PipelineExecutor()
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
|
||||
# Model manager raises error
|
||||
mock_model_manager.return_value.get_model.side_effect = ModelError("Model not found")
|
||||
|
||||
result = await executor.execute_pipeline(pipeline_config, context)
|
||||
|
||||
assert result.success is False
|
||||
assert "Model not found" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_pipeline_timeout(self, mock_yolo_model, mock_frame):
|
||||
"""Test pipeline execution timeout."""
|
||||
import torch
|
||||
|
||||
# Configure short timeout
|
||||
config = {"timeout": 0.001} # Very short timeout
|
||||
executor = PipelineExecutor(config)
|
||||
|
||||
# Mock slow model inference
|
||||
def slow_inference(*args, **kwargs):
|
||||
import time
|
||||
time.sleep(1) # Longer than timeout
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = None
|
||||
return [mock_result]
|
||||
|
||||
mock_yolo_model.track.side_effect = slow_inference
|
||||
|
||||
pipeline_config = {
|
||||
"modelId": "car_detection_v1",
|
||||
"modelFile": "car_detection.pt",
|
||||
"expectedClasses": ["car"],
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.8,
|
||||
"branches": [],
|
||||
"actions": []
|
||||
}
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
|
||||
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
|
||||
|
||||
result = await executor.execute_pipeline(pipeline_config, context)
|
||||
|
||||
assert result.success is False
|
||||
assert "timeout" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_branch_parallel(self, mock_frame):
|
||||
"""Test parallel branch execution."""
|
||||
import torch
|
||||
|
||||
# Mock classification model
|
||||
mock_brand_model = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.probs = Mock()
|
||||
mock_result.probs.top1 = 1
|
||||
mock_result.probs.top1conf = 0.85
|
||||
mock_brand_model.predict.return_value = [mock_result]
|
||||
mock_brand_model.names = {0: "Honda", 1: "Toyota", 2: "Ford"}
|
||||
|
||||
executor = PipelineExecutor()
|
||||
|
||||
# Branch configuration
|
||||
branch_config = {
|
||||
"modelId": "car_brand_cls",
|
||||
"modelFile": "car_brand.pt",
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.7,
|
||||
"parallel": True,
|
||||
"crop": True,
|
||||
"cropClass": "car"
|
||||
}
|
||||
|
||||
# Mock detected regions
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
|
||||
mock_model_manager.return_value.get_model.return_value = mock_brand_model
|
||||
|
||||
result = await executor._execute_branch(branch_config, regions, context)
|
||||
|
||||
assert result.success is True
|
||||
assert result.branch_id == "car_brand_cls"
|
||||
assert result.metadata.get("brand") == "Toyota"
|
||||
assert result.execution_time > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_branch_no_trigger_class(self, mock_frame):
|
||||
"""Test branch execution when trigger class not detected."""
|
||||
executor = PipelineExecutor()
|
||||
|
||||
branch_config = {
|
||||
"modelId": "car_brand_cls",
|
||||
"modelFile": "car_brand.pt",
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.7
|
||||
}
|
||||
|
||||
# No car detected
|
||||
regions = {
|
||||
"truck": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("truck", 0.9, BoundingBox(100, 200, 300, 400), 1002)
|
||||
}
|
||||
}
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
result = await executor._execute_branch(branch_config, regions, context)
|
||||
|
||||
assert result.success is False
|
||||
assert "trigger class not detected" in result.error.lower()
|
||||
|
||||
def test_wait_for_branches(self):
|
||||
"""Test waiting for specific branches to complete."""
|
||||
executor = PipelineExecutor()
|
||||
|
||||
# Mock completed branch results
|
||||
branch_results = {
|
||||
"car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1),
|
||||
"car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12),
|
||||
"license_ocr": BranchResult("license_ocr", True, [], {"license": "ABC123"}, 0.2)
|
||||
}
|
||||
|
||||
# Wait for specific branches
|
||||
wait_for = ["car_brand_cls", "car_bodytype_cls"]
|
||||
completed = executor._wait_for_branches(branch_results, wait_for, timeout=1.0)
|
||||
|
||||
assert completed is True
|
||||
|
||||
# Wait for non-existent branch (should timeout)
|
||||
wait_for_missing = ["car_brand_cls", "nonexistent_branch"]
|
||||
completed = executor._wait_for_branches(branch_results, wait_for_missing, timeout=0.1)
|
||||
|
||||
assert completed is False
|
||||
|
||||
def test_validate_pipeline_config(self):
|
||||
"""Test pipeline configuration validation."""
|
||||
executor = PipelineExecutor()
|
||||
|
||||
# Valid configuration
|
||||
valid_config = {
|
||||
"modelId": "car_detection_v1",
|
||||
"modelFile": "car_detection.pt",
|
||||
"expectedClasses": ["car"],
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.8
|
||||
}
|
||||
|
||||
assert executor._validate_pipeline_config(valid_config) is True
|
||||
|
||||
# Invalid configuration (missing required fields)
|
||||
invalid_config = {
|
||||
"modelFile": "car_detection.pt"
|
||||
# Missing modelId
|
||||
}
|
||||
|
||||
with pytest.raises(PipelineError):
|
||||
executor._validate_pipeline_config(invalid_config)
|
||||
|
||||
def test_crop_frame_for_detection(self, mock_frame):
|
||||
"""Test frame cropping for detection."""
|
||||
executor = PipelineExecutor()
|
||||
|
||||
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
|
||||
cropped = executor._crop_frame_for_detection(mock_frame, detection)
|
||||
|
||||
assert cropped.shape == (200, 200, 3) # 400-200, 300-100
|
||||
|
||||
def test_crop_frame_invalid_bounds(self, mock_frame):
|
||||
"""Test frame cropping with invalid bounds."""
|
||||
executor = PipelineExecutor()
|
||||
|
||||
# Detection outside frame bounds
|
||||
detection = DetectionResult("car", 0.9, BoundingBox(-100, -200, 50, 100), 1001)
|
||||
|
||||
cropped = executor._crop_frame_for_detection(mock_frame, detection)
|
||||
|
||||
# Should handle bounds gracefully
|
||||
assert cropped.shape[0] > 0
|
||||
assert cropped.shape[1] > 0
|
||||
|
||||
|
||||
class TestPipelineExecutorPerformance:
|
||||
"""Test pipeline executor performance and optimization."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_branch_execution_performance(self, mock_frame):
|
||||
"""Test that parallel execution is faster than sequential."""
|
||||
import time
|
||||
import torch
|
||||
|
||||
def slow_inference(*args, **kwargs):
|
||||
time.sleep(0.1) # Simulate slow inference
|
||||
mock_result = Mock()
|
||||
mock_result.probs = Mock()
|
||||
mock_result.probs.top1 = 1
|
||||
mock_result.probs.top1conf = 0.85
|
||||
return [mock_result]
|
||||
|
||||
mock_model = Mock()
|
||||
mock_model.predict.side_effect = slow_inference
|
||||
mock_model.names = {0: "Class0", 1: "Class1"}
|
||||
|
||||
# Test parallel execution
|
||||
parallel_executor = PipelineExecutor({"execution_mode": "parallel", "max_workers": 2})
|
||||
|
||||
branch_configs = [
|
||||
{
|
||||
"modelId": f"branch_{i}",
|
||||
"modelFile": f"branch_{i}.pt",
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.7,
|
||||
"parallel": True
|
||||
}
|
||||
for i in range(3) # 3 branches
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
|
||||
mock_model_manager.return_value.get_model.return_value = mock_model
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Execute branches in parallel
|
||||
tasks = [
|
||||
parallel_executor._execute_branch(config, regions, context)
|
||||
for config in branch_configs
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
parallel_time = time.time() - start_time
|
||||
|
||||
# Parallel execution should be faster than 3 * 0.1 seconds
|
||||
assert parallel_time < 0.25 # Allow some overhead
|
||||
assert len(results) == 3
|
||||
assert all(result.success for result in results)
|
||||
|
||||
def test_thread_pool_management(self):
|
||||
"""Test thread pool creation and management."""
|
||||
# Test different worker counts
|
||||
for workers in [1, 2, 4, 8]:
|
||||
executor = PipelineExecutor({"max_workers": workers})
|
||||
assert executor.max_workers == workers
|
||||
assert executor.thread_pool._max_workers == workers
|
||||
|
||||
def test_memory_management_large_frames(self):
|
||||
"""Test memory management with large frames."""
|
||||
executor = PipelineExecutor()
|
||||
|
||||
# Create large frame
|
||||
large_frame = np.ones((1080, 1920, 3), dtype=np.uint8) * 128
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=large_frame,
|
||||
crop_region=(500, 400, 1000, 800)
|
||||
)
|
||||
|
||||
# Get cropped frame
|
||||
cropped = context.get_cropped_frame()
|
||||
|
||||
# Should reduce memory usage
|
||||
assert cropped.shape == (400, 500, 3) # Much smaller than original
|
||||
assert cropped.nbytes < large_frame.nbytes
|
||||
|
||||
|
||||
class TestPipelineExecutorErrorHandling:
|
||||
"""Test comprehensive error handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_branch_execution_error_isolation(self, mock_frame):
|
||||
"""Test that errors in one branch don't affect others."""
|
||||
executor = PipelineExecutor()
|
||||
|
||||
# Mock models - one fails, one succeeds
|
||||
failing_model = Mock()
|
||||
failing_model.predict.side_effect = Exception("Model crashed")
|
||||
|
||||
success_model = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.probs = Mock()
|
||||
mock_result.probs.top1 = 1
|
||||
mock_result.probs.top1conf = 0.85
|
||||
success_model.predict.return_value = [mock_result]
|
||||
success_model.names = {0: "Class0", 1: "Class1"}
|
||||
|
||||
branch_configs = [
|
||||
{
|
||||
"modelId": "failing_branch",
|
||||
"modelFile": "failing.pt",
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.7,
|
||||
"parallel": True
|
||||
},
|
||||
{
|
||||
"modelId": "success_branch",
|
||||
"modelFile": "success.pt",
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.7,
|
||||
"parallel": True
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
def get_model_side_effect(model_id, camera_id):
|
||||
if model_id == "failing_branch":
|
||||
return failing_model
|
||||
elif model_id == "success_branch":
|
||||
return success_model
|
||||
return None
|
||||
|
||||
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
|
||||
mock_model_manager.return_value.get_model.side_effect = get_model_side_effect
|
||||
|
||||
# Execute branches
|
||||
tasks = [
|
||||
executor._execute_branch(config, regions, context)
|
||||
for config in branch_configs
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# One should fail, one should succeed
|
||||
failing_result = next(r for r in results if isinstance(r, BranchResult) and r.branch_id == "failing_branch")
|
||||
success_result = next(r for r in results if isinstance(r, BranchResult) and r.branch_id == "success_branch")
|
||||
|
||||
assert failing_result.success is False
|
||||
assert "Model crashed" in failing_result.error
|
||||
|
||||
assert success_result.success is True
|
||||
assert success_result.error is None
|
Loading…
Add table
Add a link
Reference in a new issue