Refactor: PHASE 8: Testing & Integration
This commit is contained in:
parent
af34f4fd08
commit
9e8c6804a7
32 changed files with 17128 additions and 0 deletions
976
tests/unit/storage/test_database_manager.py
Normal file
976
tests/unit/storage/test_database_manager.py
Normal file
|
@ -0,0 +1,976 @@
|
|||
"""
|
||||
Unit tests for database management functionality.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, MagicMock, patch, AsyncMock
|
||||
from datetime import datetime, timedelta
|
||||
import psycopg2
|
||||
import uuid
|
||||
|
||||
from detector_worker.storage.database_manager import (
|
||||
DatabaseManager,
|
||||
DatabaseConfig,
|
||||
DatabaseConnection,
|
||||
QueryBuilder,
|
||||
TransactionManager,
|
||||
DatabaseError,
|
||||
ConnectionPoolError
|
||||
)
|
||||
from detector_worker.core.exceptions import ConfigurationError
|
||||
|
||||
|
||||
class TestDatabaseConfig:
|
||||
"""Test database configuration."""
|
||||
|
||||
def test_creation_minimal(self):
|
||||
"""Test creating database config with minimal parameters."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
assert config.host == "localhost"
|
||||
assert config.port == 5432 # Default port
|
||||
assert config.database == "test_db"
|
||||
assert config.username == "test_user"
|
||||
assert config.password == "test_pass"
|
||||
assert config.schema == "public" # Default schema
|
||||
assert config.enabled is True
|
||||
|
||||
def test_creation_full(self):
|
||||
"""Test creating database config with all parameters."""
|
||||
config = DatabaseConfig(
|
||||
host="db.example.com",
|
||||
port=5433,
|
||||
database="production_db",
|
||||
username="prod_user",
|
||||
password="secure_pass",
|
||||
schema="gas_station_1",
|
||||
enabled=True,
|
||||
pool_min_conn=2,
|
||||
pool_max_conn=20,
|
||||
pool_timeout=30.0,
|
||||
connection_timeout=10.0,
|
||||
ssl_mode="require"
|
||||
)
|
||||
|
||||
assert config.host == "db.example.com"
|
||||
assert config.port == 5433
|
||||
assert config.database == "production_db"
|
||||
assert config.schema == "gas_station_1"
|
||||
assert config.pool_min_conn == 2
|
||||
assert config.pool_max_conn == 20
|
||||
assert config.ssl_mode == "require"
|
||||
|
||||
def test_get_connection_string(self):
|
||||
"""Test generating connection string."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn_string = config.get_connection_string()
|
||||
|
||||
expected = "host=localhost port=5432 database=test_db user=test_user password=test_pass"
|
||||
assert conn_string == expected
|
||||
|
||||
def test_get_connection_string_with_ssl(self):
|
||||
"""Test generating connection string with SSL."""
|
||||
config = DatabaseConfig(
|
||||
host="db.example.com",
|
||||
database="secure_db",
|
||||
username="user",
|
||||
password="pass",
|
||||
ssl_mode="require"
|
||||
)
|
||||
|
||||
conn_string = config.get_connection_string()
|
||||
|
||||
assert "sslmode=require" in conn_string
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test creating config from dictionary."""
|
||||
config_dict = {
|
||||
"host": "test-host",
|
||||
"port": 5433,
|
||||
"database": "test-db",
|
||||
"username": "test-user",
|
||||
"password": "test-pass",
|
||||
"schema": "test_schema",
|
||||
"pool_max_conn": 15,
|
||||
"unknown_field": "ignored"
|
||||
}
|
||||
|
||||
config = DatabaseConfig.from_dict(config_dict)
|
||||
|
||||
assert config.host == "test-host"
|
||||
assert config.port == 5433
|
||||
assert config.database == "test-db"
|
||||
assert config.schema == "test_schema"
|
||||
assert config.pool_max_conn == 15
|
||||
|
||||
|
||||
class TestQueryBuilder:
|
||||
"""Test SQL query building functionality."""
|
||||
|
||||
def test_build_select_query(self):
|
||||
"""Test building SELECT queries."""
|
||||
builder = QueryBuilder("test_schema")
|
||||
|
||||
query, params = builder.build_select_query(
|
||||
table="users",
|
||||
columns=["id", "name", "email"],
|
||||
where={"status": "active", "age": 25},
|
||||
order_by="created_at DESC",
|
||||
limit=10
|
||||
)
|
||||
|
||||
expected_query = (
|
||||
"SELECT id, name, email FROM test_schema.users "
|
||||
"WHERE status = %s AND age = %s "
|
||||
"ORDER BY created_at DESC LIMIT 10"
|
||||
)
|
||||
|
||||
assert query == expected_query
|
||||
assert params == ["active", 25]
|
||||
|
||||
def test_build_select_all_columns(self):
|
||||
"""Test building SELECT * query."""
|
||||
builder = QueryBuilder("public")
|
||||
|
||||
query, params = builder.build_select_query("products")
|
||||
|
||||
expected_query = "SELECT * FROM public.products"
|
||||
assert query == expected_query
|
||||
assert params == []
|
||||
|
||||
def test_build_insert_query(self):
|
||||
"""Test building INSERT queries."""
|
||||
builder = QueryBuilder("inventory")
|
||||
|
||||
data = {
|
||||
"product_name": "Widget",
|
||||
"price": 19.99,
|
||||
"quantity": 100,
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
|
||||
query, params = builder.build_insert_query("products", data)
|
||||
|
||||
expected_query = (
|
||||
"INSERT INTO inventory.products (product_name, price, quantity, created_at) "
|
||||
"VALUES (%s, %s, %s, NOW()) RETURNING id"
|
||||
)
|
||||
|
||||
assert query == expected_query
|
||||
assert params == ["Widget", 19.99, 100]
|
||||
|
||||
def test_build_update_query(self):
|
||||
"""Test building UPDATE queries."""
|
||||
builder = QueryBuilder("sales")
|
||||
|
||||
data = {
|
||||
"status": "shipped",
|
||||
"shipped_date": "NOW()",
|
||||
"tracking_number": "ABC123"
|
||||
}
|
||||
|
||||
where_conditions = {"order_id": 12345}
|
||||
|
||||
query, params = builder.build_update_query("orders", data, where_conditions)
|
||||
|
||||
expected_query = (
|
||||
"UPDATE sales.orders SET status = %s, shipped_date = NOW(), tracking_number = %s "
|
||||
"WHERE order_id = %s"
|
||||
)
|
||||
|
||||
assert query == expected_query
|
||||
assert params == ["shipped", "ABC123", 12345]
|
||||
|
||||
def test_build_delete_query(self):
|
||||
"""Test building DELETE queries."""
|
||||
builder = QueryBuilder("logs")
|
||||
|
||||
where_conditions = {
|
||||
"level": "DEBUG",
|
||||
"created_at": "< NOW() - INTERVAL '7 days'"
|
||||
}
|
||||
|
||||
query, params = builder.build_delete_query("application_logs", where_conditions)
|
||||
|
||||
expected_query = (
|
||||
"DELETE FROM logs.application_logs "
|
||||
"WHERE level = %s AND created_at < NOW() - INTERVAL '7 days'"
|
||||
)
|
||||
|
||||
assert query == expected_query
|
||||
assert params == ["DEBUG"]
|
||||
|
||||
def test_build_create_table_query(self):
|
||||
"""Test building CREATE TABLE queries."""
|
||||
builder = QueryBuilder("gas_station_1")
|
||||
|
||||
columns = {
|
||||
"id": "SERIAL PRIMARY KEY",
|
||||
"session_id": "VARCHAR(255) UNIQUE NOT NULL",
|
||||
"camera_id": "VARCHAR(255) NOT NULL",
|
||||
"detection_class": "VARCHAR(100)",
|
||||
"confidence": "DECIMAL(4,3)",
|
||||
"bbox_data": "JSON",
|
||||
"created_at": "TIMESTAMP DEFAULT NOW()",
|
||||
"updated_at": "TIMESTAMP DEFAULT NOW()"
|
||||
}
|
||||
|
||||
query = builder.build_create_table_query("detections", columns)
|
||||
|
||||
expected_parts = [
|
||||
"CREATE TABLE IF NOT EXISTS gas_station_1.detections",
|
||||
"id SERIAL PRIMARY KEY",
|
||||
"session_id VARCHAR(255) UNIQUE NOT NULL",
|
||||
"camera_id VARCHAR(255) NOT NULL",
|
||||
"bbox_data JSON",
|
||||
"created_at TIMESTAMP DEFAULT NOW()"
|
||||
]
|
||||
|
||||
for part in expected_parts:
|
||||
assert part in query
|
||||
|
||||
def test_escape_identifier(self):
|
||||
"""Test SQL identifier escaping."""
|
||||
builder = QueryBuilder("test")
|
||||
|
||||
assert builder.escape_identifier("table") == '"table"'
|
||||
assert builder.escape_identifier("column_name") == '"column_name"'
|
||||
assert builder.escape_identifier("user-table") == '"user-table"'
|
||||
|
||||
def test_format_value_for_sql(self):
|
||||
"""Test SQL value formatting."""
|
||||
builder = QueryBuilder("test")
|
||||
|
||||
# Regular values should use placeholder
|
||||
assert builder.format_value_for_sql("string") == ("%s", "string")
|
||||
assert builder.format_value_for_sql(42) == ("%s", 42)
|
||||
assert builder.format_value_for_sql(3.14) == ("%s", 3.14)
|
||||
|
||||
# SQL functions should be literal
|
||||
assert builder.format_value_for_sql("NOW()") == ("NOW()", None)
|
||||
assert builder.format_value_for_sql("CURRENT_TIMESTAMP") == ("CURRENT_TIMESTAMP", None)
|
||||
assert builder.format_value_for_sql("UUID()") == ("UUID()", None)
|
||||
|
||||
|
||||
class TestDatabaseConnection:
|
||||
"""Test database connection management."""
|
||||
|
||||
def test_creation(self, mock_database_connection):
|
||||
"""Test connection creation."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
assert conn.config == config
|
||||
assert conn.connection == mock_database_connection
|
||||
assert conn.is_connected is True
|
||||
|
||||
def test_execute_query(self, mock_database_connection):
|
||||
"""Test query execution."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
# Mock cursor behavior
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.fetchall.return_value = [
|
||||
(1, "John", "john@example.com"),
|
||||
(2, "Jane", "jane@example.com")
|
||||
]
|
||||
mock_cursor.rowcount = 2
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
query = "SELECT id, name, email FROM users WHERE status = %s"
|
||||
params = ["active"]
|
||||
|
||||
result = conn.execute_query(query, params)
|
||||
|
||||
assert result == [
|
||||
(1, "John", "john@example.com"),
|
||||
(2, "Jane", "jane@example.com")
|
||||
]
|
||||
|
||||
mock_cursor.execute.assert_called_once_with(query, params)
|
||||
mock_cursor.fetchall.assert_called_once()
|
||||
|
||||
def test_execute_query_single_result(self, mock_database_connection):
|
||||
"""Test query execution with single result."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.fetchone.return_value = (1, "John", "john@example.com")
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
result = conn.execute_query("SELECT * FROM users WHERE id = %s", [1], fetch_one=True)
|
||||
|
||||
assert result == (1, "John", "john@example.com")
|
||||
mock_cursor.fetchone.assert_called_once()
|
||||
|
||||
def test_execute_query_no_fetch(self, mock_database_connection):
|
||||
"""Test query execution without fetching results."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.rowcount = 1
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
result = conn.execute_query(
|
||||
"INSERT INTO users (name) VALUES (%s)",
|
||||
["John"],
|
||||
fetch_results=False
|
||||
)
|
||||
|
||||
assert result == 1 # Row count
|
||||
mock_cursor.execute.assert_called_once()
|
||||
mock_cursor.fetchall.assert_not_called()
|
||||
mock_cursor.fetchone.assert_not_called()
|
||||
|
||||
def test_execute_query_error(self, mock_database_connection):
|
||||
"""Test query execution error handling."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.execute.side_effect = psycopg2.Error("Database error")
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
with pytest.raises(DatabaseError) as exc_info:
|
||||
conn.execute_query("SELECT * FROM invalid_table")
|
||||
|
||||
assert "Database error" in str(exc_info.value)
|
||||
|
||||
def test_commit_transaction(self, mock_database_connection):
|
||||
"""Test transaction commit."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
conn.commit()
|
||||
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
|
||||
def test_rollback_transaction(self, mock_database_connection):
|
||||
"""Test transaction rollback."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
conn.rollback()
|
||||
|
||||
mock_database_connection.rollback.assert_called_once()
|
||||
|
||||
def test_close_connection(self, mock_database_connection):
|
||||
"""Test connection closing."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
conn.close()
|
||||
|
||||
assert conn.is_connected is False
|
||||
mock_database_connection.close.assert_called_once()
|
||||
|
||||
|
||||
class TestTransactionManager:
|
||||
"""Test transaction management."""
|
||||
|
||||
def test_transaction_context_success(self, mock_database_connection):
|
||||
"""Test successful transaction context."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
tx_manager = TransactionManager(conn)
|
||||
|
||||
with tx_manager:
|
||||
# Simulate some database operations
|
||||
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
|
||||
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
|
||||
|
||||
# Should commit on successful exit
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
mock_database_connection.rollback.assert_not_called()
|
||||
|
||||
def test_transaction_context_error(self, mock_database_connection):
|
||||
"""Test transaction context with error."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
tx_manager = TransactionManager(conn)
|
||||
|
||||
with pytest.raises(DatabaseError):
|
||||
with tx_manager:
|
||||
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
|
||||
# Simulate an error
|
||||
raise DatabaseError("Something went wrong")
|
||||
|
||||
# Should rollback on error
|
||||
mock_database_connection.rollback.assert_called_once()
|
||||
mock_database_connection.commit.assert_not_called()
|
||||
|
||||
|
||||
class TestDatabaseManager:
|
||||
"""Test main database manager functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test database manager initialization."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
assert manager.config == config
|
||||
assert isinstance(manager.query_builder, QueryBuilder)
|
||||
assert manager.query_builder.schema == "gas_station_1"
|
||||
assert manager.connection is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_success(self):
|
||||
"""Test successful database connection."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
with patch('psycopg2.connect') as mock_connect:
|
||||
mock_connection = Mock()
|
||||
mock_connect.return_value = mock_connection
|
||||
|
||||
await manager.connect()
|
||||
|
||||
assert manager.connection is not None
|
||||
assert manager.is_connected is True
|
||||
mock_connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_failure(self):
|
||||
"""Test database connection failure."""
|
||||
config = DatabaseConfig(
|
||||
host="nonexistent-host",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
with patch('psycopg2.connect') as mock_connect:
|
||||
mock_connect.side_effect = psycopg2.Error("Connection failed")
|
||||
|
||||
with pytest.raises(DatabaseError) as exc_info:
|
||||
await manager.connect()
|
||||
|
||||
assert "Connection failed" in str(exc_info.value)
|
||||
assert manager.is_connected is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect(self):
|
||||
"""Test database disconnection."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
# Mock connection
|
||||
mock_connection = Mock()
|
||||
manager.connection = DatabaseConnection(config, mock_connection)
|
||||
|
||||
await manager.disconnect()
|
||||
|
||||
assert manager.connection is None
|
||||
mock_connection.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query(self, mock_database_connection):
|
||||
"""Test query execution through manager."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behavior
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.fetchall.return_value = [(1, "Test"), (2, "Data")]
|
||||
|
||||
result = await manager.execute_query("SELECT * FROM test_table")
|
||||
|
||||
assert result == [(1, "Test"), (2, "Data")]
|
||||
mock_cursor.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_not_connected(self):
|
||||
"""Test query execution when not connected."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
with pytest.raises(DatabaseError) as exc_info:
|
||||
await manager.execute_query("SELECT * FROM test_table")
|
||||
|
||||
assert "not connected" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_record(self, mock_database_connection):
|
||||
"""Test inserting a record."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behavior
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.fetchone.return_value = (123,) # Returned ID
|
||||
|
||||
data = {
|
||||
"session_id": "session_123",
|
||||
"camera_id": "camera_001",
|
||||
"detection_class": "car",
|
||||
"confidence": 0.95,
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
|
||||
record_id = await manager.insert_record("car_detections", data)
|
||||
|
||||
assert record_id == 123
|
||||
mock_cursor.execute.assert_called_once()
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_record(self, mock_database_connection):
|
||||
"""Test updating a record."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behavior
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.rowcount = 1
|
||||
|
||||
data = {
|
||||
"car_brand": "Toyota",
|
||||
"car_body_type": "Sedan",
|
||||
"updated_at": "NOW()"
|
||||
}
|
||||
|
||||
where_conditions = {"session_id": "session_123"}
|
||||
|
||||
rows_affected = await manager.update_record("car_info", data, where_conditions)
|
||||
|
||||
assert rows_affected == 1
|
||||
mock_cursor.execute.assert_called_once()
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_records(self, mock_database_connection):
|
||||
"""Test deleting records."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behavior
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.rowcount = 3
|
||||
|
||||
where_conditions = {
|
||||
"created_at": "< NOW() - INTERVAL '30 days'",
|
||||
"processed": True
|
||||
}
|
||||
|
||||
rows_deleted = await manager.delete_records("old_detections", where_conditions)
|
||||
|
||||
assert rows_deleted == 3
|
||||
mock_cursor.execute.assert_called_once()
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_table(self, mock_database_connection):
|
||||
"""Test creating a table."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
columns = {
|
||||
"id": "SERIAL PRIMARY KEY",
|
||||
"session_id": "VARCHAR(255) UNIQUE NOT NULL",
|
||||
"camera_id": "VARCHAR(255) NOT NULL",
|
||||
"detection_data": "JSON",
|
||||
"created_at": "TIMESTAMP DEFAULT NOW()"
|
||||
}
|
||||
|
||||
await manager.create_table("test_detections", columns)
|
||||
|
||||
mock_database_connection.cursor.return_value.execute.assert_called_once()
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_exists(self, mock_database_connection):
|
||||
"""Test checking if table exists."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behavior - table exists
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.fetchone.return_value = (1,)
|
||||
|
||||
exists = await manager.table_exists("car_detections")
|
||||
|
||||
assert exists is True
|
||||
mock_cursor.execute.assert_called_once()
|
||||
|
||||
# Mock cursor behavior - table doesn't exist
|
||||
mock_cursor.fetchone.return_value = None
|
||||
|
||||
exists = await manager.table_exists("nonexistent_table")
|
||||
|
||||
assert exists is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transaction_context(self, mock_database_connection):
|
||||
"""Test transaction context manager."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
async with manager.transaction():
|
||||
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
|
||||
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
|
||||
|
||||
# Should commit on successful completion
|
||||
mock_database_connection.commit.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_schema(self, mock_database_connection):
|
||||
"""Test getting table schema information."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behavior
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.fetchall.return_value = [
|
||||
("id", "integer", "NOT NULL"),
|
||||
("session_id", "character varying", "NOT NULL"),
|
||||
("created_at", "timestamp without time zone", "DEFAULT now()")
|
||||
]
|
||||
|
||||
schema = await manager.get_table_schema("car_detections")
|
||||
|
||||
assert len(schema) == 3
|
||||
assert schema[0] == ("id", "integer", "NOT NULL")
|
||||
assert schema[1] == ("session_id", "character varying", "NOT NULL")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_insert(self, mock_database_connection):
|
||||
"""Test bulk insert operation."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
records = [
|
||||
{"name": "John", "email": "john@example.com"},
|
||||
{"name": "Jane", "email": "jane@example.com"},
|
||||
{"name": "Bob", "email": "bob@example.com"}
|
||||
]
|
||||
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.rowcount = 3
|
||||
|
||||
rows_inserted = await manager.bulk_insert("users", records)
|
||||
|
||||
assert rows_inserted == 3
|
||||
mock_cursor.executemany.assert_called_once()
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_connection_stats(self, mock_database_connection):
|
||||
"""Test getting connection statistics."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
stats = manager.get_connection_stats()
|
||||
|
||||
assert "connected" in stats
|
||||
assert "host" in stats
|
||||
assert "database" in stats
|
||||
assert "schema" in stats
|
||||
assert stats["connected"] is True
|
||||
assert stats["host"] == "localhost"
|
||||
assert stats["database"] == "test_db"
|
||||
|
||||
|
||||
class TestDatabaseManagerIntegration:
|
||||
"""Integration tests for database manager."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_car_detection_workflow(self, mock_database_connection):
|
||||
"""Test complete car detection database workflow."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="gas_station_db",
|
||||
username="detector_user",
|
||||
password="detector_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behaviors for different operations
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
|
||||
# 1. Create initial detection record
|
||||
mock_cursor.fetchone.return_value = (456,) # Returned ID
|
||||
|
||||
detection_data = {
|
||||
"session_id": str(uuid.uuid4()),
|
||||
"camera_id": "camera_001",
|
||||
"display_id": "display_001",
|
||||
"detection_class": "car",
|
||||
"confidence": 0.92,
|
||||
"bbox_x1": 100,
|
||||
"bbox_y1": 200,
|
||||
"bbox_x2": 300,
|
||||
"bbox_y2": 400,
|
||||
"track_id": 1001,
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
|
||||
detection_id = await manager.insert_record("car_detections", detection_data)
|
||||
assert detection_id == 456
|
||||
|
||||
# 2. Update with classification results
|
||||
mock_cursor.rowcount = 1
|
||||
|
||||
classification_data = {
|
||||
"car_brand": "Toyota",
|
||||
"car_model": "Camry",
|
||||
"car_body_type": "Sedan",
|
||||
"car_color": "Blue",
|
||||
"brand_confidence": 0.87,
|
||||
"bodytype_confidence": 0.82,
|
||||
"color_confidence": 0.79,
|
||||
"updated_at": "NOW()"
|
||||
}
|
||||
|
||||
where_conditions = {"session_id": detection_data["session_id"]}
|
||||
|
||||
rows_updated = await manager.update_record("car_detections", classification_data, where_conditions)
|
||||
assert rows_updated == 1
|
||||
|
||||
# 3. Query final results
|
||||
mock_cursor.fetchall.return_value = [
|
||||
(456, detection_data["session_id"], "camera_001", "car", 0.92, "Toyota", "Sedan")
|
||||
]
|
||||
|
||||
results = await manager.execute_query(
|
||||
"SELECT id, session_id, camera_id, detection_class, confidence, car_brand, car_body_type "
|
||||
"FROM gas_station_1.car_detections WHERE session_id = %s",
|
||||
[detection_data["session_id"]]
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0][0] == 456 # ID
|
||||
assert results[0][3] == "car" # detection_class
|
||||
assert results[0][5] == "Toyota" # car_brand
|
||||
|
||||
# Verify all database operations were called
|
||||
assert mock_cursor.execute.call_count == 3
|
||||
assert mock_database_connection.commit.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_and_recovery(self, mock_database_connection):
|
||||
"""Test error handling and recovery scenarios."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Test transaction rollback on error
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
|
||||
with pytest.raises(DatabaseError):
|
||||
async with manager.transaction():
|
||||
# First operation succeeds
|
||||
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
|
||||
|
||||
# Second operation fails
|
||||
mock_cursor.execute.side_effect = psycopg2.Error("Constraint violation")
|
||||
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
|
||||
|
||||
# Should have rolled back
|
||||
mock_database_connection.rollback.assert_called_once()
|
||||
mock_database_connection.commit.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_recovery(self):
|
||||
"""Test automatic connection recovery."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
with patch('psycopg2.connect') as mock_connect:
|
||||
# First connection attempt fails
|
||||
mock_connect.side_effect = [
|
||||
psycopg2.Error("Connection refused"),
|
||||
Mock() # Second attempt succeeds
|
||||
]
|
||||
|
||||
# First attempt should fail
|
||||
with pytest.raises(DatabaseError):
|
||||
await manager.connect()
|
||||
|
||||
# Second attempt should succeed
|
||||
await manager.connect()
|
||||
assert manager.is_connected is True
|
Loading…
Add table
Add a link
Reference in a new issue