Refactor: PHASE 8: Testing & Integration

This commit is contained in:
ziesorx 2025-09-12 18:55:23 +07:00
parent af34f4fd08
commit 9e8c6804a7
32 changed files with 17128 additions and 0 deletions

View file

@ -0,0 +1,921 @@
"""
Unit tests for pipeline execution functionality.
"""
import pytest
import asyncio
import numpy as np
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from concurrent.futures import ThreadPoolExecutor
import json
from detector_worker.pipeline.pipeline_executor import (
PipelineExecutor,
PipelineContext,
PipelineResult,
BranchResult,
ExecutionMode
)
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
from detector_worker.core.exceptions import PipelineError, ModelError, ActionError
class TestPipelineContext:
"""Test pipeline context data structure."""
def test_creation(self):
"""Test pipeline context creation."""
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=np.zeros((480, 640, 3), dtype=np.uint8)
)
assert context.camera_id == "camera_001"
assert context.display_id == "display_001"
assert context.session_id == "session_123"
assert context.timestamp == 1640995200000
assert context.frame_data.shape == (480, 640, 3)
assert context.metadata == {}
assert context.crop_region is None
def test_creation_with_crop_region(self):
"""Test context creation with crop region."""
crop_region = (100, 200, 300, 400)
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=np.zeros((480, 640, 3), dtype=np.uint8),
crop_region=crop_region
)
assert context.crop_region == crop_region
def test_add_metadata(self):
"""Test adding metadata to context."""
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=np.zeros((480, 640, 3), dtype=np.uint8)
)
context.add_metadata("model_id", "yolo_v8")
context.add_metadata("confidence_threshold", 0.8)
assert context.metadata["model_id"] == "yolo_v8"
assert context.metadata["confidence_threshold"] == 0.8
def test_get_cropped_frame(self):
"""Test getting cropped frame."""
frame = np.ones((480, 640, 3), dtype=np.uint8) * 255
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=frame,
crop_region=(100, 200, 300, 400)
)
cropped = context.get_cropped_frame()
assert cropped.shape == (200, 200, 3) # 400-200, 300-100
assert np.all(cropped == 255)
def test_get_cropped_frame_no_crop(self):
"""Test getting frame when no crop region specified."""
frame = np.ones((480, 640, 3), dtype=np.uint8) * 255
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=frame
)
cropped = context.get_cropped_frame()
assert np.array_equal(cropped, frame)
class TestBranchResult:
"""Test branch execution result."""
def test_creation_success(self):
"""Test successful branch result creation."""
detections = [
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
]
result = BranchResult(
branch_id="car_brand_cls",
success=True,
detections=detections,
metadata={"brand": "Toyota"},
execution_time=0.15
)
assert result.branch_id == "car_brand_cls"
assert result.success is True
assert len(result.detections) == 1
assert result.metadata["brand"] == "Toyota"
assert result.execution_time == 0.15
assert result.error is None
def test_creation_failure(self):
"""Test failed branch result creation."""
result = BranchResult(
branch_id="car_brand_cls",
success=False,
error="Model inference failed",
execution_time=0.05
)
assert result.branch_id == "car_brand_cls"
assert result.success is False
assert result.detections == []
assert result.metadata == {}
assert result.error == "Model inference failed"
class TestPipelineResult:
"""Test pipeline execution result."""
def test_creation_success(self):
"""Test successful pipeline result creation."""
main_detections = [
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
]
branch_results = {
"car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1),
"car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12)
}
result = PipelineResult(
success=True,
detections=main_detections,
branch_results=branch_results,
total_execution_time=0.5
)
assert result.success is True
assert len(result.detections) == 1
assert len(result.branch_results) == 2
assert result.total_execution_time == 0.5
assert result.error is None
def test_creation_failure(self):
"""Test failed pipeline result creation."""
result = PipelineResult(
success=False,
error="Pipeline execution failed",
total_execution_time=0.1
)
assert result.success is False
assert result.detections == []
assert result.branch_results == {}
assert result.error == "Pipeline execution failed"
def test_get_combined_results(self):
"""Test getting combined results from all branches."""
main_detections = [
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
]
branch_results = {
"car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1),
"car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12)
}
result = PipelineResult(
success=True,
detections=main_detections,
branch_results=branch_results,
total_execution_time=0.5
)
combined = result.get_combined_results()
assert "brand" in combined
assert "body_type" in combined
assert combined["brand"] == "Toyota"
assert combined["body_type"] == "Sedan"
class TestPipelineExecutor:
"""Test pipeline execution functionality."""
def test_initialization(self):
"""Test pipeline executor initialization."""
executor = PipelineExecutor()
assert isinstance(executor.thread_pool, ThreadPoolExecutor)
assert executor.max_workers == 4
assert executor.execution_mode == ExecutionMode.PARALLEL
assert executor.timeout == 30.0
def test_initialization_custom_config(self):
"""Test initialization with custom configuration."""
config = {
"max_workers": 8,
"execution_mode": "sequential",
"timeout": 60.0
}
executor = PipelineExecutor(config)
assert executor.max_workers == 8
assert executor.execution_mode == ExecutionMode.SEQUENTIAL
assert executor.timeout == 60.0
@pytest.mark.asyncio
async def test_execute_pipeline_simple(self, mock_yolo_model, mock_frame):
"""Test simple pipeline execution."""
# Mock pipeline configuration
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [],
"actions": []
}
# Mock detection result
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = torch.tensor([1001])
mock_yolo_model.track.return_value = [mock_result]
executor = PipelineExecutor()
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is True
assert len(result.detections) == 1
assert result.detections[0].class_name == "0" # Default class name
assert result.detections[0].confidence == 0.9
@pytest.mark.asyncio
async def test_execute_pipeline_with_branches(self, mock_yolo_model, mock_frame):
"""Test pipeline execution with classification branches."""
import torch
# Mock main detection
mock_detection_result = Mock()
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0] # car detection
])
mock_detection_result.boxes.id = torch.tensor([1001])
# Mock classification results
mock_brand_result = Mock()
mock_brand_result.probs = Mock()
mock_brand_result.probs.top1 = 2 # Toyota
mock_brand_result.probs.top1conf = 0.85
mock_bodytype_result = Mock()
mock_bodytype_result.probs = Mock()
mock_bodytype_result.probs.top1 = 1 # Sedan
mock_bodytype_result.probs.top1conf = 0.78
mock_yolo_model.track.return_value = [mock_detection_result]
mock_yolo_model.predict.return_value = [mock_brand_result]
mock_brand_model = Mock()
mock_brand_model.predict.return_value = [mock_brand_result]
mock_brand_model.names = {0: "Honda", 1: "Ford", 2: "Toyota"}
mock_bodytype_model = Mock()
mock_bodytype_model.predict.return_value = [mock_bodytype_result]
mock_bodytype_model.names = {0: "SUV", 1: "Sedan", 2: "Hatchback"}
# Pipeline configuration with branches
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [
{
"modelId": "car_brand_cls",
"modelFile": "car_brand.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True,
"crop": True,
"cropClass": "car"
},
{
"modelId": "car_bodytype_cls",
"modelFile": "car_bodytype.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True,
"crop": True,
"cropClass": "car"
}
],
"actions": []
}
executor = PipelineExecutor()
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
def get_model_side_effect(model_id, camera_id):
if model_id == "car_detection_v1":
return mock_yolo_model
elif model_id == "car_brand_cls":
return mock_brand_model
elif model_id == "car_bodytype_cls":
return mock_bodytype_model
return None
mock_model_manager.return_value.get_model.side_effect = get_model_side_effect
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is True
assert len(result.detections) == 1
assert len(result.branch_results) == 2
# Check branch results
assert "car_brand_cls" in result.branch_results
assert "car_bodytype_cls" in result.branch_results
brand_result = result.branch_results["car_brand_cls"]
assert brand_result.success is True
assert brand_result.metadata.get("brand") == "Toyota"
bodytype_result = result.branch_results["car_bodytype_cls"]
assert bodytype_result.success is True
assert bodytype_result.metadata.get("body_type") == "Sedan"
@pytest.mark.asyncio
async def test_execute_pipeline_sequential_mode(self, mock_yolo_model, mock_frame):
"""Test pipeline execution in sequential mode."""
import torch
config = {"execution_mode": "sequential"}
executor = PipelineExecutor(config)
# Mock detection result
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = torch.tensor([1001])
mock_yolo_model.track.return_value = [mock_result]
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [
{
"modelId": "car_brand_cls",
"modelFile": "car_brand.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": False # Sequential execution
}
],
"actions": []
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is True
assert executor.execution_mode == ExecutionMode.SEQUENTIAL
@pytest.mark.asyncio
async def test_execute_pipeline_with_actions(self, mock_yolo_model, mock_frame):
"""Test pipeline execution with actions."""
import torch
# Mock detection result
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = torch.tensor([1001])
mock_yolo_model.track.return_value = [mock_result]
# Pipeline configuration with actions
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [],
"actions": [
{
"type": "redis_save_image",
"region": "car",
"key": "inference:{display_id}:{timestamp}:{session_id}",
"expire_seconds": 600
},
{
"type": "postgresql_insert",
"table": "detections",
"fields": {
"camera_id": "{camera_id}",
"detection_class": "{class}",
"confidence": "{confidence}"
}
}
]
}
executor = PipelineExecutor()
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager, \
patch('detector_worker.pipeline.action_executor.ActionExecutor') as mock_action_executor:
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
mock_action_executor.return_value.execute_actions = AsyncMock(return_value=True)
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is True
# Actions should be executed
mock_action_executor.return_value.execute_actions.assert_called_once()
@pytest.mark.asyncio
async def test_execute_pipeline_model_error(self, mock_frame):
"""Test pipeline execution with model error."""
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [],
"actions": []
}
executor = PipelineExecutor()
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
# Model manager raises error
mock_model_manager.return_value.get_model.side_effect = ModelError("Model not found")
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is False
assert "Model not found" in result.error
@pytest.mark.asyncio
async def test_execute_pipeline_timeout(self, mock_yolo_model, mock_frame):
"""Test pipeline execution timeout."""
import torch
# Configure short timeout
config = {"timeout": 0.001} # Very short timeout
executor = PipelineExecutor(config)
# Mock slow model inference
def slow_inference(*args, **kwargs):
import time
time.sleep(1) # Longer than timeout
mock_result = Mock()
mock_result.boxes = None
return [mock_result]
mock_yolo_model.track.side_effect = slow_inference
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [],
"actions": []
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is False
assert "timeout" in result.error.lower()
@pytest.mark.asyncio
async def test_execute_branch_parallel(self, mock_frame):
"""Test parallel branch execution."""
import torch
# Mock classification model
mock_brand_model = Mock()
mock_result = Mock()
mock_result.probs = Mock()
mock_result.probs.top1 = 1
mock_result.probs.top1conf = 0.85
mock_brand_model.predict.return_value = [mock_result]
mock_brand_model.names = {0: "Honda", 1: "Toyota", 2: "Ford"}
executor = PipelineExecutor()
# Branch configuration
branch_config = {
"modelId": "car_brand_cls",
"modelFile": "car_brand.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True,
"crop": True,
"cropClass": "car"
}
# Mock detected regions
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_brand_model
result = await executor._execute_branch(branch_config, regions, context)
assert result.success is True
assert result.branch_id == "car_brand_cls"
assert result.metadata.get("brand") == "Toyota"
assert result.execution_time > 0
@pytest.mark.asyncio
async def test_execute_branch_no_trigger_class(self, mock_frame):
"""Test branch execution when trigger class not detected."""
executor = PipelineExecutor()
branch_config = {
"modelId": "car_brand_cls",
"modelFile": "car_brand.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7
}
# No car detected
regions = {
"truck": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("truck", 0.9, BoundingBox(100, 200, 300, 400), 1002)
}
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
result = await executor._execute_branch(branch_config, regions, context)
assert result.success is False
assert "trigger class not detected" in result.error.lower()
def test_wait_for_branches(self):
"""Test waiting for specific branches to complete."""
executor = PipelineExecutor()
# Mock completed branch results
branch_results = {
"car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1),
"car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12),
"license_ocr": BranchResult("license_ocr", True, [], {"license": "ABC123"}, 0.2)
}
# Wait for specific branches
wait_for = ["car_brand_cls", "car_bodytype_cls"]
completed = executor._wait_for_branches(branch_results, wait_for, timeout=1.0)
assert completed is True
# Wait for non-existent branch (should timeout)
wait_for_missing = ["car_brand_cls", "nonexistent_branch"]
completed = executor._wait_for_branches(branch_results, wait_for_missing, timeout=0.1)
assert completed is False
def test_validate_pipeline_config(self):
"""Test pipeline configuration validation."""
executor = PipelineExecutor()
# Valid configuration
valid_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8
}
assert executor._validate_pipeline_config(valid_config) is True
# Invalid configuration (missing required fields)
invalid_config = {
"modelFile": "car_detection.pt"
# Missing modelId
}
with pytest.raises(PipelineError):
executor._validate_pipeline_config(invalid_config)
def test_crop_frame_for_detection(self, mock_frame):
"""Test frame cropping for detection."""
executor = PipelineExecutor()
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
cropped = executor._crop_frame_for_detection(mock_frame, detection)
assert cropped.shape == (200, 200, 3) # 400-200, 300-100
def test_crop_frame_invalid_bounds(self, mock_frame):
"""Test frame cropping with invalid bounds."""
executor = PipelineExecutor()
# Detection outside frame bounds
detection = DetectionResult("car", 0.9, BoundingBox(-100, -200, 50, 100), 1001)
cropped = executor._crop_frame_for_detection(mock_frame, detection)
# Should handle bounds gracefully
assert cropped.shape[0] > 0
assert cropped.shape[1] > 0
class TestPipelineExecutorPerformance:
"""Test pipeline executor performance and optimization."""
@pytest.mark.asyncio
async def test_parallel_branch_execution_performance(self, mock_frame):
"""Test that parallel execution is faster than sequential."""
import time
import torch
def slow_inference(*args, **kwargs):
time.sleep(0.1) # Simulate slow inference
mock_result = Mock()
mock_result.probs = Mock()
mock_result.probs.top1 = 1
mock_result.probs.top1conf = 0.85
return [mock_result]
mock_model = Mock()
mock_model.predict.side_effect = slow_inference
mock_model.names = {0: "Class0", 1: "Class1"}
# Test parallel execution
parallel_executor = PipelineExecutor({"execution_mode": "parallel", "max_workers": 2})
branch_configs = [
{
"modelId": f"branch_{i}",
"modelFile": f"branch_{i}.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True
}
for i in range(3) # 3 branches
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_model
start_time = time.time()
# Execute branches in parallel
tasks = [
parallel_executor._execute_branch(config, regions, context)
for config in branch_configs
]
results = await asyncio.gather(*tasks)
parallel_time = time.time() - start_time
# Parallel execution should be faster than 3 * 0.1 seconds
assert parallel_time < 0.25 # Allow some overhead
assert len(results) == 3
assert all(result.success for result in results)
def test_thread_pool_management(self):
"""Test thread pool creation and management."""
# Test different worker counts
for workers in [1, 2, 4, 8]:
executor = PipelineExecutor({"max_workers": workers})
assert executor.max_workers == workers
assert executor.thread_pool._max_workers == workers
def test_memory_management_large_frames(self):
"""Test memory management with large frames."""
executor = PipelineExecutor()
# Create large frame
large_frame = np.ones((1080, 1920, 3), dtype=np.uint8) * 128
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=large_frame,
crop_region=(500, 400, 1000, 800)
)
# Get cropped frame
cropped = context.get_cropped_frame()
# Should reduce memory usage
assert cropped.shape == (400, 500, 3) # Much smaller than original
assert cropped.nbytes < large_frame.nbytes
class TestPipelineExecutorErrorHandling:
"""Test comprehensive error handling."""
@pytest.mark.asyncio
async def test_branch_execution_error_isolation(self, mock_frame):
"""Test that errors in one branch don't affect others."""
executor = PipelineExecutor()
# Mock models - one fails, one succeeds
failing_model = Mock()
failing_model.predict.side_effect = Exception("Model crashed")
success_model = Mock()
mock_result = Mock()
mock_result.probs = Mock()
mock_result.probs.top1 = 1
mock_result.probs.top1conf = 0.85
success_model.predict.return_value = [mock_result]
success_model.names = {0: "Class0", 1: "Class1"}
branch_configs = [
{
"modelId": "failing_branch",
"modelFile": "failing.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True
},
{
"modelId": "success_branch",
"modelFile": "success.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
def get_model_side_effect(model_id, camera_id):
if model_id == "failing_branch":
return failing_model
elif model_id == "success_branch":
return success_model
return None
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.side_effect = get_model_side_effect
# Execute branches
tasks = [
executor._execute_branch(config, regions, context)
for config in branch_configs
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# One should fail, one should succeed
failing_result = next(r for r in results if isinstance(r, BranchResult) and r.branch_id == "failing_branch")
success_result = next(r for r in results if isinstance(r, BranchResult) and r.branch_id == "success_branch")
assert failing_result.success is False
assert "Model crashed" in failing_result.error
assert success_result.success is True
assert success_result.error is None