""" Unit tests for database management functionality. """ import pytest import asyncio from unittest.mock import Mock, MagicMock, patch, AsyncMock from datetime import datetime, timedelta import psycopg2 import uuid from detector_worker.storage.database_manager import ( DatabaseManager, DatabaseConfig, DatabaseConnection, QueryBuilder, TransactionManager, DatabaseError, ConnectionPoolError ) from detector_worker.core.exceptions import ConfigurationError class TestDatabaseConfig: """Test database configuration.""" def test_creation_minimal(self): """Test creating database config with minimal parameters.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) assert config.host == "localhost" assert config.port == 5432 # Default port assert config.database == "test_db" assert config.username == "test_user" assert config.password == "test_pass" assert config.schema == "public" # Default schema assert config.enabled is True def test_creation_full(self): """Test creating database config with all parameters.""" config = DatabaseConfig( host="db.example.com", port=5433, database="production_db", username="prod_user", password="secure_pass", schema="gas_station_1", enabled=True, pool_min_conn=2, pool_max_conn=20, pool_timeout=30.0, connection_timeout=10.0, ssl_mode="require" ) assert config.host == "db.example.com" assert config.port == 5433 assert config.database == "production_db" assert config.schema == "gas_station_1" assert config.pool_min_conn == 2 assert config.pool_max_conn == 20 assert config.ssl_mode == "require" def test_get_connection_string(self): """Test generating connection string.""" config = DatabaseConfig( host="localhost", port=5432, database="test_db", username="test_user", password="test_pass" ) conn_string = config.get_connection_string() expected = "host=localhost port=5432 database=test_db user=test_user password=test_pass" assert conn_string == expected def test_get_connection_string_with_ssl(self): """Test generating connection string with SSL.""" config = DatabaseConfig( host="db.example.com", database="secure_db", username="user", password="pass", ssl_mode="require" ) conn_string = config.get_connection_string() assert "sslmode=require" in conn_string def test_from_dict(self): """Test creating config from dictionary.""" config_dict = { "host": "test-host", "port": 5433, "database": "test-db", "username": "test-user", "password": "test-pass", "schema": "test_schema", "pool_max_conn": 15, "unknown_field": "ignored" } config = DatabaseConfig.from_dict(config_dict) assert config.host == "test-host" assert config.port == 5433 assert config.database == "test-db" assert config.schema == "test_schema" assert config.pool_max_conn == 15 class TestQueryBuilder: """Test SQL query building functionality.""" def test_build_select_query(self): """Test building SELECT queries.""" builder = QueryBuilder("test_schema") query, params = builder.build_select_query( table="users", columns=["id", "name", "email"], where={"status": "active", "age": 25}, order_by="created_at DESC", limit=10 ) expected_query = ( "SELECT id, name, email FROM test_schema.users " "WHERE status = %s AND age = %s " "ORDER BY created_at DESC LIMIT 10" ) assert query == expected_query assert params == ["active", 25] def test_build_select_all_columns(self): """Test building SELECT * query.""" builder = QueryBuilder("public") query, params = builder.build_select_query("products") expected_query = "SELECT * FROM public.products" assert query == expected_query assert params == [] def test_build_insert_query(self): """Test building INSERT queries.""" builder = QueryBuilder("inventory") data = { "product_name": "Widget", "price": 19.99, "quantity": 100, "created_at": "NOW()" } query, params = builder.build_insert_query("products", data) expected_query = ( "INSERT INTO inventory.products (product_name, price, quantity, created_at) " "VALUES (%s, %s, %s, NOW()) RETURNING id" ) assert query == expected_query assert params == ["Widget", 19.99, 100] def test_build_update_query(self): """Test building UPDATE queries.""" builder = QueryBuilder("sales") data = { "status": "shipped", "shipped_date": "NOW()", "tracking_number": "ABC123" } where_conditions = {"order_id": 12345} query, params = builder.build_update_query("orders", data, where_conditions) expected_query = ( "UPDATE sales.orders SET status = %s, shipped_date = NOW(), tracking_number = %s " "WHERE order_id = %s" ) assert query == expected_query assert params == ["shipped", "ABC123", 12345] def test_build_delete_query(self): """Test building DELETE queries.""" builder = QueryBuilder("logs") where_conditions = { "level": "DEBUG", "created_at": "< NOW() - INTERVAL '7 days'" } query, params = builder.build_delete_query("application_logs", where_conditions) expected_query = ( "DELETE FROM logs.application_logs " "WHERE level = %s AND created_at < NOW() - INTERVAL '7 days'" ) assert query == expected_query assert params == ["DEBUG"] def test_build_create_table_query(self): """Test building CREATE TABLE queries.""" builder = QueryBuilder("gas_station_1") columns = { "id": "SERIAL PRIMARY KEY", "session_id": "VARCHAR(255) UNIQUE NOT NULL", "camera_id": "VARCHAR(255) NOT NULL", "detection_class": "VARCHAR(100)", "confidence": "DECIMAL(4,3)", "bbox_data": "JSON", "created_at": "TIMESTAMP DEFAULT NOW()", "updated_at": "TIMESTAMP DEFAULT NOW()" } query = builder.build_create_table_query("detections", columns) expected_parts = [ "CREATE TABLE IF NOT EXISTS gas_station_1.detections", "id SERIAL PRIMARY KEY", "session_id VARCHAR(255) UNIQUE NOT NULL", "camera_id VARCHAR(255) NOT NULL", "bbox_data JSON", "created_at TIMESTAMP DEFAULT NOW()" ] for part in expected_parts: assert part in query def test_escape_identifier(self): """Test SQL identifier escaping.""" builder = QueryBuilder("test") assert builder.escape_identifier("table") == '"table"' assert builder.escape_identifier("column_name") == '"column_name"' assert builder.escape_identifier("user-table") == '"user-table"' def test_format_value_for_sql(self): """Test SQL value formatting.""" builder = QueryBuilder("test") # Regular values should use placeholder assert builder.format_value_for_sql("string") == ("%s", "string") assert builder.format_value_for_sql(42) == ("%s", 42) assert builder.format_value_for_sql(3.14) == ("%s", 3.14) # SQL functions should be literal assert builder.format_value_for_sql("NOW()") == ("NOW()", None) assert builder.format_value_for_sql("CURRENT_TIMESTAMP") == ("CURRENT_TIMESTAMP", None) assert builder.format_value_for_sql("UUID()") == ("UUID()", None) class TestDatabaseConnection: """Test database connection management.""" def test_creation(self, mock_database_connection): """Test connection creation.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) conn = DatabaseConnection(config, mock_database_connection) assert conn.config == config assert conn.connection == mock_database_connection assert conn.is_connected is True def test_execute_query(self, mock_database_connection): """Test query execution.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) # Mock cursor behavior mock_cursor = mock_database_connection.cursor.return_value mock_cursor.fetchall.return_value = [ (1, "John", "john@example.com"), (2, "Jane", "jane@example.com") ] mock_cursor.rowcount = 2 conn = DatabaseConnection(config, mock_database_connection) query = "SELECT id, name, email FROM users WHERE status = %s" params = ["active"] result = conn.execute_query(query, params) assert result == [ (1, "John", "john@example.com"), (2, "Jane", "jane@example.com") ] mock_cursor.execute.assert_called_once_with(query, params) mock_cursor.fetchall.assert_called_once() def test_execute_query_single_result(self, mock_database_connection): """Test query execution with single result.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) mock_cursor = mock_database_connection.cursor.return_value mock_cursor.fetchone.return_value = (1, "John", "john@example.com") conn = DatabaseConnection(config, mock_database_connection) result = conn.execute_query("SELECT * FROM users WHERE id = %s", [1], fetch_one=True) assert result == (1, "John", "john@example.com") mock_cursor.fetchone.assert_called_once() def test_execute_query_no_fetch(self, mock_database_connection): """Test query execution without fetching results.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) mock_cursor = mock_database_connection.cursor.return_value mock_cursor.rowcount = 1 conn = DatabaseConnection(config, mock_database_connection) result = conn.execute_query( "INSERT INTO users (name) VALUES (%s)", ["John"], fetch_results=False ) assert result == 1 # Row count mock_cursor.execute.assert_called_once() mock_cursor.fetchall.assert_not_called() mock_cursor.fetchone.assert_not_called() def test_execute_query_error(self, mock_database_connection): """Test query execution error handling.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) mock_cursor = mock_database_connection.cursor.return_value mock_cursor.execute.side_effect = psycopg2.Error("Database error") conn = DatabaseConnection(config, mock_database_connection) with pytest.raises(DatabaseError) as exc_info: conn.execute_query("SELECT * FROM invalid_table") assert "Database error" in str(exc_info.value) def test_commit_transaction(self, mock_database_connection): """Test transaction commit.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) conn = DatabaseConnection(config, mock_database_connection) conn.commit() mock_database_connection.commit.assert_called_once() def test_rollback_transaction(self, mock_database_connection): """Test transaction rollback.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) conn = DatabaseConnection(config, mock_database_connection) conn.rollback() mock_database_connection.rollback.assert_called_once() def test_close_connection(self, mock_database_connection): """Test connection closing.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) conn = DatabaseConnection(config, mock_database_connection) conn.close() assert conn.is_connected is False mock_database_connection.close.assert_called_once() class TestTransactionManager: """Test transaction management.""" def test_transaction_context_success(self, mock_database_connection): """Test successful transaction context.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) conn = DatabaseConnection(config, mock_database_connection) tx_manager = TransactionManager(conn) with tx_manager: # Simulate some database operations conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"]) conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"]) # Should commit on successful exit mock_database_connection.commit.assert_called_once() mock_database_connection.rollback.assert_not_called() def test_transaction_context_error(self, mock_database_connection): """Test transaction context with error.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) conn = DatabaseConnection(config, mock_database_connection) tx_manager = TransactionManager(conn) with pytest.raises(DatabaseError): with tx_manager: conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"]) # Simulate an error raise DatabaseError("Something went wrong") # Should rollback on error mock_database_connection.rollback.assert_called_once() mock_database_connection.commit.assert_not_called() class TestDatabaseManager: """Test main database manager functionality.""" def test_initialization(self): """Test database manager initialization.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass", schema="gas_station_1" ) manager = DatabaseManager(config) assert manager.config == config assert isinstance(manager.query_builder, QueryBuilder) assert manager.query_builder.schema == "gas_station_1" assert manager.connection is None @pytest.mark.asyncio async def test_connect_success(self): """Test successful database connection.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) manager = DatabaseManager(config) with patch('psycopg2.connect') as mock_connect: mock_connection = Mock() mock_connect.return_value = mock_connection await manager.connect() assert manager.connection is not None assert manager.is_connected is True mock_connect.assert_called_once() @pytest.mark.asyncio async def test_connect_failure(self): """Test database connection failure.""" config = DatabaseConfig( host="nonexistent-host", database="test_db", username="test_user", password="test_pass" ) manager = DatabaseManager(config) with patch('psycopg2.connect') as mock_connect: mock_connect.side_effect = psycopg2.Error("Connection failed") with pytest.raises(DatabaseError) as exc_info: await manager.connect() assert "Connection failed" in str(exc_info.value) assert manager.is_connected is False @pytest.mark.asyncio async def test_disconnect(self): """Test database disconnection.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) manager = DatabaseManager(config) # Mock connection mock_connection = Mock() manager.connection = DatabaseConnection(config, mock_connection) await manager.disconnect() assert manager.connection is None mock_connection.close.assert_called_once() @pytest.mark.asyncio async def test_execute_query(self, mock_database_connection): """Test query execution through manager.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) manager = DatabaseManager(config) manager.connection = DatabaseConnection(config, mock_database_connection) # Mock cursor behavior mock_cursor = mock_database_connection.cursor.return_value mock_cursor.fetchall.return_value = [(1, "Test"), (2, "Data")] result = await manager.execute_query("SELECT * FROM test_table") assert result == [(1, "Test"), (2, "Data")] mock_cursor.execute.assert_called_once() @pytest.mark.asyncio async def test_execute_query_not_connected(self): """Test query execution when not connected.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) manager = DatabaseManager(config) with pytest.raises(DatabaseError) as exc_info: await manager.execute_query("SELECT * FROM test_table") assert "not connected" in str(exc_info.value).lower() @pytest.mark.asyncio async def test_insert_record(self, mock_database_connection): """Test inserting a record.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass", schema="gas_station_1" ) manager = DatabaseManager(config) manager.connection = DatabaseConnection(config, mock_database_connection) # Mock cursor behavior mock_cursor = mock_database_connection.cursor.return_value mock_cursor.fetchone.return_value = (123,) # Returned ID data = { "session_id": "session_123", "camera_id": "camera_001", "detection_class": "car", "confidence": 0.95, "created_at": "NOW()" } record_id = await manager.insert_record("car_detections", data) assert record_id == 123 mock_cursor.execute.assert_called_once() mock_database_connection.commit.assert_called_once() @pytest.mark.asyncio async def test_update_record(self, mock_database_connection): """Test updating a record.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass", schema="gas_station_1" ) manager = DatabaseManager(config) manager.connection = DatabaseConnection(config, mock_database_connection) # Mock cursor behavior mock_cursor = mock_database_connection.cursor.return_value mock_cursor.rowcount = 1 data = { "car_brand": "Toyota", "car_body_type": "Sedan", "updated_at": "NOW()" } where_conditions = {"session_id": "session_123"} rows_affected = await manager.update_record("car_info", data, where_conditions) assert rows_affected == 1 mock_cursor.execute.assert_called_once() mock_database_connection.commit.assert_called_once() @pytest.mark.asyncio async def test_delete_records(self, mock_database_connection): """Test deleting records.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) manager = DatabaseManager(config) manager.connection = DatabaseConnection(config, mock_database_connection) # Mock cursor behavior mock_cursor = mock_database_connection.cursor.return_value mock_cursor.rowcount = 3 where_conditions = { "created_at": "< NOW() - INTERVAL '30 days'", "processed": True } rows_deleted = await manager.delete_records("old_detections", where_conditions) assert rows_deleted == 3 mock_cursor.execute.assert_called_once() mock_database_connection.commit.assert_called_once() @pytest.mark.asyncio async def test_create_table(self, mock_database_connection): """Test creating a table.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass", schema="gas_station_1" ) manager = DatabaseManager(config) manager.connection = DatabaseConnection(config, mock_database_connection) columns = { "id": "SERIAL PRIMARY KEY", "session_id": "VARCHAR(255) UNIQUE NOT NULL", "camera_id": "VARCHAR(255) NOT NULL", "detection_data": "JSON", "created_at": "TIMESTAMP DEFAULT NOW()" } await manager.create_table("test_detections", columns) mock_database_connection.cursor.return_value.execute.assert_called_once() mock_database_connection.commit.assert_called_once() @pytest.mark.asyncio async def test_table_exists(self, mock_database_connection): """Test checking if table exists.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass", schema="gas_station_1" ) manager = DatabaseManager(config) manager.connection = DatabaseConnection(config, mock_database_connection) # Mock cursor behavior - table exists mock_cursor = mock_database_connection.cursor.return_value mock_cursor.fetchone.return_value = (1,) exists = await manager.table_exists("car_detections") assert exists is True mock_cursor.execute.assert_called_once() # Mock cursor behavior - table doesn't exist mock_cursor.fetchone.return_value = None exists = await manager.table_exists("nonexistent_table") assert exists is False @pytest.mark.asyncio async def test_transaction_context(self, mock_database_connection): """Test transaction context manager.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) manager = DatabaseManager(config) manager.connection = DatabaseConnection(config, mock_database_connection) async with manager.transaction(): await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"]) await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"]) # Should commit on successful completion mock_database_connection.commit.assert_called() @pytest.mark.asyncio async def test_get_table_schema(self, mock_database_connection): """Test getting table schema information.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass", schema="gas_station_1" ) manager = DatabaseManager(config) manager.connection = DatabaseConnection(config, mock_database_connection) # Mock cursor behavior mock_cursor = mock_database_connection.cursor.return_value mock_cursor.fetchall.return_value = [ ("id", "integer", "NOT NULL"), ("session_id", "character varying", "NOT NULL"), ("created_at", "timestamp without time zone", "DEFAULT now()") ] schema = await manager.get_table_schema("car_detections") assert len(schema) == 3 assert schema[0] == ("id", "integer", "NOT NULL") assert schema[1] == ("session_id", "character varying", "NOT NULL") @pytest.mark.asyncio async def test_bulk_insert(self, mock_database_connection): """Test bulk insert operation.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) manager = DatabaseManager(config) manager.connection = DatabaseConnection(config, mock_database_connection) records = [ {"name": "John", "email": "john@example.com"}, {"name": "Jane", "email": "jane@example.com"}, {"name": "Bob", "email": "bob@example.com"} ] mock_cursor = mock_database_connection.cursor.return_value mock_cursor.rowcount = 3 rows_inserted = await manager.bulk_insert("users", records) assert rows_inserted == 3 mock_cursor.executemany.assert_called_once() mock_database_connection.commit.assert_called_once() @pytest.mark.asyncio async def test_get_connection_stats(self, mock_database_connection): """Test getting connection statistics.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) manager = DatabaseManager(config) manager.connection = DatabaseConnection(config, mock_database_connection) stats = manager.get_connection_stats() assert "connected" in stats assert "host" in stats assert "database" in stats assert "schema" in stats assert stats["connected"] is True assert stats["host"] == "localhost" assert stats["database"] == "test_db" class TestDatabaseManagerIntegration: """Integration tests for database manager.""" @pytest.mark.asyncio async def test_complete_car_detection_workflow(self, mock_database_connection): """Test complete car detection database workflow.""" config = DatabaseConfig( host="localhost", database="gas_station_db", username="detector_user", password="detector_pass", schema="gas_station_1" ) manager = DatabaseManager(config) manager.connection = DatabaseConnection(config, mock_database_connection) # Mock cursor behaviors for different operations mock_cursor = mock_database_connection.cursor.return_value # 1. Create initial detection record mock_cursor.fetchone.return_value = (456,) # Returned ID detection_data = { "session_id": str(uuid.uuid4()), "camera_id": "camera_001", "display_id": "display_001", "detection_class": "car", "confidence": 0.92, "bbox_x1": 100, "bbox_y1": 200, "bbox_x2": 300, "bbox_y2": 400, "track_id": 1001, "created_at": "NOW()" } detection_id = await manager.insert_record("car_detections", detection_data) assert detection_id == 456 # 2. Update with classification results mock_cursor.rowcount = 1 classification_data = { "car_brand": "Toyota", "car_model": "Camry", "car_body_type": "Sedan", "car_color": "Blue", "brand_confidence": 0.87, "bodytype_confidence": 0.82, "color_confidence": 0.79, "updated_at": "NOW()" } where_conditions = {"session_id": detection_data["session_id"]} rows_updated = await manager.update_record("car_detections", classification_data, where_conditions) assert rows_updated == 1 # 3. Query final results mock_cursor.fetchall.return_value = [ (456, detection_data["session_id"], "camera_001", "car", 0.92, "Toyota", "Sedan") ] results = await manager.execute_query( "SELECT id, session_id, camera_id, detection_class, confidence, car_brand, car_body_type " "FROM gas_station_1.car_detections WHERE session_id = %s", [detection_data["session_id"]] ) assert len(results) == 1 assert results[0][0] == 456 # ID assert results[0][3] == "car" # detection_class assert results[0][5] == "Toyota" # car_brand # Verify all database operations were called assert mock_cursor.execute.call_count == 3 assert mock_database_connection.commit.call_count == 2 @pytest.mark.asyncio async def test_error_handling_and_recovery(self, mock_database_connection): """Test error handling and recovery scenarios.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) manager = DatabaseManager(config) manager.connection = DatabaseConnection(config, mock_database_connection) # Test transaction rollback on error mock_cursor = mock_database_connection.cursor.return_value with pytest.raises(DatabaseError): async with manager.transaction(): # First operation succeeds await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"]) # Second operation fails mock_cursor.execute.side_effect = psycopg2.Error("Constraint violation") await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"]) # Should have rolled back mock_database_connection.rollback.assert_called_once() mock_database_connection.commit.assert_not_called() @pytest.mark.asyncio async def test_connection_recovery(self): """Test automatic connection recovery.""" config = DatabaseConfig( host="localhost", database="test_db", username="test_user", password="test_pass" ) manager = DatabaseManager(config) with patch('psycopg2.connect') as mock_connect: # First connection attempt fails mock_connect.side_effect = [ psycopg2.Error("Connection refused"), Mock() # Second attempt succeeds ] # First attempt should fail with pytest.raises(DatabaseError): await manager.connect() # Second attempt should succeed await manager.connect() assert manager.is_connected is True