976 lines
No EOL
33 KiB
Python
976 lines
No EOL
33 KiB
Python
"""
|
|
Unit tests for database management functionality.
|
|
"""
|
|
import pytest
|
|
import asyncio
|
|
from unittest.mock import Mock, MagicMock, patch, AsyncMock
|
|
from datetime import datetime, timedelta
|
|
import psycopg2
|
|
import uuid
|
|
|
|
from detector_worker.storage.database_manager import (
|
|
DatabaseManager,
|
|
DatabaseConfig,
|
|
DatabaseConnection,
|
|
QueryBuilder,
|
|
TransactionManager,
|
|
DatabaseError,
|
|
ConnectionPoolError
|
|
)
|
|
from detector_worker.core.exceptions import ConfigurationError
|
|
|
|
|
|
class TestDatabaseConfig:
|
|
"""Test database configuration."""
|
|
|
|
def test_creation_minimal(self):
|
|
"""Test creating database config with minimal parameters."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
assert config.host == "localhost"
|
|
assert config.port == 5432 # Default port
|
|
assert config.database == "test_db"
|
|
assert config.username == "test_user"
|
|
assert config.password == "test_pass"
|
|
assert config.schema == "public" # Default schema
|
|
assert config.enabled is True
|
|
|
|
def test_creation_full(self):
|
|
"""Test creating database config with all parameters."""
|
|
config = DatabaseConfig(
|
|
host="db.example.com",
|
|
port=5433,
|
|
database="production_db",
|
|
username="prod_user",
|
|
password="secure_pass",
|
|
schema="gas_station_1",
|
|
enabled=True,
|
|
pool_min_conn=2,
|
|
pool_max_conn=20,
|
|
pool_timeout=30.0,
|
|
connection_timeout=10.0,
|
|
ssl_mode="require"
|
|
)
|
|
|
|
assert config.host == "db.example.com"
|
|
assert config.port == 5433
|
|
assert config.database == "production_db"
|
|
assert config.schema == "gas_station_1"
|
|
assert config.pool_min_conn == 2
|
|
assert config.pool_max_conn == 20
|
|
assert config.ssl_mode == "require"
|
|
|
|
def test_get_connection_string(self):
|
|
"""Test generating connection string."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
port=5432,
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
conn_string = config.get_connection_string()
|
|
|
|
expected = "host=localhost port=5432 database=test_db user=test_user password=test_pass"
|
|
assert conn_string == expected
|
|
|
|
def test_get_connection_string_with_ssl(self):
|
|
"""Test generating connection string with SSL."""
|
|
config = DatabaseConfig(
|
|
host="db.example.com",
|
|
database="secure_db",
|
|
username="user",
|
|
password="pass",
|
|
ssl_mode="require"
|
|
)
|
|
|
|
conn_string = config.get_connection_string()
|
|
|
|
assert "sslmode=require" in conn_string
|
|
|
|
def test_from_dict(self):
|
|
"""Test creating config from dictionary."""
|
|
config_dict = {
|
|
"host": "test-host",
|
|
"port": 5433,
|
|
"database": "test-db",
|
|
"username": "test-user",
|
|
"password": "test-pass",
|
|
"schema": "test_schema",
|
|
"pool_max_conn": 15,
|
|
"unknown_field": "ignored"
|
|
}
|
|
|
|
config = DatabaseConfig.from_dict(config_dict)
|
|
|
|
assert config.host == "test-host"
|
|
assert config.port == 5433
|
|
assert config.database == "test-db"
|
|
assert config.schema == "test_schema"
|
|
assert config.pool_max_conn == 15
|
|
|
|
|
|
class TestQueryBuilder:
|
|
"""Test SQL query building functionality."""
|
|
|
|
def test_build_select_query(self):
|
|
"""Test building SELECT queries."""
|
|
builder = QueryBuilder("test_schema")
|
|
|
|
query, params = builder.build_select_query(
|
|
table="users",
|
|
columns=["id", "name", "email"],
|
|
where={"status": "active", "age": 25},
|
|
order_by="created_at DESC",
|
|
limit=10
|
|
)
|
|
|
|
expected_query = (
|
|
"SELECT id, name, email FROM test_schema.users "
|
|
"WHERE status = %s AND age = %s "
|
|
"ORDER BY created_at DESC LIMIT 10"
|
|
)
|
|
|
|
assert query == expected_query
|
|
assert params == ["active", 25]
|
|
|
|
def test_build_select_all_columns(self):
|
|
"""Test building SELECT * query."""
|
|
builder = QueryBuilder("public")
|
|
|
|
query, params = builder.build_select_query("products")
|
|
|
|
expected_query = "SELECT * FROM public.products"
|
|
assert query == expected_query
|
|
assert params == []
|
|
|
|
def test_build_insert_query(self):
|
|
"""Test building INSERT queries."""
|
|
builder = QueryBuilder("inventory")
|
|
|
|
data = {
|
|
"product_name": "Widget",
|
|
"price": 19.99,
|
|
"quantity": 100,
|
|
"created_at": "NOW()"
|
|
}
|
|
|
|
query, params = builder.build_insert_query("products", data)
|
|
|
|
expected_query = (
|
|
"INSERT INTO inventory.products (product_name, price, quantity, created_at) "
|
|
"VALUES (%s, %s, %s, NOW()) RETURNING id"
|
|
)
|
|
|
|
assert query == expected_query
|
|
assert params == ["Widget", 19.99, 100]
|
|
|
|
def test_build_update_query(self):
|
|
"""Test building UPDATE queries."""
|
|
builder = QueryBuilder("sales")
|
|
|
|
data = {
|
|
"status": "shipped",
|
|
"shipped_date": "NOW()",
|
|
"tracking_number": "ABC123"
|
|
}
|
|
|
|
where_conditions = {"order_id": 12345}
|
|
|
|
query, params = builder.build_update_query("orders", data, where_conditions)
|
|
|
|
expected_query = (
|
|
"UPDATE sales.orders SET status = %s, shipped_date = NOW(), tracking_number = %s "
|
|
"WHERE order_id = %s"
|
|
)
|
|
|
|
assert query == expected_query
|
|
assert params == ["shipped", "ABC123", 12345]
|
|
|
|
def test_build_delete_query(self):
|
|
"""Test building DELETE queries."""
|
|
builder = QueryBuilder("logs")
|
|
|
|
where_conditions = {
|
|
"level": "DEBUG",
|
|
"created_at": "< NOW() - INTERVAL '7 days'"
|
|
}
|
|
|
|
query, params = builder.build_delete_query("application_logs", where_conditions)
|
|
|
|
expected_query = (
|
|
"DELETE FROM logs.application_logs "
|
|
"WHERE level = %s AND created_at < NOW() - INTERVAL '7 days'"
|
|
)
|
|
|
|
assert query == expected_query
|
|
assert params == ["DEBUG"]
|
|
|
|
def test_build_create_table_query(self):
|
|
"""Test building CREATE TABLE queries."""
|
|
builder = QueryBuilder("gas_station_1")
|
|
|
|
columns = {
|
|
"id": "SERIAL PRIMARY KEY",
|
|
"session_id": "VARCHAR(255) UNIQUE NOT NULL",
|
|
"camera_id": "VARCHAR(255) NOT NULL",
|
|
"detection_class": "VARCHAR(100)",
|
|
"confidence": "DECIMAL(4,3)",
|
|
"bbox_data": "JSON",
|
|
"created_at": "TIMESTAMP DEFAULT NOW()",
|
|
"updated_at": "TIMESTAMP DEFAULT NOW()"
|
|
}
|
|
|
|
query = builder.build_create_table_query("detections", columns)
|
|
|
|
expected_parts = [
|
|
"CREATE TABLE IF NOT EXISTS gas_station_1.detections",
|
|
"id SERIAL PRIMARY KEY",
|
|
"session_id VARCHAR(255) UNIQUE NOT NULL",
|
|
"camera_id VARCHAR(255) NOT NULL",
|
|
"bbox_data JSON",
|
|
"created_at TIMESTAMP DEFAULT NOW()"
|
|
]
|
|
|
|
for part in expected_parts:
|
|
assert part in query
|
|
|
|
def test_escape_identifier(self):
|
|
"""Test SQL identifier escaping."""
|
|
builder = QueryBuilder("test")
|
|
|
|
assert builder.escape_identifier("table") == '"table"'
|
|
assert builder.escape_identifier("column_name") == '"column_name"'
|
|
assert builder.escape_identifier("user-table") == '"user-table"'
|
|
|
|
def test_format_value_for_sql(self):
|
|
"""Test SQL value formatting."""
|
|
builder = QueryBuilder("test")
|
|
|
|
# Regular values should use placeholder
|
|
assert builder.format_value_for_sql("string") == ("%s", "string")
|
|
assert builder.format_value_for_sql(42) == ("%s", 42)
|
|
assert builder.format_value_for_sql(3.14) == ("%s", 3.14)
|
|
|
|
# SQL functions should be literal
|
|
assert builder.format_value_for_sql("NOW()") == ("NOW()", None)
|
|
assert builder.format_value_for_sql("CURRENT_TIMESTAMP") == ("CURRENT_TIMESTAMP", None)
|
|
assert builder.format_value_for_sql("UUID()") == ("UUID()", None)
|
|
|
|
|
|
class TestDatabaseConnection:
|
|
"""Test database connection management."""
|
|
|
|
def test_creation(self, mock_database_connection):
|
|
"""Test connection creation."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
conn = DatabaseConnection(config, mock_database_connection)
|
|
|
|
assert conn.config == config
|
|
assert conn.connection == mock_database_connection
|
|
assert conn.is_connected is True
|
|
|
|
def test_execute_query(self, mock_database_connection):
|
|
"""Test query execution."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
# Mock cursor behavior
|
|
mock_cursor = mock_database_connection.cursor.return_value
|
|
mock_cursor.fetchall.return_value = [
|
|
(1, "John", "john@example.com"),
|
|
(2, "Jane", "jane@example.com")
|
|
]
|
|
mock_cursor.rowcount = 2
|
|
|
|
conn = DatabaseConnection(config, mock_database_connection)
|
|
|
|
query = "SELECT id, name, email FROM users WHERE status = %s"
|
|
params = ["active"]
|
|
|
|
result = conn.execute_query(query, params)
|
|
|
|
assert result == [
|
|
(1, "John", "john@example.com"),
|
|
(2, "Jane", "jane@example.com")
|
|
]
|
|
|
|
mock_cursor.execute.assert_called_once_with(query, params)
|
|
mock_cursor.fetchall.assert_called_once()
|
|
|
|
def test_execute_query_single_result(self, mock_database_connection):
|
|
"""Test query execution with single result."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
mock_cursor = mock_database_connection.cursor.return_value
|
|
mock_cursor.fetchone.return_value = (1, "John", "john@example.com")
|
|
|
|
conn = DatabaseConnection(config, mock_database_connection)
|
|
|
|
result = conn.execute_query("SELECT * FROM users WHERE id = %s", [1], fetch_one=True)
|
|
|
|
assert result == (1, "John", "john@example.com")
|
|
mock_cursor.fetchone.assert_called_once()
|
|
|
|
def test_execute_query_no_fetch(self, mock_database_connection):
|
|
"""Test query execution without fetching results."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
mock_cursor = mock_database_connection.cursor.return_value
|
|
mock_cursor.rowcount = 1
|
|
|
|
conn = DatabaseConnection(config, mock_database_connection)
|
|
|
|
result = conn.execute_query(
|
|
"INSERT INTO users (name) VALUES (%s)",
|
|
["John"],
|
|
fetch_results=False
|
|
)
|
|
|
|
assert result == 1 # Row count
|
|
mock_cursor.execute.assert_called_once()
|
|
mock_cursor.fetchall.assert_not_called()
|
|
mock_cursor.fetchone.assert_not_called()
|
|
|
|
def test_execute_query_error(self, mock_database_connection):
|
|
"""Test query execution error handling."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
mock_cursor = mock_database_connection.cursor.return_value
|
|
mock_cursor.execute.side_effect = psycopg2.Error("Database error")
|
|
|
|
conn = DatabaseConnection(config, mock_database_connection)
|
|
|
|
with pytest.raises(DatabaseError) as exc_info:
|
|
conn.execute_query("SELECT * FROM invalid_table")
|
|
|
|
assert "Database error" in str(exc_info.value)
|
|
|
|
def test_commit_transaction(self, mock_database_connection):
|
|
"""Test transaction commit."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
conn = DatabaseConnection(config, mock_database_connection)
|
|
conn.commit()
|
|
|
|
mock_database_connection.commit.assert_called_once()
|
|
|
|
def test_rollback_transaction(self, mock_database_connection):
|
|
"""Test transaction rollback."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
conn = DatabaseConnection(config, mock_database_connection)
|
|
conn.rollback()
|
|
|
|
mock_database_connection.rollback.assert_called_once()
|
|
|
|
def test_close_connection(self, mock_database_connection):
|
|
"""Test connection closing."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
conn = DatabaseConnection(config, mock_database_connection)
|
|
conn.close()
|
|
|
|
assert conn.is_connected is False
|
|
mock_database_connection.close.assert_called_once()
|
|
|
|
|
|
class TestTransactionManager:
|
|
"""Test transaction management."""
|
|
|
|
def test_transaction_context_success(self, mock_database_connection):
|
|
"""Test successful transaction context."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
conn = DatabaseConnection(config, mock_database_connection)
|
|
tx_manager = TransactionManager(conn)
|
|
|
|
with tx_manager:
|
|
# Simulate some database operations
|
|
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
|
|
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
|
|
|
|
# Should commit on successful exit
|
|
mock_database_connection.commit.assert_called_once()
|
|
mock_database_connection.rollback.assert_not_called()
|
|
|
|
def test_transaction_context_error(self, mock_database_connection):
|
|
"""Test transaction context with error."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
conn = DatabaseConnection(config, mock_database_connection)
|
|
tx_manager = TransactionManager(conn)
|
|
|
|
with pytest.raises(DatabaseError):
|
|
with tx_manager:
|
|
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
|
|
# Simulate an error
|
|
raise DatabaseError("Something went wrong")
|
|
|
|
# Should rollback on error
|
|
mock_database_connection.rollback.assert_called_once()
|
|
mock_database_connection.commit.assert_not_called()
|
|
|
|
|
|
class TestDatabaseManager:
|
|
"""Test main database manager functionality."""
|
|
|
|
def test_initialization(self):
|
|
"""Test database manager initialization."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass",
|
|
schema="gas_station_1"
|
|
)
|
|
|
|
manager = DatabaseManager(config)
|
|
|
|
assert manager.config == config
|
|
assert isinstance(manager.query_builder, QueryBuilder)
|
|
assert manager.query_builder.schema == "gas_station_1"
|
|
assert manager.connection is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_success(self):
|
|
"""Test successful database connection."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
manager = DatabaseManager(config)
|
|
|
|
with patch('psycopg2.connect') as mock_connect:
|
|
mock_connection = Mock()
|
|
mock_connect.return_value = mock_connection
|
|
|
|
await manager.connect()
|
|
|
|
assert manager.connection is not None
|
|
assert manager.is_connected is True
|
|
mock_connect.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_failure(self):
|
|
"""Test database connection failure."""
|
|
config = DatabaseConfig(
|
|
host="nonexistent-host",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
manager = DatabaseManager(config)
|
|
|
|
with patch('psycopg2.connect') as mock_connect:
|
|
mock_connect.side_effect = psycopg2.Error("Connection failed")
|
|
|
|
with pytest.raises(DatabaseError) as exc_info:
|
|
await manager.connect()
|
|
|
|
assert "Connection failed" in str(exc_info.value)
|
|
assert manager.is_connected is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect(self):
|
|
"""Test database disconnection."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
manager = DatabaseManager(config)
|
|
|
|
# Mock connection
|
|
mock_connection = Mock()
|
|
manager.connection = DatabaseConnection(config, mock_connection)
|
|
|
|
await manager.disconnect()
|
|
|
|
assert manager.connection is None
|
|
mock_connection.close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_query(self, mock_database_connection):
|
|
"""Test query execution through manager."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
manager = DatabaseManager(config)
|
|
manager.connection = DatabaseConnection(config, mock_database_connection)
|
|
|
|
# Mock cursor behavior
|
|
mock_cursor = mock_database_connection.cursor.return_value
|
|
mock_cursor.fetchall.return_value = [(1, "Test"), (2, "Data")]
|
|
|
|
result = await manager.execute_query("SELECT * FROM test_table")
|
|
|
|
assert result == [(1, "Test"), (2, "Data")]
|
|
mock_cursor.execute.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_query_not_connected(self):
|
|
"""Test query execution when not connected."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
manager = DatabaseManager(config)
|
|
|
|
with pytest.raises(DatabaseError) as exc_info:
|
|
await manager.execute_query("SELECT * FROM test_table")
|
|
|
|
assert "not connected" in str(exc_info.value).lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_insert_record(self, mock_database_connection):
|
|
"""Test inserting a record."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass",
|
|
schema="gas_station_1"
|
|
)
|
|
|
|
manager = DatabaseManager(config)
|
|
manager.connection = DatabaseConnection(config, mock_database_connection)
|
|
|
|
# Mock cursor behavior
|
|
mock_cursor = mock_database_connection.cursor.return_value
|
|
mock_cursor.fetchone.return_value = (123,) # Returned ID
|
|
|
|
data = {
|
|
"session_id": "session_123",
|
|
"camera_id": "camera_001",
|
|
"detection_class": "car",
|
|
"confidence": 0.95,
|
|
"created_at": "NOW()"
|
|
}
|
|
|
|
record_id = await manager.insert_record("car_detections", data)
|
|
|
|
assert record_id == 123
|
|
mock_cursor.execute.assert_called_once()
|
|
mock_database_connection.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_record(self, mock_database_connection):
|
|
"""Test updating a record."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass",
|
|
schema="gas_station_1"
|
|
)
|
|
|
|
manager = DatabaseManager(config)
|
|
manager.connection = DatabaseConnection(config, mock_database_connection)
|
|
|
|
# Mock cursor behavior
|
|
mock_cursor = mock_database_connection.cursor.return_value
|
|
mock_cursor.rowcount = 1
|
|
|
|
data = {
|
|
"car_brand": "Toyota",
|
|
"car_body_type": "Sedan",
|
|
"updated_at": "NOW()"
|
|
}
|
|
|
|
where_conditions = {"session_id": "session_123"}
|
|
|
|
rows_affected = await manager.update_record("car_info", data, where_conditions)
|
|
|
|
assert rows_affected == 1
|
|
mock_cursor.execute.assert_called_once()
|
|
mock_database_connection.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_records(self, mock_database_connection):
|
|
"""Test deleting records."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
manager = DatabaseManager(config)
|
|
manager.connection = DatabaseConnection(config, mock_database_connection)
|
|
|
|
# Mock cursor behavior
|
|
mock_cursor = mock_database_connection.cursor.return_value
|
|
mock_cursor.rowcount = 3
|
|
|
|
where_conditions = {
|
|
"created_at": "< NOW() - INTERVAL '30 days'",
|
|
"processed": True
|
|
}
|
|
|
|
rows_deleted = await manager.delete_records("old_detections", where_conditions)
|
|
|
|
assert rows_deleted == 3
|
|
mock_cursor.execute.assert_called_once()
|
|
mock_database_connection.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_table(self, mock_database_connection):
|
|
"""Test creating a table."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass",
|
|
schema="gas_station_1"
|
|
)
|
|
|
|
manager = DatabaseManager(config)
|
|
manager.connection = DatabaseConnection(config, mock_database_connection)
|
|
|
|
columns = {
|
|
"id": "SERIAL PRIMARY KEY",
|
|
"session_id": "VARCHAR(255) UNIQUE NOT NULL",
|
|
"camera_id": "VARCHAR(255) NOT NULL",
|
|
"detection_data": "JSON",
|
|
"created_at": "TIMESTAMP DEFAULT NOW()"
|
|
}
|
|
|
|
await manager.create_table("test_detections", columns)
|
|
|
|
mock_database_connection.cursor.return_value.execute.assert_called_once()
|
|
mock_database_connection.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_table_exists(self, mock_database_connection):
|
|
"""Test checking if table exists."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass",
|
|
schema="gas_station_1"
|
|
)
|
|
|
|
manager = DatabaseManager(config)
|
|
manager.connection = DatabaseConnection(config, mock_database_connection)
|
|
|
|
# Mock cursor behavior - table exists
|
|
mock_cursor = mock_database_connection.cursor.return_value
|
|
mock_cursor.fetchone.return_value = (1,)
|
|
|
|
exists = await manager.table_exists("car_detections")
|
|
|
|
assert exists is True
|
|
mock_cursor.execute.assert_called_once()
|
|
|
|
# Mock cursor behavior - table doesn't exist
|
|
mock_cursor.fetchone.return_value = None
|
|
|
|
exists = await manager.table_exists("nonexistent_table")
|
|
|
|
assert exists is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transaction_context(self, mock_database_connection):
|
|
"""Test transaction context manager."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
manager = DatabaseManager(config)
|
|
manager.connection = DatabaseConnection(config, mock_database_connection)
|
|
|
|
async with manager.transaction():
|
|
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
|
|
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
|
|
|
|
# Should commit on successful completion
|
|
mock_database_connection.commit.assert_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_table_schema(self, mock_database_connection):
|
|
"""Test getting table schema information."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass",
|
|
schema="gas_station_1"
|
|
)
|
|
|
|
manager = DatabaseManager(config)
|
|
manager.connection = DatabaseConnection(config, mock_database_connection)
|
|
|
|
# Mock cursor behavior
|
|
mock_cursor = mock_database_connection.cursor.return_value
|
|
mock_cursor.fetchall.return_value = [
|
|
("id", "integer", "NOT NULL"),
|
|
("session_id", "character varying", "NOT NULL"),
|
|
("created_at", "timestamp without time zone", "DEFAULT now()")
|
|
]
|
|
|
|
schema = await manager.get_table_schema("car_detections")
|
|
|
|
assert len(schema) == 3
|
|
assert schema[0] == ("id", "integer", "NOT NULL")
|
|
assert schema[1] == ("session_id", "character varying", "NOT NULL")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_bulk_insert(self, mock_database_connection):
|
|
"""Test bulk insert operation."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
manager = DatabaseManager(config)
|
|
manager.connection = DatabaseConnection(config, mock_database_connection)
|
|
|
|
records = [
|
|
{"name": "John", "email": "john@example.com"},
|
|
{"name": "Jane", "email": "jane@example.com"},
|
|
{"name": "Bob", "email": "bob@example.com"}
|
|
]
|
|
|
|
mock_cursor = mock_database_connection.cursor.return_value
|
|
mock_cursor.rowcount = 3
|
|
|
|
rows_inserted = await manager.bulk_insert("users", records)
|
|
|
|
assert rows_inserted == 3
|
|
mock_cursor.executemany.assert_called_once()
|
|
mock_database_connection.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_connection_stats(self, mock_database_connection):
|
|
"""Test getting connection statistics."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
manager = DatabaseManager(config)
|
|
manager.connection = DatabaseConnection(config, mock_database_connection)
|
|
|
|
stats = manager.get_connection_stats()
|
|
|
|
assert "connected" in stats
|
|
assert "host" in stats
|
|
assert "database" in stats
|
|
assert "schema" in stats
|
|
assert stats["connected"] is True
|
|
assert stats["host"] == "localhost"
|
|
assert stats["database"] == "test_db"
|
|
|
|
|
|
class TestDatabaseManagerIntegration:
|
|
"""Integration tests for database manager."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_car_detection_workflow(self, mock_database_connection):
|
|
"""Test complete car detection database workflow."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="gas_station_db",
|
|
username="detector_user",
|
|
password="detector_pass",
|
|
schema="gas_station_1"
|
|
)
|
|
|
|
manager = DatabaseManager(config)
|
|
manager.connection = DatabaseConnection(config, mock_database_connection)
|
|
|
|
# Mock cursor behaviors for different operations
|
|
mock_cursor = mock_database_connection.cursor.return_value
|
|
|
|
# 1. Create initial detection record
|
|
mock_cursor.fetchone.return_value = (456,) # Returned ID
|
|
|
|
detection_data = {
|
|
"session_id": str(uuid.uuid4()),
|
|
"camera_id": "camera_001",
|
|
"display_id": "display_001",
|
|
"detection_class": "car",
|
|
"confidence": 0.92,
|
|
"bbox_x1": 100,
|
|
"bbox_y1": 200,
|
|
"bbox_x2": 300,
|
|
"bbox_y2": 400,
|
|
"track_id": 1001,
|
|
"created_at": "NOW()"
|
|
}
|
|
|
|
detection_id = await manager.insert_record("car_detections", detection_data)
|
|
assert detection_id == 456
|
|
|
|
# 2. Update with classification results
|
|
mock_cursor.rowcount = 1
|
|
|
|
classification_data = {
|
|
"car_brand": "Toyota",
|
|
"car_model": "Camry",
|
|
"car_body_type": "Sedan",
|
|
"car_color": "Blue",
|
|
"brand_confidence": 0.87,
|
|
"bodytype_confidence": 0.82,
|
|
"color_confidence": 0.79,
|
|
"updated_at": "NOW()"
|
|
}
|
|
|
|
where_conditions = {"session_id": detection_data["session_id"]}
|
|
|
|
rows_updated = await manager.update_record("car_detections", classification_data, where_conditions)
|
|
assert rows_updated == 1
|
|
|
|
# 3. Query final results
|
|
mock_cursor.fetchall.return_value = [
|
|
(456, detection_data["session_id"], "camera_001", "car", 0.92, "Toyota", "Sedan")
|
|
]
|
|
|
|
results = await manager.execute_query(
|
|
"SELECT id, session_id, camera_id, detection_class, confidence, car_brand, car_body_type "
|
|
"FROM gas_station_1.car_detections WHERE session_id = %s",
|
|
[detection_data["session_id"]]
|
|
)
|
|
|
|
assert len(results) == 1
|
|
assert results[0][0] == 456 # ID
|
|
assert results[0][3] == "car" # detection_class
|
|
assert results[0][5] == "Toyota" # car_brand
|
|
|
|
# Verify all database operations were called
|
|
assert mock_cursor.execute.call_count == 3
|
|
assert mock_database_connection.commit.call_count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_error_handling_and_recovery(self, mock_database_connection):
|
|
"""Test error handling and recovery scenarios."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
manager = DatabaseManager(config)
|
|
manager.connection = DatabaseConnection(config, mock_database_connection)
|
|
|
|
# Test transaction rollback on error
|
|
mock_cursor = mock_database_connection.cursor.return_value
|
|
|
|
with pytest.raises(DatabaseError):
|
|
async with manager.transaction():
|
|
# First operation succeeds
|
|
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
|
|
|
|
# Second operation fails
|
|
mock_cursor.execute.side_effect = psycopg2.Error("Constraint violation")
|
|
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
|
|
|
|
# Should have rolled back
|
|
mock_database_connection.rollback.assert_called_once()
|
|
mock_database_connection.commit.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connection_recovery(self):
|
|
"""Test automatic connection recovery."""
|
|
config = DatabaseConfig(
|
|
host="localhost",
|
|
database="test_db",
|
|
username="test_user",
|
|
password="test_pass"
|
|
)
|
|
|
|
manager = DatabaseManager(config)
|
|
|
|
with patch('psycopg2.connect') as mock_connect:
|
|
# First connection attempt fails
|
|
mock_connect.side_effect = [
|
|
psycopg2.Error("Connection refused"),
|
|
Mock() # Second attempt succeeds
|
|
]
|
|
|
|
# First attempt should fail
|
|
with pytest.raises(DatabaseError):
|
|
await manager.connect()
|
|
|
|
# Second attempt should succeed
|
|
await manager.connect()
|
|
assert manager.is_connected is True |