python-detector-worker/tests/unit/storage/test_database_manager.py
2025-09-12 18:55:23 +07:00

976 lines
No EOL
33 KiB
Python

"""
Unit tests for database management functionality.
"""
import pytest
import asyncio
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from datetime import datetime, timedelta
import psycopg2
import uuid
from detector_worker.storage.database_manager import (
DatabaseManager,
DatabaseConfig,
DatabaseConnection,
QueryBuilder,
TransactionManager,
DatabaseError,
ConnectionPoolError
)
from detector_worker.core.exceptions import ConfigurationError
class TestDatabaseConfig:
"""Test database configuration."""
def test_creation_minimal(self):
"""Test creating database config with minimal parameters."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
assert config.host == "localhost"
assert config.port == 5432 # Default port
assert config.database == "test_db"
assert config.username == "test_user"
assert config.password == "test_pass"
assert config.schema == "public" # Default schema
assert config.enabled is True
def test_creation_full(self):
"""Test creating database config with all parameters."""
config = DatabaseConfig(
host="db.example.com",
port=5433,
database="production_db",
username="prod_user",
password="secure_pass",
schema="gas_station_1",
enabled=True,
pool_min_conn=2,
pool_max_conn=20,
pool_timeout=30.0,
connection_timeout=10.0,
ssl_mode="require"
)
assert config.host == "db.example.com"
assert config.port == 5433
assert config.database == "production_db"
assert config.schema == "gas_station_1"
assert config.pool_min_conn == 2
assert config.pool_max_conn == 20
assert config.ssl_mode == "require"
def test_get_connection_string(self):
"""Test generating connection string."""
config = DatabaseConfig(
host="localhost",
port=5432,
database="test_db",
username="test_user",
password="test_pass"
)
conn_string = config.get_connection_string()
expected = "host=localhost port=5432 database=test_db user=test_user password=test_pass"
assert conn_string == expected
def test_get_connection_string_with_ssl(self):
"""Test generating connection string with SSL."""
config = DatabaseConfig(
host="db.example.com",
database="secure_db",
username="user",
password="pass",
ssl_mode="require"
)
conn_string = config.get_connection_string()
assert "sslmode=require" in conn_string
def test_from_dict(self):
"""Test creating config from dictionary."""
config_dict = {
"host": "test-host",
"port": 5433,
"database": "test-db",
"username": "test-user",
"password": "test-pass",
"schema": "test_schema",
"pool_max_conn": 15,
"unknown_field": "ignored"
}
config = DatabaseConfig.from_dict(config_dict)
assert config.host == "test-host"
assert config.port == 5433
assert config.database == "test-db"
assert config.schema == "test_schema"
assert config.pool_max_conn == 15
class TestQueryBuilder:
"""Test SQL query building functionality."""
def test_build_select_query(self):
"""Test building SELECT queries."""
builder = QueryBuilder("test_schema")
query, params = builder.build_select_query(
table="users",
columns=["id", "name", "email"],
where={"status": "active", "age": 25},
order_by="created_at DESC",
limit=10
)
expected_query = (
"SELECT id, name, email FROM test_schema.users "
"WHERE status = %s AND age = %s "
"ORDER BY created_at DESC LIMIT 10"
)
assert query == expected_query
assert params == ["active", 25]
def test_build_select_all_columns(self):
"""Test building SELECT * query."""
builder = QueryBuilder("public")
query, params = builder.build_select_query("products")
expected_query = "SELECT * FROM public.products"
assert query == expected_query
assert params == []
def test_build_insert_query(self):
"""Test building INSERT queries."""
builder = QueryBuilder("inventory")
data = {
"product_name": "Widget",
"price": 19.99,
"quantity": 100,
"created_at": "NOW()"
}
query, params = builder.build_insert_query("products", data)
expected_query = (
"INSERT INTO inventory.products (product_name, price, quantity, created_at) "
"VALUES (%s, %s, %s, NOW()) RETURNING id"
)
assert query == expected_query
assert params == ["Widget", 19.99, 100]
def test_build_update_query(self):
"""Test building UPDATE queries."""
builder = QueryBuilder("sales")
data = {
"status": "shipped",
"shipped_date": "NOW()",
"tracking_number": "ABC123"
}
where_conditions = {"order_id": 12345}
query, params = builder.build_update_query("orders", data, where_conditions)
expected_query = (
"UPDATE sales.orders SET status = %s, shipped_date = NOW(), tracking_number = %s "
"WHERE order_id = %s"
)
assert query == expected_query
assert params == ["shipped", "ABC123", 12345]
def test_build_delete_query(self):
"""Test building DELETE queries."""
builder = QueryBuilder("logs")
where_conditions = {
"level": "DEBUG",
"created_at": "< NOW() - INTERVAL '7 days'"
}
query, params = builder.build_delete_query("application_logs", where_conditions)
expected_query = (
"DELETE FROM logs.application_logs "
"WHERE level = %s AND created_at < NOW() - INTERVAL '7 days'"
)
assert query == expected_query
assert params == ["DEBUG"]
def test_build_create_table_query(self):
"""Test building CREATE TABLE queries."""
builder = QueryBuilder("gas_station_1")
columns = {
"id": "SERIAL PRIMARY KEY",
"session_id": "VARCHAR(255) UNIQUE NOT NULL",
"camera_id": "VARCHAR(255) NOT NULL",
"detection_class": "VARCHAR(100)",
"confidence": "DECIMAL(4,3)",
"bbox_data": "JSON",
"created_at": "TIMESTAMP DEFAULT NOW()",
"updated_at": "TIMESTAMP DEFAULT NOW()"
}
query = builder.build_create_table_query("detections", columns)
expected_parts = [
"CREATE TABLE IF NOT EXISTS gas_station_1.detections",
"id SERIAL PRIMARY KEY",
"session_id VARCHAR(255) UNIQUE NOT NULL",
"camera_id VARCHAR(255) NOT NULL",
"bbox_data JSON",
"created_at TIMESTAMP DEFAULT NOW()"
]
for part in expected_parts:
assert part in query
def test_escape_identifier(self):
"""Test SQL identifier escaping."""
builder = QueryBuilder("test")
assert builder.escape_identifier("table") == '"table"'
assert builder.escape_identifier("column_name") == '"column_name"'
assert builder.escape_identifier("user-table") == '"user-table"'
def test_format_value_for_sql(self):
"""Test SQL value formatting."""
builder = QueryBuilder("test")
# Regular values should use placeholder
assert builder.format_value_for_sql("string") == ("%s", "string")
assert builder.format_value_for_sql(42) == ("%s", 42)
assert builder.format_value_for_sql(3.14) == ("%s", 3.14)
# SQL functions should be literal
assert builder.format_value_for_sql("NOW()") == ("NOW()", None)
assert builder.format_value_for_sql("CURRENT_TIMESTAMP") == ("CURRENT_TIMESTAMP", None)
assert builder.format_value_for_sql("UUID()") == ("UUID()", None)
class TestDatabaseConnection:
"""Test database connection management."""
def test_creation(self, mock_database_connection):
"""Test connection creation."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
assert conn.config == config
assert conn.connection == mock_database_connection
assert conn.is_connected is True
def test_execute_query(self, mock_database_connection):
"""Test query execution."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchall.return_value = [
(1, "John", "john@example.com"),
(2, "Jane", "jane@example.com")
]
mock_cursor.rowcount = 2
conn = DatabaseConnection(config, mock_database_connection)
query = "SELECT id, name, email FROM users WHERE status = %s"
params = ["active"]
result = conn.execute_query(query, params)
assert result == [
(1, "John", "john@example.com"),
(2, "Jane", "jane@example.com")
]
mock_cursor.execute.assert_called_once_with(query, params)
mock_cursor.fetchall.assert_called_once()
def test_execute_query_single_result(self, mock_database_connection):
"""Test query execution with single result."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchone.return_value = (1, "John", "john@example.com")
conn = DatabaseConnection(config, mock_database_connection)
result = conn.execute_query("SELECT * FROM users WHERE id = %s", [1], fetch_one=True)
assert result == (1, "John", "john@example.com")
mock_cursor.fetchone.assert_called_once()
def test_execute_query_no_fetch(self, mock_database_connection):
"""Test query execution without fetching results."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.rowcount = 1
conn = DatabaseConnection(config, mock_database_connection)
result = conn.execute_query(
"INSERT INTO users (name) VALUES (%s)",
["John"],
fetch_results=False
)
assert result == 1 # Row count
mock_cursor.execute.assert_called_once()
mock_cursor.fetchall.assert_not_called()
mock_cursor.fetchone.assert_not_called()
def test_execute_query_error(self, mock_database_connection):
"""Test query execution error handling."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.execute.side_effect = psycopg2.Error("Database error")
conn = DatabaseConnection(config, mock_database_connection)
with pytest.raises(DatabaseError) as exc_info:
conn.execute_query("SELECT * FROM invalid_table")
assert "Database error" in str(exc_info.value)
def test_commit_transaction(self, mock_database_connection):
"""Test transaction commit."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
conn.commit()
mock_database_connection.commit.assert_called_once()
def test_rollback_transaction(self, mock_database_connection):
"""Test transaction rollback."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
conn.rollback()
mock_database_connection.rollback.assert_called_once()
def test_close_connection(self, mock_database_connection):
"""Test connection closing."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
conn.close()
assert conn.is_connected is False
mock_database_connection.close.assert_called_once()
class TestTransactionManager:
"""Test transaction management."""
def test_transaction_context_success(self, mock_database_connection):
"""Test successful transaction context."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
tx_manager = TransactionManager(conn)
with tx_manager:
# Simulate some database operations
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
# Should commit on successful exit
mock_database_connection.commit.assert_called_once()
mock_database_connection.rollback.assert_not_called()
def test_transaction_context_error(self, mock_database_connection):
"""Test transaction context with error."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
tx_manager = TransactionManager(conn)
with pytest.raises(DatabaseError):
with tx_manager:
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
# Simulate an error
raise DatabaseError("Something went wrong")
# Should rollback on error
mock_database_connection.rollback.assert_called_once()
mock_database_connection.commit.assert_not_called()
class TestDatabaseManager:
"""Test main database manager functionality."""
def test_initialization(self):
"""Test database manager initialization."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
assert manager.config == config
assert isinstance(manager.query_builder, QueryBuilder)
assert manager.query_builder.schema == "gas_station_1"
assert manager.connection is None
@pytest.mark.asyncio
async def test_connect_success(self):
"""Test successful database connection."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
with patch('psycopg2.connect') as mock_connect:
mock_connection = Mock()
mock_connect.return_value = mock_connection
await manager.connect()
assert manager.connection is not None
assert manager.is_connected is True
mock_connect.assert_called_once()
@pytest.mark.asyncio
async def test_connect_failure(self):
"""Test database connection failure."""
config = DatabaseConfig(
host="nonexistent-host",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
with patch('psycopg2.connect') as mock_connect:
mock_connect.side_effect = psycopg2.Error("Connection failed")
with pytest.raises(DatabaseError) as exc_info:
await manager.connect()
assert "Connection failed" in str(exc_info.value)
assert manager.is_connected is False
@pytest.mark.asyncio
async def test_disconnect(self):
"""Test database disconnection."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
# Mock connection
mock_connection = Mock()
manager.connection = DatabaseConnection(config, mock_connection)
await manager.disconnect()
assert manager.connection is None
mock_connection.close.assert_called_once()
@pytest.mark.asyncio
async def test_execute_query(self, mock_database_connection):
"""Test query execution through manager."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchall.return_value = [(1, "Test"), (2, "Data")]
result = await manager.execute_query("SELECT * FROM test_table")
assert result == [(1, "Test"), (2, "Data")]
mock_cursor.execute.assert_called_once()
@pytest.mark.asyncio
async def test_execute_query_not_connected(self):
"""Test query execution when not connected."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
with pytest.raises(DatabaseError) as exc_info:
await manager.execute_query("SELECT * FROM test_table")
assert "not connected" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_insert_record(self, mock_database_connection):
"""Test inserting a record."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchone.return_value = (123,) # Returned ID
data = {
"session_id": "session_123",
"camera_id": "camera_001",
"detection_class": "car",
"confidence": 0.95,
"created_at": "NOW()"
}
record_id = await manager.insert_record("car_detections", data)
assert record_id == 123
mock_cursor.execute.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_update_record(self, mock_database_connection):
"""Test updating a record."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.rowcount = 1
data = {
"car_brand": "Toyota",
"car_body_type": "Sedan",
"updated_at": "NOW()"
}
where_conditions = {"session_id": "session_123"}
rows_affected = await manager.update_record("car_info", data, where_conditions)
assert rows_affected == 1
mock_cursor.execute.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_delete_records(self, mock_database_connection):
"""Test deleting records."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.rowcount = 3
where_conditions = {
"created_at": "< NOW() - INTERVAL '30 days'",
"processed": True
}
rows_deleted = await manager.delete_records("old_detections", where_conditions)
assert rows_deleted == 3
mock_cursor.execute.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_create_table(self, mock_database_connection):
"""Test creating a table."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
columns = {
"id": "SERIAL PRIMARY KEY",
"session_id": "VARCHAR(255) UNIQUE NOT NULL",
"camera_id": "VARCHAR(255) NOT NULL",
"detection_data": "JSON",
"created_at": "TIMESTAMP DEFAULT NOW()"
}
await manager.create_table("test_detections", columns)
mock_database_connection.cursor.return_value.execute.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_table_exists(self, mock_database_connection):
"""Test checking if table exists."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior - table exists
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchone.return_value = (1,)
exists = await manager.table_exists("car_detections")
assert exists is True
mock_cursor.execute.assert_called_once()
# Mock cursor behavior - table doesn't exist
mock_cursor.fetchone.return_value = None
exists = await manager.table_exists("nonexistent_table")
assert exists is False
@pytest.mark.asyncio
async def test_transaction_context(self, mock_database_connection):
"""Test transaction context manager."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
async with manager.transaction():
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
# Should commit on successful completion
mock_database_connection.commit.assert_called()
@pytest.mark.asyncio
async def test_get_table_schema(self, mock_database_connection):
"""Test getting table schema information."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchall.return_value = [
("id", "integer", "NOT NULL"),
("session_id", "character varying", "NOT NULL"),
("created_at", "timestamp without time zone", "DEFAULT now()")
]
schema = await manager.get_table_schema("car_detections")
assert len(schema) == 3
assert schema[0] == ("id", "integer", "NOT NULL")
assert schema[1] == ("session_id", "character varying", "NOT NULL")
@pytest.mark.asyncio
async def test_bulk_insert(self, mock_database_connection):
"""Test bulk insert operation."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
records = [
{"name": "John", "email": "john@example.com"},
{"name": "Jane", "email": "jane@example.com"},
{"name": "Bob", "email": "bob@example.com"}
]
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.rowcount = 3
rows_inserted = await manager.bulk_insert("users", records)
assert rows_inserted == 3
mock_cursor.executemany.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_get_connection_stats(self, mock_database_connection):
"""Test getting connection statistics."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
stats = manager.get_connection_stats()
assert "connected" in stats
assert "host" in stats
assert "database" in stats
assert "schema" in stats
assert stats["connected"] is True
assert stats["host"] == "localhost"
assert stats["database"] == "test_db"
class TestDatabaseManagerIntegration:
"""Integration tests for database manager."""
@pytest.mark.asyncio
async def test_complete_car_detection_workflow(self, mock_database_connection):
"""Test complete car detection database workflow."""
config = DatabaseConfig(
host="localhost",
database="gas_station_db",
username="detector_user",
password="detector_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behaviors for different operations
mock_cursor = mock_database_connection.cursor.return_value
# 1. Create initial detection record
mock_cursor.fetchone.return_value = (456,) # Returned ID
detection_data = {
"session_id": str(uuid.uuid4()),
"camera_id": "camera_001",
"display_id": "display_001",
"detection_class": "car",
"confidence": 0.92,
"bbox_x1": 100,
"bbox_y1": 200,
"bbox_x2": 300,
"bbox_y2": 400,
"track_id": 1001,
"created_at": "NOW()"
}
detection_id = await manager.insert_record("car_detections", detection_data)
assert detection_id == 456
# 2. Update with classification results
mock_cursor.rowcount = 1
classification_data = {
"car_brand": "Toyota",
"car_model": "Camry",
"car_body_type": "Sedan",
"car_color": "Blue",
"brand_confidence": 0.87,
"bodytype_confidence": 0.82,
"color_confidence": 0.79,
"updated_at": "NOW()"
}
where_conditions = {"session_id": detection_data["session_id"]}
rows_updated = await manager.update_record("car_detections", classification_data, where_conditions)
assert rows_updated == 1
# 3. Query final results
mock_cursor.fetchall.return_value = [
(456, detection_data["session_id"], "camera_001", "car", 0.92, "Toyota", "Sedan")
]
results = await manager.execute_query(
"SELECT id, session_id, camera_id, detection_class, confidence, car_brand, car_body_type "
"FROM gas_station_1.car_detections WHERE session_id = %s",
[detection_data["session_id"]]
)
assert len(results) == 1
assert results[0][0] == 456 # ID
assert results[0][3] == "car" # detection_class
assert results[0][5] == "Toyota" # car_brand
# Verify all database operations were called
assert mock_cursor.execute.call_count == 3
assert mock_database_connection.commit.call_count == 2
@pytest.mark.asyncio
async def test_error_handling_and_recovery(self, mock_database_connection):
"""Test error handling and recovery scenarios."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Test transaction rollback on error
mock_cursor = mock_database_connection.cursor.return_value
with pytest.raises(DatabaseError):
async with manager.transaction():
# First operation succeeds
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
# Second operation fails
mock_cursor.execute.side_effect = psycopg2.Error("Constraint violation")
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
# Should have rolled back
mock_database_connection.rollback.assert_called_once()
mock_database_connection.commit.assert_not_called()
@pytest.mark.asyncio
async def test_connection_recovery(self):
"""Test automatic connection recovery."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
with patch('psycopg2.connect') as mock_connect:
# First connection attempt fails
mock_connect.side_effect = [
psycopg2.Error("Connection refused"),
Mock() # Second attempt succeeds
]
# First attempt should fail
with pytest.raises(DatabaseError):
await manager.connect()
# Second attempt should succeed
await manager.connect()
assert manager.is_connected is True