Refactor: PHASE 8: Testing & Integration

This commit is contained in:
ziesorx 2025-09-12 18:55:23 +07:00
parent af34f4fd08
commit 9e8c6804a7
32 changed files with 17128 additions and 0 deletions

View file

@ -0,0 +1,959 @@
"""
Unit tests for action execution functionality.
"""
import pytest
import asyncio
import json
import base64
import numpy as np
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from datetime import datetime, timedelta
from detector_worker.pipeline.action_executor import (
ActionExecutor,
ActionResult,
ActionType,
RedisAction,
PostgreSQLAction,
FileAction
)
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
from detector_worker.core.exceptions import ActionError, RedisError, DatabaseError
class TestActionResult:
"""Test action execution result."""
def test_creation_success(self):
"""Test successful action result creation."""
result = ActionResult(
action_type=ActionType.REDIS_SAVE,
success=True,
execution_time=0.05,
metadata={"key": "saved_image_key", "expiry": 600}
)
assert result.action_type == ActionType.REDIS_SAVE
assert result.success is True
assert result.execution_time == 0.05
assert result.metadata["key"] == "saved_image_key"
assert result.error is None
def test_creation_failure(self):
"""Test failed action result creation."""
result = ActionResult(
action_type=ActionType.POSTGRESQL_INSERT,
success=False,
error="Database connection failed",
execution_time=0.02
)
assert result.action_type == ActionType.POSTGRESQL_INSERT
assert result.success is False
assert result.error == "Database connection failed"
assert result.metadata == {}
class TestRedisAction:
"""Test Redis action implementations."""
def test_creation(self):
"""Test Redis action creation."""
action_config = {
"type": "redis_save_image",
"region": "car",
"key": "inference:{display_id}:{timestamp}:{session_id}",
"expire_seconds": 600
}
action = RedisAction(action_config)
assert action.action_type == ActionType.REDIS_SAVE
assert action.region == "car"
assert action.key_template == "inference:{display_id}:{timestamp}:{session_id}"
assert action.expire_seconds == 600
def test_resolve_key_template(self):
"""Test key template resolution."""
action_config = {
"type": "redis_save_image",
"region": "car",
"key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
"expire_seconds": 600
}
action = RedisAction(action_config)
context = {
"display_id": "display_001",
"timestamp": "1640995200000",
"session_id": "session_123",
"filename": "detection.jpg"
}
resolved_key = action.resolve_key(context)
expected_key = "inference:display_001:1640995200000:session_123:detection.jpg"
assert resolved_key == expected_key
def test_resolve_key_missing_variable(self):
"""Test key resolution with missing variable."""
action_config = {
"type": "redis_save_image",
"region": "car",
"key": "inference:{display_id}:{missing_var}",
"expire_seconds": 600
}
action = RedisAction(action_config)
context = {"display_id": "display_001"}
with pytest.raises(ActionError):
action.resolve_key(context)
class TestPostgreSQLAction:
"""Test PostgreSQL action implementations."""
def test_creation_insert(self):
"""Test PostgreSQL insert action creation."""
action_config = {
"type": "postgresql_insert",
"table": "detections",
"fields": {
"camera_id": "{camera_id}",
"session_id": "{session_id}",
"detection_class": "{class}",
"confidence": "{confidence}",
"bbox_x1": "{bbox.x1}",
"created_at": "NOW()"
}
}
action = PostgreSQLAction(action_config)
assert action.action_type == ActionType.POSTGRESQL_INSERT
assert action.table == "detections"
assert len(action.fields) == 6
assert action.key_field is None
def test_creation_update(self):
"""Test PostgreSQL update action creation."""
action_config = {
"type": "postgresql_update_combined",
"table": "car_info",
"key_field": "session_id",
"fields": {
"car_brand": "{car_brand_cls.brand}",
"car_body_type": "{car_bodytype_cls.body_type}",
"updated_at": "NOW()"
},
"waitForBranches": ["car_brand_cls", "car_bodytype_cls"]
}
action = PostgreSQLAction(action_config)
assert action.action_type == ActionType.POSTGRESQL_UPDATE
assert action.table == "car_info"
assert action.key_field == "session_id"
assert action.wait_for_branches == ["car_brand_cls", "car_bodytype_cls"]
def test_resolve_field_values(self):
"""Test field value resolution."""
action_config = {
"type": "postgresql_insert",
"table": "detections",
"fields": {
"camera_id": "{camera_id}",
"detection_class": "{class}",
"confidence": "{confidence}",
"brand": "{car_brand_cls.brand}"
}
}
action = PostgreSQLAction(action_config)
context = {
"camera_id": "camera_001",
"class": "car",
"confidence": 0.85
}
branch_results = {
"car_brand_cls": {"brand": "Toyota", "confidence": 0.78}
}
resolved_fields = action.resolve_field_values(context, branch_results)
assert resolved_fields["camera_id"] == "camera_001"
assert resolved_fields["detection_class"] == "car"
assert resolved_fields["confidence"] == 0.85
assert resolved_fields["brand"] == "Toyota"
class TestFileAction:
"""Test file action implementations."""
def test_creation(self):
"""Test file action creation."""
action_config = {
"type": "save_image",
"path": "/tmp/detections/{camera_id}_{timestamp}.jpg",
"region": "car",
"format": "jpeg",
"quality": 85
}
action = FileAction(action_config)
assert action.action_type == ActionType.SAVE_IMAGE
assert action.path_template == "/tmp/detections/{camera_id}_{timestamp}.jpg"
assert action.region == "car"
assert action.format == "jpeg"
assert action.quality == 85
def test_resolve_path_template(self):
"""Test path template resolution."""
action_config = {
"type": "save_image",
"path": "/tmp/detections/{camera_id}/{date}/{timestamp}.jpg"
}
action = FileAction(action_config)
context = {
"camera_id": "camera_001",
"timestamp": "1640995200000",
"date": "2022-01-01"
}
resolved_path = action.resolve_path(context)
expected_path = "/tmp/detections/camera_001/2022-01-01/1640995200000.jpg"
assert resolved_path == expected_path
class TestActionExecutor:
"""Test action execution functionality."""
def test_initialization(self):
"""Test action executor initialization."""
executor = ActionExecutor()
assert executor.redis_client is None
assert executor.db_manager is None
assert executor.max_concurrent_actions == 10
assert executor.action_timeout == 30.0
def test_initialization_with_clients(self, mock_redis_client, mock_database_connection):
"""Test initialization with client instances."""
executor = ActionExecutor(
redis_client=mock_redis_client,
db_manager=mock_database_connection
)
assert executor.redis_client is mock_redis_client
assert executor.db_manager is mock_database_connection
@pytest.mark.asyncio
async def test_execute_actions_empty_list(self):
"""Test executing empty action list."""
executor = ActionExecutor()
context = {
"camera_id": "camera_001",
"session_id": "session_123"
}
results = await executor.execute_actions([], {}, context)
assert results == []
@pytest.mark.asyncio
async def test_execute_redis_save_action(self, mock_redis_client, mock_frame):
"""Test executing Redis save image action."""
executor = ActionExecutor(redis_client=mock_redis_client)
actions = [
{
"type": "redis_save_image",
"region": "car",
"key": "inference:{camera_id}:{session_id}",
"expire_seconds": 600
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"camera_id": "camera_001",
"session_id": "session_123",
"frame_data": mock_frame
}
# Mock successful Redis operations
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 1
assert results[0].success is True
assert results[0].action_type == ActionType.REDIS_SAVE
# Verify Redis calls
mock_redis_client.set.assert_called_once()
mock_redis_client.expire.assert_called_once()
@pytest.mark.asyncio
async def test_execute_postgresql_insert_action(self, mock_database_connection):
"""Test executing PostgreSQL insert action."""
# Mock database manager
mock_db_manager = Mock()
mock_db_manager.execute_query = AsyncMock(return_value=True)
executor = ActionExecutor(db_manager=mock_db_manager)
actions = [
{
"type": "postgresql_insert",
"table": "detections",
"fields": {
"camera_id": "{camera_id}",
"session_id": "{session_id}",
"detection_class": "{class}",
"confidence": "{confidence}",
"created_at": "NOW()"
}
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"camera_id": "camera_001",
"session_id": "session_123",
"class": "car",
"confidence": 0.9
}
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 1
assert results[0].success is True
assert results[0].action_type == ActionType.POSTGRESQL_INSERT
# Verify database call
mock_db_manager.execute_query.assert_called_once()
call_args = mock_db_manager.execute_query.call_args[0]
assert "INSERT INTO detections" in call_args[0]
@pytest.mark.asyncio
async def test_execute_postgresql_update_action(self, mock_database_connection):
"""Test executing PostgreSQL update action."""
mock_db_manager = Mock()
mock_db_manager.execute_query = AsyncMock(return_value=True)
executor = ActionExecutor(db_manager=mock_db_manager)
actions = [
{
"type": "postgresql_update_combined",
"table": "car_info",
"key_field": "session_id",
"fields": {
"car_brand": "{car_brand_cls.brand}",
"car_body_type": "{car_bodytype_cls.body_type}",
"updated_at": "NOW()"
},
"waitForBranches": ["car_brand_cls", "car_bodytype_cls"]
}
]
regions = {}
context = {
"session_id": "session_123"
}
branch_results = {
"car_brand_cls": {"brand": "Toyota"},
"car_bodytype_cls": {"body_type": "Sedan"}
}
results = await executor.execute_actions(actions, regions, context, branch_results)
assert len(results) == 1
assert results[0].success is True
assert results[0].action_type == ActionType.POSTGRESQL_UPDATE
# Verify database call
mock_db_manager.execute_query.assert_called_once()
call_args = mock_db_manager.execute_query.call_args[0]
assert "UPDATE car_info SET" in call_args[0]
assert "WHERE session_id" in call_args[0]
@pytest.mark.asyncio
async def test_execute_file_save_action(self, mock_frame):
"""Test executing file save action."""
executor = ActionExecutor()
actions = [
{
"type": "save_image",
"path": "/tmp/test_{camera_id}_{timestamp}.jpg",
"region": "car",
"format": "jpeg",
"quality": 85
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"camera_id": "camera_001",
"timestamp": "1640995200000",
"frame_data": mock_frame
}
with patch('cv2.imwrite') as mock_imwrite:
mock_imwrite.return_value = True
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 1
assert results[0].success is True
assert results[0].action_type == ActionType.SAVE_IMAGE
# Verify file save call
mock_imwrite.assert_called_once()
call_args = mock_imwrite.call_args
assert "/tmp/test_camera_001_1640995200000.jpg" in call_args[0][0]
@pytest.mark.asyncio
async def test_execute_actions_parallel(self, mock_redis_client):
"""Test parallel execution of multiple actions."""
executor = ActionExecutor(redis_client=mock_redis_client)
# Multiple Redis actions
actions = [
{
"type": "redis_save_image",
"region": "car",
"key": "inference:car:{session_id}",
"expire_seconds": 600
},
{
"type": "redis_publish",
"channel": "detections",
"message": "{camera_id}:car_detected"
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"camera_id": "camera_001",
"session_id": "session_123",
"frame_data": np.zeros((480, 640, 3), dtype=np.uint8)
}
# Mock Redis operations
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
mock_redis_client.publish.return_value = 1
import time
start_time = time.time()
results = await executor.execute_actions(actions, regions, context)
execution_time = time.time() - start_time
assert len(results) == 2
assert all(result.success for result in results)
# Should execute in parallel (faster than sequential)
assert execution_time < 0.1 # Allow some overhead
@pytest.mark.asyncio
async def test_execute_actions_error_handling(self, mock_redis_client):
"""Test error handling in action execution."""
executor = ActionExecutor(redis_client=mock_redis_client)
actions = [
{
"type": "redis_save_image",
"region": "car",
"key": "inference:{session_id}",
"expire_seconds": 600
},
{
"type": "redis_save_image", # This one will fail
"region": "truck", # Region not detected
"key": "inference:truck:{session_id}",
"expire_seconds": 600
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
# No truck region
}
context = {
"session_id": "session_123",
"frame_data": np.zeros((480, 640, 3), dtype=np.uint8)
}
# Mock Redis operations
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 2
assert results[0].success is True # Car action succeeds
assert results[1].success is False # Truck action fails
assert "Region 'truck' not found" in results[1].error
@pytest.mark.asyncio
async def test_execute_actions_timeout(self, mock_redis_client):
"""Test action execution timeout."""
config = {"action_timeout": 0.001} # Very short timeout
executor = ActionExecutor(redis_client=mock_redis_client, config=config)
def slow_redis_operation(*args, **kwargs):
import time
time.sleep(1) # Longer than timeout
return True
mock_redis_client.set.side_effect = slow_redis_operation
actions = [
{
"type": "redis_save_image",
"region": "car",
"key": "inference:{session_id}",
"expire_seconds": 600
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"session_id": "session_123",
"frame_data": np.zeros((480, 640, 3), dtype=np.uint8)
}
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 1
assert results[0].success is False
assert "timeout" in results[0].error.lower()
@pytest.mark.asyncio
async def test_execute_redis_publish_action(self, mock_redis_client):
"""Test executing Redis publish action."""
executor = ActionExecutor(redis_client=mock_redis_client)
actions = [
{
"type": "redis_publish",
"channel": "detections:{camera_id}",
"message": {
"camera_id": "{camera_id}",
"detection_class": "{class}",
"confidence": "{confidence}",
"timestamp": "{timestamp}"
}
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"camera_id": "camera_001",
"class": "car",
"confidence": 0.9,
"timestamp": "1640995200000"
}
mock_redis_client.publish.return_value = 1
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 1
assert results[0].success is True
assert results[0].action_type == ActionType.REDIS_PUBLISH
# Verify publish call
mock_redis_client.publish.assert_called_once()
call_args = mock_redis_client.publish.call_args
assert call_args[0][0] == "detections:camera_001" # Channel
# Message should be JSON
message = call_args[0][1]
parsed_message = json.loads(message)
assert parsed_message["camera_id"] == "camera_001"
assert parsed_message["detection_class"] == "car"
@pytest.mark.asyncio
async def test_execute_conditional_action(self):
"""Test executing conditional actions."""
executor = ActionExecutor()
actions = [
{
"type": "conditional",
"condition": "{confidence} > 0.8",
"actions": [
{
"type": "log",
"message": "High confidence detection: {class} ({confidence})"
}
]
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.95, # High confidence
"detection": DetectionResult("car", 0.95, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"class": "car",
"confidence": 0.95
}
with patch('logging.info') as mock_log:
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 1
assert results[0].success is True
# Should have logged the message
mock_log.assert_called_once()
log_message = mock_log.call_args[0][0]
assert "High confidence detection: car (0.95)" in log_message
def test_crop_region_from_frame(self, mock_frame):
"""Test cropping region from frame."""
executor = ActionExecutor()
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
cropped = executor._crop_region_from_frame(mock_frame, detection.bbox)
assert cropped.shape == (200, 200, 3) # 400-200, 300-100
def test_encode_image_base64(self, mock_frame):
"""Test encoding image to base64."""
executor = ActionExecutor()
# Crop a small region
cropped_frame = mock_frame[200:400, 100:300] # 200x200 region
with patch('cv2.imencode') as mock_imencode:
# Mock successful encoding
mock_imencode.return_value = (True, np.array([1, 2, 3, 4], dtype=np.uint8))
encoded = executor._encode_image_base64(cropped_frame, format="jpeg")
# Should return base64 string
assert isinstance(encoded, str)
assert len(encoded) > 0
# Verify encoding call
mock_imencode.assert_called_once()
assert mock_imencode.call_args[0][0] == '.jpg'
def test_build_insert_query(self):
"""Test building INSERT SQL query."""
executor = ActionExecutor()
table = "detections"
fields = {
"camera_id": "camera_001",
"detection_class": "car",
"confidence": 0.9,
"created_at": "NOW()"
}
query, values = executor._build_insert_query(table, fields)
assert "INSERT INTO detections" in query
assert "camera_id, detection_class, confidence, created_at" in query
assert "VALUES (%s, %s, %s, NOW())" in query
assert values == ["camera_001", "car", 0.9]
def test_build_update_query(self):
"""Test building UPDATE SQL query."""
executor = ActionExecutor()
table = "car_info"
fields = {
"car_brand": "Toyota",
"car_body_type": "Sedan",
"updated_at": "NOW()"
}
key_field = "session_id"
key_value = "session_123"
query, values = executor._build_update_query(table, fields, key_field, key_value)
assert "UPDATE car_info SET" in query
assert "car_brand = %s" in query
assert "car_body_type = %s" in query
assert "updated_at = NOW()" in query
assert "WHERE session_id = %s" in query
assert values == ["Toyota", "Sedan", "session_123"]
def test_evaluate_condition(self):
"""Test evaluating conditional expressions."""
executor = ActionExecutor()
context = {
"confidence": 0.85,
"class": "car",
"area": 40000
}
# Simple comparisons
assert executor._evaluate_condition("{confidence} > 0.8", context) is True
assert executor._evaluate_condition("{confidence} < 0.8", context) is False
assert executor._evaluate_condition("{confidence} >= 0.85", context) is True
assert executor._evaluate_condition("{confidence} == 0.85", context) is True
# String comparisons
assert executor._evaluate_condition("{class} == 'car'", context) is True
assert executor._evaluate_condition("{class} != 'truck'", context) is True
# Complex conditions
assert executor._evaluate_condition("{confidence} > 0.8 and {area} > 30000", context) is True
assert executor._evaluate_condition("{confidence} > 0.9 or {area} > 30000", context) is True
assert executor._evaluate_condition("{confidence} > 0.9 and {area} < 30000", context) is False
def test_validate_action_config(self):
"""Test action configuration validation."""
executor = ActionExecutor()
# Valid Redis action
valid_redis = {
"type": "redis_save_image",
"region": "car",
"key": "inference:{session_id}",
"expire_seconds": 600
}
assert executor._validate_action_config(valid_redis) is True
# Invalid action (missing required fields)
invalid_action = {
"type": "redis_save_image"
# Missing region and key
}
with pytest.raises(ActionError):
executor._validate_action_config(invalid_action)
# Unknown action type
unknown_action = {
"type": "unknown_action_type",
"some_field": "value"
}
with pytest.raises(ActionError):
executor._validate_action_config(unknown_action)
class TestActionExecutorIntegration:
"""Integration tests for action execution."""
@pytest.mark.asyncio
async def test_complete_detection_workflow(self, mock_redis_client, mock_frame):
"""Test complete detection workflow with multiple actions."""
# Mock database manager
mock_db_manager = Mock()
mock_db_manager.execute_query = AsyncMock(return_value=True)
executor = ActionExecutor(
redis_client=mock_redis_client,
db_manager=mock_db_manager
)
# Complete action workflow
actions = [
# Save cropped image to Redis
{
"type": "redis_save_image",
"region": "car",
"key": "inference:{camera_id}:{timestamp}:{session_id}:car",
"expire_seconds": 600
},
# Insert initial detection record
{
"type": "postgresql_insert",
"table": "car_detections",
"fields": {
"camera_id": "{camera_id}",
"session_id": "{session_id}",
"detection_class": "{class}",
"confidence": "{confidence}",
"bbox_x1": "{bbox.x1}",
"bbox_y1": "{bbox.y1}",
"bbox_x2": "{bbox.x2}",
"bbox_y2": "{bbox.y2}",
"created_at": "NOW()"
}
},
# Publish detection event
{
"type": "redis_publish",
"channel": "detections:{camera_id}",
"message": {
"event": "car_detected",
"camera_id": "{camera_id}",
"session_id": "{session_id}",
"timestamp": "{timestamp}"
}
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.92,
"detection": DetectionResult("car", 0.92, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"camera_id": "camera_001",
"session_id": "session_123",
"timestamp": "1640995200000",
"class": "car",
"confidence": 0.92,
"bbox": {"x1": 100, "y1": 200, "x2": 300, "y2": 400},
"frame_data": mock_frame
}
# Mock all Redis operations
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
mock_redis_client.publish.return_value = 1
results = await executor.execute_actions(actions, regions, context)
# All actions should succeed
assert len(results) == 3
assert all(result.success for result in results)
# Verify all operations were called
mock_redis_client.set.assert_called_once() # Image save
mock_redis_client.expire.assert_called_once() # Set expiry
mock_redis_client.publish.assert_called_once() # Publish event
mock_db_manager.execute_query.assert_called_once() # Database insert
@pytest.mark.asyncio
async def test_branch_dependent_actions(self, mock_database_connection):
"""Test actions that depend on branch results."""
mock_db_manager = Mock()
mock_db_manager.execute_query = AsyncMock(return_value=True)
executor = ActionExecutor(db_manager=mock_db_manager)
# Action that waits for classification branches
actions = [
{
"type": "postgresql_update_combined",
"table": "car_info",
"key_field": "session_id",
"fields": {
"car_brand": "{car_brand_cls.brand}",
"car_body_type": "{car_bodytype_cls.body_type}",
"car_color": "{car_color_cls.color}",
"confidence_brand": "{car_brand_cls.confidence}",
"confidence_bodytype": "{car_bodytype_cls.confidence}",
"updated_at": "NOW()"
},
"waitForBranches": ["car_brand_cls", "car_bodytype_cls", "car_color_cls"]
}
]
regions = {}
context = {
"session_id": "session_123"
}
# Simulated branch results
branch_results = {
"car_brand_cls": {"brand": "Toyota", "confidence": 0.87},
"car_bodytype_cls": {"body_type": "Sedan", "confidence": 0.82},
"car_color_cls": {"color": "Red", "confidence": 0.79}
}
results = await executor.execute_actions(actions, regions, context, branch_results)
assert len(results) == 1
assert results[0].success is True
assert results[0].action_type == ActionType.POSTGRESQL_UPDATE
# Verify database call with all branch data
mock_db_manager.execute_query.assert_called_once()
call_args = mock_db_manager.execute_query.call_args
query = call_args[0][0]
values = call_args[0][1]
assert "UPDATE car_info SET" in query
assert "car_brand = %s" in query
assert "car_body_type = %s" in query
assert "car_color = %s" in query
assert "WHERE session_id = %s" in query
assert "Toyota" in values
assert "Sedan" in values
assert "Red" in values
assert "session_123" in values

View file

@ -0,0 +1,786 @@
"""
Unit tests for field mapping and template resolution.
"""
import pytest
from unittest.mock import Mock, patch
from datetime import datetime
import json
from detector_worker.pipeline.field_mapper import (
FieldMapper,
MappingContext,
TemplateResolver,
FieldMappingError,
NestedFieldAccessor
)
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
class TestNestedFieldAccessor:
"""Test nested field access functionality."""
def test_get_nested_value_simple(self):
"""Test getting simple nested values."""
data = {
"user": {
"name": "John",
"age": 30,
"address": {
"city": "New York",
"zip": "10001"
}
}
}
accessor = NestedFieldAccessor()
assert accessor.get_nested_value(data, "user.name") == "John"
assert accessor.get_nested_value(data, "user.age") == 30
assert accessor.get_nested_value(data, "user.address.city") == "New York"
assert accessor.get_nested_value(data, "user.address.zip") == "10001"
def test_get_nested_value_array_access(self):
"""Test accessing array elements."""
data = {
"results": [
{"score": 0.9, "label": "car"},
{"score": 0.8, "label": "truck"}
],
"bbox": [100, 200, 300, 400]
}
accessor = NestedFieldAccessor()
assert accessor.get_nested_value(data, "results[0].score") == 0.9
assert accessor.get_nested_value(data, "results[0].label") == "car"
assert accessor.get_nested_value(data, "results[1].score") == 0.8
assert accessor.get_nested_value(data, "bbox[0]") == 100
assert accessor.get_nested_value(data, "bbox[3]") == 400
def test_get_nested_value_nonexistent_path(self):
"""Test accessing non-existent paths."""
data = {"user": {"name": "John"}}
accessor = NestedFieldAccessor()
assert accessor.get_nested_value(data, "user.nonexistent") is None
assert accessor.get_nested_value(data, "nonexistent.field") is None
assert accessor.get_nested_value(data, "user.address.city") is None
def test_get_nested_value_with_default(self):
"""Test getting nested values with default fallback."""
data = {"user": {"name": "John"}}
accessor = NestedFieldAccessor()
assert accessor.get_nested_value(data, "user.age", default=25) == 25
assert accessor.get_nested_value(data, "user.name", default="Unknown") == "John"
def test_set_nested_value(self):
"""Test setting nested values."""
data = {}
accessor = NestedFieldAccessor()
accessor.set_nested_value(data, "user.name", "John")
assert data["user"]["name"] == "John"
accessor.set_nested_value(data, "user.address.city", "New York")
assert data["user"]["address"]["city"] == "New York"
accessor.set_nested_value(data, "scores[0]", 0.95)
assert data["scores"][0] == 0.95
def test_set_nested_value_overwrite(self):
"""Test overwriting existing nested values."""
data = {"user": {"name": "John", "age": 30}}
accessor = NestedFieldAccessor()
accessor.set_nested_value(data, "user.name", "Jane")
assert data["user"]["name"] == "Jane"
assert data["user"]["age"] == 30 # Should not affect other fields
class TestTemplateResolver:
"""Test template string resolution."""
def test_resolve_simple_template(self):
"""Test resolving simple template variables."""
resolver = TemplateResolver()
template = "Hello {name}, you are {age} years old"
context = {"name": "John", "age": 30}
result = resolver.resolve(template, context)
assert result == "Hello John, you are 30 years old"
def test_resolve_nested_template(self):
"""Test resolving nested field templates."""
resolver = TemplateResolver()
template = "User: {user.name} from {user.address.city}"
context = {
"user": {
"name": "John",
"address": {"city": "New York", "zip": "10001"}
}
}
result = resolver.resolve(template, context)
assert result == "User: John from New York"
def test_resolve_array_template(self):
"""Test resolving array element templates."""
resolver = TemplateResolver()
template = "First result: {results[0].label} ({results[0].score})"
context = {
"results": [
{"label": "car", "score": 0.95},
{"label": "truck", "score": 0.87}
]
}
result = resolver.resolve(template, context)
assert result == "First result: car (0.95)"
def test_resolve_missing_variables(self):
"""Test resolving templates with missing variables."""
resolver = TemplateResolver()
template = "Hello {name}, you are {age} years old"
context = {"name": "John"} # Missing age
with pytest.raises(FieldMappingError) as exc_info:
resolver.resolve(template, context)
assert "Variable 'age' not found" in str(exc_info.value)
def test_resolve_with_defaults(self):
"""Test resolving templates with default values."""
resolver = TemplateResolver(allow_missing=True)
template = "Hello {name}, you are {age|25} years old"
context = {"name": "John"} # Missing age, should use default
result = resolver.resolve(template, context)
assert result == "Hello John, you are 25 years old"
def test_resolve_complex_template(self):
"""Test resolving complex templates with multiple variable types."""
resolver = TemplateResolver()
template = "{camera_id}:{timestamp}:{session_id}:{results[0].class}_{bbox[0]}_{bbox[1]}"
context = {
"camera_id": "cam001",
"timestamp": 1640995200000,
"session_id": "sess123",
"results": [{"class": "car", "confidence": 0.95}],
"bbox": [100, 200, 300, 400]
}
result = resolver.resolve(template, context)
assert result == "cam001:1640995200000:sess123:car_100_200"
def test_resolve_conditional_template(self):
"""Test resolving conditional templates."""
resolver = TemplateResolver()
# Simple conditional
template = "{name} is {age > 18 ? 'adult' : 'minor'}"
context_adult = {"name": "John", "age": 25}
result_adult = resolver.resolve(template, context_adult)
assert result_adult == "John is adult"
context_minor = {"name": "Jane", "age": 16}
result_minor = resolver.resolve(template, context_minor)
assert result_minor == "Jane is minor"
def test_escape_braces(self):
"""Test escaping braces in templates."""
resolver = TemplateResolver()
template = "Literal {{braces}} and variable {name}"
context = {"name": "John"}
result = resolver.resolve(template, context)
assert result == "Literal {braces} and variable John"
class TestMappingContext:
"""Test mapping context data structure."""
def test_creation(self):
"""Test mapping context creation."""
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
detection=detection,
timestamp=1640995200000
)
assert context.camera_id == "camera_001"
assert context.display_id == "display_001"
assert context.session_id == "session_123"
assert context.detection == detection
assert context.timestamp == 1640995200000
assert context.branch_results == {}
assert context.metadata == {}
def test_add_branch_result(self):
"""Test adding branch results to context."""
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123"
)
context.add_branch_result("car_brand_cls", {"brand": "Toyota", "confidence": 0.87})
context.add_branch_result("car_bodytype_cls", {"body_type": "Sedan", "confidence": 0.82})
assert len(context.branch_results) == 2
assert context.branch_results["car_brand_cls"]["brand"] == "Toyota"
assert context.branch_results["car_bodytype_cls"]["body_type"] == "Sedan"
def test_to_dict(self):
"""Test converting context to dictionary."""
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
detection=detection,
timestamp=1640995200000
)
context.add_branch_result("car_brand_cls", {"brand": "Toyota"})
context.add_metadata("model_id", "yolo_v8")
context_dict = context.to_dict()
assert context_dict["camera_id"] == "camera_001"
assert context_dict["display_id"] == "display_001"
assert context_dict["session_id"] == "session_123"
assert context_dict["timestamp"] == 1640995200000
assert context_dict["class"] == "car"
assert context_dict["confidence"] == 0.9
assert context_dict["track_id"] == 1001
assert context_dict["bbox"]["x1"] == 100
assert context_dict["car_brand_cls"]["brand"] == "Toyota"
assert context_dict["model_id"] == "yolo_v8"
def test_add_metadata(self):
"""Test adding metadata to context."""
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123"
)
context.add_metadata("model_version", "v2.1")
context.add_metadata("inference_time", 0.15)
assert context.metadata["model_version"] == "v2.1"
assert context.metadata["inference_time"] == 0.15
class TestFieldMapper:
"""Test field mapping functionality."""
def test_initialization(self):
"""Test field mapper initialization."""
mapper = FieldMapper()
assert isinstance(mapper.template_resolver, TemplateResolver)
assert isinstance(mapper.field_accessor, NestedFieldAccessor)
def test_map_fields_simple(self):
"""Test simple field mapping."""
mapper = FieldMapper()
field_mappings = {
"camera_id": "{camera_id}",
"detection_class": "{class}",
"confidence_score": "{confidence}",
"track_identifier": "{track_id}"
}
detection = DetectionResult("car", 0.92, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
detection=detection,
timestamp=1640995200000
)
mapped_fields = mapper.map_fields(field_mappings, context)
assert mapped_fields["camera_id"] == "camera_001"
assert mapped_fields["detection_class"] == "car"
assert mapped_fields["confidence_score"] == 0.92
assert mapped_fields["track_identifier"] == 1001
def test_map_fields_with_branch_results(self):
"""Test field mapping with branch results."""
mapper = FieldMapper()
field_mappings = {
"car_brand": "{car_brand_cls.brand}",
"car_model": "{car_brand_cls.model}",
"body_type": "{car_bodytype_cls.body_type}",
"brand_confidence": "{car_brand_cls.confidence}",
"combined_info": "{car_brand_cls.brand} {car_bodytype_cls.body_type}"
}
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123"
)
context.add_branch_result("car_brand_cls", {
"brand": "Toyota",
"model": "Camry",
"confidence": 0.87
})
context.add_branch_result("car_bodytype_cls", {
"body_type": "Sedan",
"confidence": 0.82
})
mapped_fields = mapper.map_fields(field_mappings, context)
assert mapped_fields["car_brand"] == "Toyota"
assert mapped_fields["car_model"] == "Camry"
assert mapped_fields["body_type"] == "Sedan"
assert mapped_fields["brand_confidence"] == 0.87
assert mapped_fields["combined_info"] == "Toyota Sedan"
def test_map_fields_bbox_access(self):
"""Test field mapping with bounding box access."""
mapper = FieldMapper()
field_mappings = {
"bbox_x1": "{bbox.x1}",
"bbox_y1": "{bbox.y1}",
"bbox_x2": "{bbox.x2}",
"bbox_y2": "{bbox.y2}",
"bbox_width": "{bbox.width}",
"bbox_height": "{bbox.height}",
"bbox_area": "{bbox.area}",
"bbox_center_x": "{bbox.center_x}",
"bbox_center_y": "{bbox.center_y}"
}
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
detection=detection
)
mapped_fields = mapper.map_fields(field_mappings, context)
assert mapped_fields["bbox_x1"] == 100
assert mapped_fields["bbox_y1"] == 200
assert mapped_fields["bbox_x2"] == 300
assert mapped_fields["bbox_y2"] == 400
assert mapped_fields["bbox_width"] == 200 # 300 - 100
assert mapped_fields["bbox_height"] == 200 # 400 - 200
assert mapped_fields["bbox_area"] == 40000 # 200 * 200
assert mapped_fields["bbox_center_x"] == 200 # (100 + 300) / 2
assert mapped_fields["bbox_center_y"] == 300 # (200 + 400) / 2
def test_map_fields_with_sql_functions(self):
"""Test field mapping with SQL function templates."""
mapper = FieldMapper()
field_mappings = {
"created_at": "NOW()",
"updated_at": "CURRENT_TIMESTAMP",
"uuid_field": "UUID()",
"json_data": "JSON_OBJECT('class', '{class}', 'confidence', {confidence})"
}
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
detection=detection
)
mapped_fields = mapper.map_fields(field_mappings, context)
# SQL functions should pass through unchanged
assert mapped_fields["created_at"] == "NOW()"
assert mapped_fields["updated_at"] == "CURRENT_TIMESTAMP"
assert mapped_fields["uuid_field"] == "UUID()"
assert mapped_fields["json_data"] == "JSON_OBJECT('class', 'car', 'confidence', 0.9)"
def test_map_fields_missing_branch_data(self):
"""Test field mapping with missing branch data."""
mapper = FieldMapper()
field_mappings = {
"car_brand": "{car_brand_cls.brand}",
"car_model": "{nonexistent_branch.model}"
}
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123"
)
# Only add one branch result
context.add_branch_result("car_brand_cls", {"brand": "Toyota"})
with pytest.raises(FieldMappingError) as exc_info:
mapper.map_fields(field_mappings, context)
assert "nonexistent_branch.model" in str(exc_info.value)
def test_map_fields_with_defaults(self):
"""Test field mapping with default values."""
mapper = FieldMapper(allow_missing=True)
field_mappings = {
"car_brand": "{car_brand_cls.brand|Unknown}",
"car_model": "{car_brand_cls.model|N/A}",
"confidence": "{confidence|0.0}"
}
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123"
)
# Don't add any branch results
mapped_fields = mapper.map_fields(field_mappings, context)
assert mapped_fields["car_brand"] == "Unknown"
assert mapped_fields["car_model"] == "N/A"
assert mapped_fields["confidence"] == "0.0"
def test_map_database_fields(self):
"""Test mapping fields for database operations."""
mapper = FieldMapper()
# Database field mapping
db_field_mappings = {
"camera_id": "{camera_id}",
"session_id": "{session_id}",
"detection_timestamp": "{timestamp}",
"object_class": "{class}",
"detection_confidence": "{confidence}",
"track_id": "{track_id}",
"bbox_json": "JSON_OBJECT('x1', {bbox.x1}, 'y1', {bbox.y1}, 'x2', {bbox.x2}, 'y2', {bbox.y2})",
"car_brand": "{car_brand_cls.brand}",
"car_body_type": "{car_bodytype_cls.body_type}",
"license_plate": "{license_ocr.text}",
"created_at": "NOW()",
"updated_at": "NOW()"
}
detection = DetectionResult("car", 0.93, BoundingBox(150, 250, 350, 450), 2001, 1640995300000)
context = MappingContext(
camera_id="camera_002",
display_id="display_002",
session_id="session_456",
detection=detection,
timestamp=1640995300000
)
# Add branch results
context.add_branch_result("car_brand_cls", {"brand": "Honda", "confidence": 0.89})
context.add_branch_result("car_bodytype_cls", {"body_type": "SUV", "confidence": 0.85})
context.add_branch_result("license_ocr", {"text": "ABC-123", "confidence": 0.76})
mapped_fields = mapper.map_fields(db_field_mappings, context)
assert mapped_fields["camera_id"] == "camera_002"
assert mapped_fields["session_id"] == "session_456"
assert mapped_fields["detection_timestamp"] == 1640995300000
assert mapped_fields["object_class"] == "car"
assert mapped_fields["detection_confidence"] == 0.93
assert mapped_fields["track_id"] == 2001
assert mapped_fields["bbox_json"] == "JSON_OBJECT('x1', 150, 'y1', 250, 'x2', 350, 'y2', 450)"
assert mapped_fields["car_brand"] == "Honda"
assert mapped_fields["car_body_type"] == "SUV"
assert mapped_fields["license_plate"] == "ABC-123"
assert mapped_fields["created_at"] == "NOW()"
assert mapped_fields["updated_at"] == "NOW()"
def test_map_redis_keys(self):
"""Test mapping Redis key templates."""
mapper = FieldMapper()
key_templates = [
"inference:{camera_id}:{timestamp}:{session_id}:car",
"detection:{display_id}:{track_id}",
"cropped_image:{camera_id}:{session_id}:{class}",
"metadata:{session_id}:brands:{car_brand_cls.brand}",
"tracking:{camera_id}:active_tracks"
]
detection = DetectionResult("car", 0.88, BoundingBox(200, 300, 400, 500), 3001, 1640995400000)
context = MappingContext(
camera_id="camera_003",
display_id="display_003",
session_id="session_789",
detection=detection,
timestamp=1640995400000
)
context.add_branch_result("car_brand_cls", {"brand": "Ford"})
mapped_keys = [mapper.map_template(template, context) for template in key_templates]
expected_keys = [
"inference:camera_003:1640995400000:session_789:car",
"detection:display_003:3001",
"cropped_image:camera_003:session_789:car",
"metadata:session_789:brands:Ford",
"tracking:camera_003:active_tracks"
]
assert mapped_keys == expected_keys
def test_map_template(self):
"""Test single template mapping."""
mapper = FieldMapper()
template = "Camera {camera_id} detected {class} with {confidence:.2f} confidence at {timestamp}"
detection = DetectionResult("truck", 0.876, BoundingBox(100, 150, 300, 350), 4001, 1640995500000)
context = MappingContext(
camera_id="camera_004",
display_id="display_004",
session_id="session_101",
detection=detection,
timestamp=1640995500000
)
result = mapper.map_template(template, context)
expected = "Camera camera_004 detected truck with 0.88 confidence at 1640995500000"
assert result == expected
def test_validate_field_mappings(self):
"""Test field mapping validation."""
mapper = FieldMapper()
# Valid mappings
valid_mappings = {
"camera_id": "{camera_id}",
"class": "{class}",
"confidence": "{confidence}",
"created_at": "NOW()"
}
assert mapper.validate_field_mappings(valid_mappings) is True
# Invalid mappings (malformed templates)
invalid_mappings = {
"camera_id": "{camera_id", # Missing closing brace
"class": "class}", # Missing opening brace
"confidence": "{nonexistent_field}" # This might be valid depending on context
}
with pytest.raises(FieldMappingError):
mapper.validate_field_mappings(invalid_mappings)
def test_create_context_from_detection(self):
"""Test creating mapping context from detection result."""
mapper = FieldMapper()
detection = DetectionResult("car", 0.95, BoundingBox(50, 100, 250, 300), 5001, 1640995600000)
context = mapper.create_context_from_detection(
detection,
camera_id="camera_005",
display_id="display_005",
session_id="session_202"
)
assert context.camera_id == "camera_005"
assert context.display_id == "display_005"
assert context.session_id == "session_202"
assert context.detection == detection
assert context.timestamp == 1640995600000
def test_format_sql_value(self):
"""Test SQL value formatting."""
mapper = FieldMapper()
# String values should be quoted
assert mapper.format_sql_value("test_string") == "'test_string'"
assert mapper.format_sql_value("John's car") == "'John''s car'" # Escape quotes
# Numeric values should not be quoted
assert mapper.format_sql_value(42) == "42"
assert mapper.format_sql_value(3.14) == "3.14"
assert mapper.format_sql_value(0.95) == "0.95"
# Boolean values
assert mapper.format_sql_value(True) == "TRUE"
assert mapper.format_sql_value(False) == "FALSE"
# None/NULL values
assert mapper.format_sql_value(None) == "NULL"
# SQL functions should pass through
assert mapper.format_sql_value("NOW()") == "NOW()"
assert mapper.format_sql_value("CURRENT_TIMESTAMP") == "CURRENT_TIMESTAMP"
class TestFieldMapperIntegration:
"""Integration tests for field mapping."""
def test_complete_mapping_workflow(self):
"""Test complete field mapping workflow."""
mapper = FieldMapper()
# Simulate complete detection workflow
detection = DetectionResult("car", 0.91, BoundingBox(120, 180, 320, 380), 6001, 1640995700000)
context = MappingContext(
camera_id="camera_006",
display_id="display_006",
session_id="session_303",
detection=detection,
timestamp=1640995700000
)
# Add comprehensive branch results
context.add_branch_result("car_brand_cls", {
"brand": "BMW",
"model": "X5",
"confidence": 0.84,
"top3_brands": ["BMW", "Audi", "Mercedes"]
})
context.add_branch_result("car_bodytype_cls", {
"body_type": "SUV",
"confidence": 0.79,
"features": ["tall", "4_doors", "roof_rails"]
})
context.add_branch_result("car_color_cls", {
"color": "Black",
"confidence": 0.73,
"rgb_values": [20, 25, 30]
})
context.add_branch_result("license_ocr", {
"text": "XYZ-789",
"confidence": 0.68,
"region_bbox": [150, 320, 290, 360]
})
# Database field mapping
db_mappings = {
"camera_id": "{camera_id}",
"display_id": "{display_id}",
"session_id": "{session_id}",
"detection_timestamp": "{timestamp}",
"object_class": "{class}",
"detection_confidence": "{confidence}",
"track_id": "{track_id}",
"bbox_x1": "{bbox.x1}",
"bbox_y1": "{bbox.y1}",
"bbox_x2": "{bbox.x2}",
"bbox_y2": "{bbox.y2}",
"bbox_area": "{bbox.area}",
"car_brand": "{car_brand_cls.brand}",
"car_model": "{car_brand_cls.model}",
"car_body_type": "{car_bodytype_cls.body_type}",
"car_color": "{car_color_cls.color}",
"license_plate": "{license_ocr.text}",
"brand_confidence": "{car_brand_cls.confidence}",
"bodytype_confidence": "{car_bodytype_cls.confidence}",
"color_confidence": "{car_color_cls.confidence}",
"license_confidence": "{license_ocr.confidence}",
"detection_summary": "{car_brand_cls.brand} {car_bodytype_cls.body_type} ({car_color_cls.color})",
"created_at": "NOW()",
"updated_at": "NOW()"
}
mapped_db_fields = mapper.map_fields(db_mappings, context)
# Verify all mappings
assert mapped_db_fields["camera_id"] == "camera_006"
assert mapped_db_fields["session_id"] == "session_303"
assert mapped_db_fields["object_class"] == "car"
assert mapped_db_fields["detection_confidence"] == 0.91
assert mapped_db_fields["track_id"] == 6001
assert mapped_db_fields["bbox_area"] == 40000 # 200 * 200
assert mapped_db_fields["car_brand"] == "BMW"
assert mapped_db_fields["car_model"] == "X5"
assert mapped_db_fields["car_body_type"] == "SUV"
assert mapped_db_fields["car_color"] == "Black"
assert mapped_db_fields["license_plate"] == "XYZ-789"
assert mapped_db_fields["detection_summary"] == "BMW SUV (Black)"
# Redis key mapping
redis_key_templates = [
"detection:{camera_id}:{session_id}:main",
"cropped:{camera_id}:{session_id}:car_image",
"metadata:{session_id}:brand:{car_brand_cls.brand}",
"tracking:{camera_id}:track_{track_id}",
"classification:{session_id}:results"
]
mapped_redis_keys = [
mapper.map_template(template, context)
for template in redis_key_templates
]
expected_redis_keys = [
"detection:camera_006:session_303:main",
"cropped:camera_006:session_303:car_image",
"metadata:session_303:brand:BMW",
"tracking:camera_006:track_6001",
"classification:session_303:results"
]
assert mapped_redis_keys == expected_redis_keys
def test_error_handling_and_recovery(self):
"""Test error handling and recovery in field mapping."""
mapper = FieldMapper(allow_missing=True)
# Context with missing detection
context = MappingContext(
camera_id="camera_007",
display_id="display_007",
session_id="session_404"
)
# Partial branch results
context.add_branch_result("car_brand_cls", {"brand": "Unknown"})
# Missing car_bodytype_cls branch
# Field mappings with some missing data
mappings = {
"camera_id": "{camera_id}",
"detection_class": "{class|Unknown}",
"confidence": "{confidence|0.0}",
"car_brand": "{car_brand_cls.brand|N/A}",
"car_body_type": "{car_bodytype_cls.body_type|Unknown}",
"car_model": "{car_brand_cls.model|N/A}"
}
mapped_fields = mapper.map_fields(mappings, context)
assert mapped_fields["camera_id"] == "camera_007"
assert mapped_fields["detection_class"] == "Unknown"
assert mapped_fields["confidence"] == "0.0"
assert mapped_fields["car_brand"] == "Unknown"
assert mapped_fields["car_body_type"] == "Unknown"
assert mapped_fields["car_model"] == "N/A"

View file

@ -0,0 +1,921 @@
"""
Unit tests for pipeline execution functionality.
"""
import pytest
import asyncio
import numpy as np
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from concurrent.futures import ThreadPoolExecutor
import json
from detector_worker.pipeline.pipeline_executor import (
PipelineExecutor,
PipelineContext,
PipelineResult,
BranchResult,
ExecutionMode
)
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
from detector_worker.core.exceptions import PipelineError, ModelError, ActionError
class TestPipelineContext:
"""Test pipeline context data structure."""
def test_creation(self):
"""Test pipeline context creation."""
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=np.zeros((480, 640, 3), dtype=np.uint8)
)
assert context.camera_id == "camera_001"
assert context.display_id == "display_001"
assert context.session_id == "session_123"
assert context.timestamp == 1640995200000
assert context.frame_data.shape == (480, 640, 3)
assert context.metadata == {}
assert context.crop_region is None
def test_creation_with_crop_region(self):
"""Test context creation with crop region."""
crop_region = (100, 200, 300, 400)
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=np.zeros((480, 640, 3), dtype=np.uint8),
crop_region=crop_region
)
assert context.crop_region == crop_region
def test_add_metadata(self):
"""Test adding metadata to context."""
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=np.zeros((480, 640, 3), dtype=np.uint8)
)
context.add_metadata("model_id", "yolo_v8")
context.add_metadata("confidence_threshold", 0.8)
assert context.metadata["model_id"] == "yolo_v8"
assert context.metadata["confidence_threshold"] == 0.8
def test_get_cropped_frame(self):
"""Test getting cropped frame."""
frame = np.ones((480, 640, 3), dtype=np.uint8) * 255
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=frame,
crop_region=(100, 200, 300, 400)
)
cropped = context.get_cropped_frame()
assert cropped.shape == (200, 200, 3) # 400-200, 300-100
assert np.all(cropped == 255)
def test_get_cropped_frame_no_crop(self):
"""Test getting frame when no crop region specified."""
frame = np.ones((480, 640, 3), dtype=np.uint8) * 255
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=frame
)
cropped = context.get_cropped_frame()
assert np.array_equal(cropped, frame)
class TestBranchResult:
"""Test branch execution result."""
def test_creation_success(self):
"""Test successful branch result creation."""
detections = [
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
]
result = BranchResult(
branch_id="car_brand_cls",
success=True,
detections=detections,
metadata={"brand": "Toyota"},
execution_time=0.15
)
assert result.branch_id == "car_brand_cls"
assert result.success is True
assert len(result.detections) == 1
assert result.metadata["brand"] == "Toyota"
assert result.execution_time == 0.15
assert result.error is None
def test_creation_failure(self):
"""Test failed branch result creation."""
result = BranchResult(
branch_id="car_brand_cls",
success=False,
error="Model inference failed",
execution_time=0.05
)
assert result.branch_id == "car_brand_cls"
assert result.success is False
assert result.detections == []
assert result.metadata == {}
assert result.error == "Model inference failed"
class TestPipelineResult:
"""Test pipeline execution result."""
def test_creation_success(self):
"""Test successful pipeline result creation."""
main_detections = [
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
]
branch_results = {
"car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1),
"car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12)
}
result = PipelineResult(
success=True,
detections=main_detections,
branch_results=branch_results,
total_execution_time=0.5
)
assert result.success is True
assert len(result.detections) == 1
assert len(result.branch_results) == 2
assert result.total_execution_time == 0.5
assert result.error is None
def test_creation_failure(self):
"""Test failed pipeline result creation."""
result = PipelineResult(
success=False,
error="Pipeline execution failed",
total_execution_time=0.1
)
assert result.success is False
assert result.detections == []
assert result.branch_results == {}
assert result.error == "Pipeline execution failed"
def test_get_combined_results(self):
"""Test getting combined results from all branches."""
main_detections = [
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
]
branch_results = {
"car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1),
"car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12)
}
result = PipelineResult(
success=True,
detections=main_detections,
branch_results=branch_results,
total_execution_time=0.5
)
combined = result.get_combined_results()
assert "brand" in combined
assert "body_type" in combined
assert combined["brand"] == "Toyota"
assert combined["body_type"] == "Sedan"
class TestPipelineExecutor:
"""Test pipeline execution functionality."""
def test_initialization(self):
"""Test pipeline executor initialization."""
executor = PipelineExecutor()
assert isinstance(executor.thread_pool, ThreadPoolExecutor)
assert executor.max_workers == 4
assert executor.execution_mode == ExecutionMode.PARALLEL
assert executor.timeout == 30.0
def test_initialization_custom_config(self):
"""Test initialization with custom configuration."""
config = {
"max_workers": 8,
"execution_mode": "sequential",
"timeout": 60.0
}
executor = PipelineExecutor(config)
assert executor.max_workers == 8
assert executor.execution_mode == ExecutionMode.SEQUENTIAL
assert executor.timeout == 60.0
@pytest.mark.asyncio
async def test_execute_pipeline_simple(self, mock_yolo_model, mock_frame):
"""Test simple pipeline execution."""
# Mock pipeline configuration
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [],
"actions": []
}
# Mock detection result
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = torch.tensor([1001])
mock_yolo_model.track.return_value = [mock_result]
executor = PipelineExecutor()
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is True
assert len(result.detections) == 1
assert result.detections[0].class_name == "0" # Default class name
assert result.detections[0].confidence == 0.9
@pytest.mark.asyncio
async def test_execute_pipeline_with_branches(self, mock_yolo_model, mock_frame):
"""Test pipeline execution with classification branches."""
import torch
# Mock main detection
mock_detection_result = Mock()
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0] # car detection
])
mock_detection_result.boxes.id = torch.tensor([1001])
# Mock classification results
mock_brand_result = Mock()
mock_brand_result.probs = Mock()
mock_brand_result.probs.top1 = 2 # Toyota
mock_brand_result.probs.top1conf = 0.85
mock_bodytype_result = Mock()
mock_bodytype_result.probs = Mock()
mock_bodytype_result.probs.top1 = 1 # Sedan
mock_bodytype_result.probs.top1conf = 0.78
mock_yolo_model.track.return_value = [mock_detection_result]
mock_yolo_model.predict.return_value = [mock_brand_result]
mock_brand_model = Mock()
mock_brand_model.predict.return_value = [mock_brand_result]
mock_brand_model.names = {0: "Honda", 1: "Ford", 2: "Toyota"}
mock_bodytype_model = Mock()
mock_bodytype_model.predict.return_value = [mock_bodytype_result]
mock_bodytype_model.names = {0: "SUV", 1: "Sedan", 2: "Hatchback"}
# Pipeline configuration with branches
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [
{
"modelId": "car_brand_cls",
"modelFile": "car_brand.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True,
"crop": True,
"cropClass": "car"
},
{
"modelId": "car_bodytype_cls",
"modelFile": "car_bodytype.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True,
"crop": True,
"cropClass": "car"
}
],
"actions": []
}
executor = PipelineExecutor()
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
def get_model_side_effect(model_id, camera_id):
if model_id == "car_detection_v1":
return mock_yolo_model
elif model_id == "car_brand_cls":
return mock_brand_model
elif model_id == "car_bodytype_cls":
return mock_bodytype_model
return None
mock_model_manager.return_value.get_model.side_effect = get_model_side_effect
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is True
assert len(result.detections) == 1
assert len(result.branch_results) == 2
# Check branch results
assert "car_brand_cls" in result.branch_results
assert "car_bodytype_cls" in result.branch_results
brand_result = result.branch_results["car_brand_cls"]
assert brand_result.success is True
assert brand_result.metadata.get("brand") == "Toyota"
bodytype_result = result.branch_results["car_bodytype_cls"]
assert bodytype_result.success is True
assert bodytype_result.metadata.get("body_type") == "Sedan"
@pytest.mark.asyncio
async def test_execute_pipeline_sequential_mode(self, mock_yolo_model, mock_frame):
"""Test pipeline execution in sequential mode."""
import torch
config = {"execution_mode": "sequential"}
executor = PipelineExecutor(config)
# Mock detection result
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = torch.tensor([1001])
mock_yolo_model.track.return_value = [mock_result]
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [
{
"modelId": "car_brand_cls",
"modelFile": "car_brand.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": False # Sequential execution
}
],
"actions": []
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is True
assert executor.execution_mode == ExecutionMode.SEQUENTIAL
@pytest.mark.asyncio
async def test_execute_pipeline_with_actions(self, mock_yolo_model, mock_frame):
"""Test pipeline execution with actions."""
import torch
# Mock detection result
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = torch.tensor([1001])
mock_yolo_model.track.return_value = [mock_result]
# Pipeline configuration with actions
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [],
"actions": [
{
"type": "redis_save_image",
"region": "car",
"key": "inference:{display_id}:{timestamp}:{session_id}",
"expire_seconds": 600
},
{
"type": "postgresql_insert",
"table": "detections",
"fields": {
"camera_id": "{camera_id}",
"detection_class": "{class}",
"confidence": "{confidence}"
}
}
]
}
executor = PipelineExecutor()
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager, \
patch('detector_worker.pipeline.action_executor.ActionExecutor') as mock_action_executor:
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
mock_action_executor.return_value.execute_actions = AsyncMock(return_value=True)
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is True
# Actions should be executed
mock_action_executor.return_value.execute_actions.assert_called_once()
@pytest.mark.asyncio
async def test_execute_pipeline_model_error(self, mock_frame):
"""Test pipeline execution with model error."""
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [],
"actions": []
}
executor = PipelineExecutor()
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
# Model manager raises error
mock_model_manager.return_value.get_model.side_effect = ModelError("Model not found")
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is False
assert "Model not found" in result.error
@pytest.mark.asyncio
async def test_execute_pipeline_timeout(self, mock_yolo_model, mock_frame):
"""Test pipeline execution timeout."""
import torch
# Configure short timeout
config = {"timeout": 0.001} # Very short timeout
executor = PipelineExecutor(config)
# Mock slow model inference
def slow_inference(*args, **kwargs):
import time
time.sleep(1) # Longer than timeout
mock_result = Mock()
mock_result.boxes = None
return [mock_result]
mock_yolo_model.track.side_effect = slow_inference
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [],
"actions": []
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is False
assert "timeout" in result.error.lower()
@pytest.mark.asyncio
async def test_execute_branch_parallel(self, mock_frame):
"""Test parallel branch execution."""
import torch
# Mock classification model
mock_brand_model = Mock()
mock_result = Mock()
mock_result.probs = Mock()
mock_result.probs.top1 = 1
mock_result.probs.top1conf = 0.85
mock_brand_model.predict.return_value = [mock_result]
mock_brand_model.names = {0: "Honda", 1: "Toyota", 2: "Ford"}
executor = PipelineExecutor()
# Branch configuration
branch_config = {
"modelId": "car_brand_cls",
"modelFile": "car_brand.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True,
"crop": True,
"cropClass": "car"
}
# Mock detected regions
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_brand_model
result = await executor._execute_branch(branch_config, regions, context)
assert result.success is True
assert result.branch_id == "car_brand_cls"
assert result.metadata.get("brand") == "Toyota"
assert result.execution_time > 0
@pytest.mark.asyncio
async def test_execute_branch_no_trigger_class(self, mock_frame):
"""Test branch execution when trigger class not detected."""
executor = PipelineExecutor()
branch_config = {
"modelId": "car_brand_cls",
"modelFile": "car_brand.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7
}
# No car detected
regions = {
"truck": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("truck", 0.9, BoundingBox(100, 200, 300, 400), 1002)
}
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
result = await executor._execute_branch(branch_config, regions, context)
assert result.success is False
assert "trigger class not detected" in result.error.lower()
def test_wait_for_branches(self):
"""Test waiting for specific branches to complete."""
executor = PipelineExecutor()
# Mock completed branch results
branch_results = {
"car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1),
"car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12),
"license_ocr": BranchResult("license_ocr", True, [], {"license": "ABC123"}, 0.2)
}
# Wait for specific branches
wait_for = ["car_brand_cls", "car_bodytype_cls"]
completed = executor._wait_for_branches(branch_results, wait_for, timeout=1.0)
assert completed is True
# Wait for non-existent branch (should timeout)
wait_for_missing = ["car_brand_cls", "nonexistent_branch"]
completed = executor._wait_for_branches(branch_results, wait_for_missing, timeout=0.1)
assert completed is False
def test_validate_pipeline_config(self):
"""Test pipeline configuration validation."""
executor = PipelineExecutor()
# Valid configuration
valid_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8
}
assert executor._validate_pipeline_config(valid_config) is True
# Invalid configuration (missing required fields)
invalid_config = {
"modelFile": "car_detection.pt"
# Missing modelId
}
with pytest.raises(PipelineError):
executor._validate_pipeline_config(invalid_config)
def test_crop_frame_for_detection(self, mock_frame):
"""Test frame cropping for detection."""
executor = PipelineExecutor()
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
cropped = executor._crop_frame_for_detection(mock_frame, detection)
assert cropped.shape == (200, 200, 3) # 400-200, 300-100
def test_crop_frame_invalid_bounds(self, mock_frame):
"""Test frame cropping with invalid bounds."""
executor = PipelineExecutor()
# Detection outside frame bounds
detection = DetectionResult("car", 0.9, BoundingBox(-100, -200, 50, 100), 1001)
cropped = executor._crop_frame_for_detection(mock_frame, detection)
# Should handle bounds gracefully
assert cropped.shape[0] > 0
assert cropped.shape[1] > 0
class TestPipelineExecutorPerformance:
"""Test pipeline executor performance and optimization."""
@pytest.mark.asyncio
async def test_parallel_branch_execution_performance(self, mock_frame):
"""Test that parallel execution is faster than sequential."""
import time
import torch
def slow_inference(*args, **kwargs):
time.sleep(0.1) # Simulate slow inference
mock_result = Mock()
mock_result.probs = Mock()
mock_result.probs.top1 = 1
mock_result.probs.top1conf = 0.85
return [mock_result]
mock_model = Mock()
mock_model.predict.side_effect = slow_inference
mock_model.names = {0: "Class0", 1: "Class1"}
# Test parallel execution
parallel_executor = PipelineExecutor({"execution_mode": "parallel", "max_workers": 2})
branch_configs = [
{
"modelId": f"branch_{i}",
"modelFile": f"branch_{i}.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True
}
for i in range(3) # 3 branches
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_model
start_time = time.time()
# Execute branches in parallel
tasks = [
parallel_executor._execute_branch(config, regions, context)
for config in branch_configs
]
results = await asyncio.gather(*tasks)
parallel_time = time.time() - start_time
# Parallel execution should be faster than 3 * 0.1 seconds
assert parallel_time < 0.25 # Allow some overhead
assert len(results) == 3
assert all(result.success for result in results)
def test_thread_pool_management(self):
"""Test thread pool creation and management."""
# Test different worker counts
for workers in [1, 2, 4, 8]:
executor = PipelineExecutor({"max_workers": workers})
assert executor.max_workers == workers
assert executor.thread_pool._max_workers == workers
def test_memory_management_large_frames(self):
"""Test memory management with large frames."""
executor = PipelineExecutor()
# Create large frame
large_frame = np.ones((1080, 1920, 3), dtype=np.uint8) * 128
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=large_frame,
crop_region=(500, 400, 1000, 800)
)
# Get cropped frame
cropped = context.get_cropped_frame()
# Should reduce memory usage
assert cropped.shape == (400, 500, 3) # Much smaller than original
assert cropped.nbytes < large_frame.nbytes
class TestPipelineExecutorErrorHandling:
"""Test comprehensive error handling."""
@pytest.mark.asyncio
async def test_branch_execution_error_isolation(self, mock_frame):
"""Test that errors in one branch don't affect others."""
executor = PipelineExecutor()
# Mock models - one fails, one succeeds
failing_model = Mock()
failing_model.predict.side_effect = Exception("Model crashed")
success_model = Mock()
mock_result = Mock()
mock_result.probs = Mock()
mock_result.probs.top1 = 1
mock_result.probs.top1conf = 0.85
success_model.predict.return_value = [mock_result]
success_model.names = {0: "Class0", 1: "Class1"}
branch_configs = [
{
"modelId": "failing_branch",
"modelFile": "failing.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True
},
{
"modelId": "success_branch",
"modelFile": "success.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
def get_model_side_effect(model_id, camera_id):
if model_id == "failing_branch":
return failing_model
elif model_id == "success_branch":
return success_model
return None
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.side_effect = get_model_side_effect
# Execute branches
tasks = [
executor._execute_branch(config, regions, context)
for config in branch_configs
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# One should fail, one should succeed
failing_result = next(r for r in results if isinstance(r, BranchResult) and r.branch_id == "failing_branch")
success_result = next(r for r in results if isinstance(r, BranchResult) and r.branch_id == "success_branch")
assert failing_result.success is False
assert "Model crashed" in failing_result.error
assert success_result.success is True
assert success_result.error is None