Refactor: PHASE 8: Testing & Integration
This commit is contained in:
parent
af34f4fd08
commit
9e8c6804a7
32 changed files with 17128 additions and 0 deletions
976
tests/unit/storage/test_database_manager.py
Normal file
976
tests/unit/storage/test_database_manager.py
Normal file
|
@ -0,0 +1,976 @@
|
|||
"""
|
||||
Unit tests for database management functionality.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, MagicMock, patch, AsyncMock
|
||||
from datetime import datetime, timedelta
|
||||
import psycopg2
|
||||
import uuid
|
||||
|
||||
from detector_worker.storage.database_manager import (
|
||||
DatabaseManager,
|
||||
DatabaseConfig,
|
||||
DatabaseConnection,
|
||||
QueryBuilder,
|
||||
TransactionManager,
|
||||
DatabaseError,
|
||||
ConnectionPoolError
|
||||
)
|
||||
from detector_worker.core.exceptions import ConfigurationError
|
||||
|
||||
|
||||
class TestDatabaseConfig:
|
||||
"""Test database configuration."""
|
||||
|
||||
def test_creation_minimal(self):
|
||||
"""Test creating database config with minimal parameters."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
assert config.host == "localhost"
|
||||
assert config.port == 5432 # Default port
|
||||
assert config.database == "test_db"
|
||||
assert config.username == "test_user"
|
||||
assert config.password == "test_pass"
|
||||
assert config.schema == "public" # Default schema
|
||||
assert config.enabled is True
|
||||
|
||||
def test_creation_full(self):
|
||||
"""Test creating database config with all parameters."""
|
||||
config = DatabaseConfig(
|
||||
host="db.example.com",
|
||||
port=5433,
|
||||
database="production_db",
|
||||
username="prod_user",
|
||||
password="secure_pass",
|
||||
schema="gas_station_1",
|
||||
enabled=True,
|
||||
pool_min_conn=2,
|
||||
pool_max_conn=20,
|
||||
pool_timeout=30.0,
|
||||
connection_timeout=10.0,
|
||||
ssl_mode="require"
|
||||
)
|
||||
|
||||
assert config.host == "db.example.com"
|
||||
assert config.port == 5433
|
||||
assert config.database == "production_db"
|
||||
assert config.schema == "gas_station_1"
|
||||
assert config.pool_min_conn == 2
|
||||
assert config.pool_max_conn == 20
|
||||
assert config.ssl_mode == "require"
|
||||
|
||||
def test_get_connection_string(self):
|
||||
"""Test generating connection string."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn_string = config.get_connection_string()
|
||||
|
||||
expected = "host=localhost port=5432 database=test_db user=test_user password=test_pass"
|
||||
assert conn_string == expected
|
||||
|
||||
def test_get_connection_string_with_ssl(self):
|
||||
"""Test generating connection string with SSL."""
|
||||
config = DatabaseConfig(
|
||||
host="db.example.com",
|
||||
database="secure_db",
|
||||
username="user",
|
||||
password="pass",
|
||||
ssl_mode="require"
|
||||
)
|
||||
|
||||
conn_string = config.get_connection_string()
|
||||
|
||||
assert "sslmode=require" in conn_string
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test creating config from dictionary."""
|
||||
config_dict = {
|
||||
"host": "test-host",
|
||||
"port": 5433,
|
||||
"database": "test-db",
|
||||
"username": "test-user",
|
||||
"password": "test-pass",
|
||||
"schema": "test_schema",
|
||||
"pool_max_conn": 15,
|
||||
"unknown_field": "ignored"
|
||||
}
|
||||
|
||||
config = DatabaseConfig.from_dict(config_dict)
|
||||
|
||||
assert config.host == "test-host"
|
||||
assert config.port == 5433
|
||||
assert config.database == "test-db"
|
||||
assert config.schema == "test_schema"
|
||||
assert config.pool_max_conn == 15
|
||||
|
||||
|
||||
class TestQueryBuilder:
|
||||
"""Test SQL query building functionality."""
|
||||
|
||||
def test_build_select_query(self):
|
||||
"""Test building SELECT queries."""
|
||||
builder = QueryBuilder("test_schema")
|
||||
|
||||
query, params = builder.build_select_query(
|
||||
table="users",
|
||||
columns=["id", "name", "email"],
|
||||
where={"status": "active", "age": 25},
|
||||
order_by="created_at DESC",
|
||||
limit=10
|
||||
)
|
||||
|
||||
expected_query = (
|
||||
"SELECT id, name, email FROM test_schema.users "
|
||||
"WHERE status = %s AND age = %s "
|
||||
"ORDER BY created_at DESC LIMIT 10"
|
||||
)
|
||||
|
||||
assert query == expected_query
|
||||
assert params == ["active", 25]
|
||||
|
||||
def test_build_select_all_columns(self):
|
||||
"""Test building SELECT * query."""
|
||||
builder = QueryBuilder("public")
|
||||
|
||||
query, params = builder.build_select_query("products")
|
||||
|
||||
expected_query = "SELECT * FROM public.products"
|
||||
assert query == expected_query
|
||||
assert params == []
|
||||
|
||||
def test_build_insert_query(self):
|
||||
"""Test building INSERT queries."""
|
||||
builder = QueryBuilder("inventory")
|
||||
|
||||
data = {
|
||||
"product_name": "Widget",
|
||||
"price": 19.99,
|
||||
"quantity": 100,
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
|
||||
query, params = builder.build_insert_query("products", data)
|
||||
|
||||
expected_query = (
|
||||
"INSERT INTO inventory.products (product_name, price, quantity, created_at) "
|
||||
"VALUES (%s, %s, %s, NOW()) RETURNING id"
|
||||
)
|
||||
|
||||
assert query == expected_query
|
||||
assert params == ["Widget", 19.99, 100]
|
||||
|
||||
def test_build_update_query(self):
|
||||
"""Test building UPDATE queries."""
|
||||
builder = QueryBuilder("sales")
|
||||
|
||||
data = {
|
||||
"status": "shipped",
|
||||
"shipped_date": "NOW()",
|
||||
"tracking_number": "ABC123"
|
||||
}
|
||||
|
||||
where_conditions = {"order_id": 12345}
|
||||
|
||||
query, params = builder.build_update_query("orders", data, where_conditions)
|
||||
|
||||
expected_query = (
|
||||
"UPDATE sales.orders SET status = %s, shipped_date = NOW(), tracking_number = %s "
|
||||
"WHERE order_id = %s"
|
||||
)
|
||||
|
||||
assert query == expected_query
|
||||
assert params == ["shipped", "ABC123", 12345]
|
||||
|
||||
def test_build_delete_query(self):
|
||||
"""Test building DELETE queries."""
|
||||
builder = QueryBuilder("logs")
|
||||
|
||||
where_conditions = {
|
||||
"level": "DEBUG",
|
||||
"created_at": "< NOW() - INTERVAL '7 days'"
|
||||
}
|
||||
|
||||
query, params = builder.build_delete_query("application_logs", where_conditions)
|
||||
|
||||
expected_query = (
|
||||
"DELETE FROM logs.application_logs "
|
||||
"WHERE level = %s AND created_at < NOW() - INTERVAL '7 days'"
|
||||
)
|
||||
|
||||
assert query == expected_query
|
||||
assert params == ["DEBUG"]
|
||||
|
||||
def test_build_create_table_query(self):
|
||||
"""Test building CREATE TABLE queries."""
|
||||
builder = QueryBuilder("gas_station_1")
|
||||
|
||||
columns = {
|
||||
"id": "SERIAL PRIMARY KEY",
|
||||
"session_id": "VARCHAR(255) UNIQUE NOT NULL",
|
||||
"camera_id": "VARCHAR(255) NOT NULL",
|
||||
"detection_class": "VARCHAR(100)",
|
||||
"confidence": "DECIMAL(4,3)",
|
||||
"bbox_data": "JSON",
|
||||
"created_at": "TIMESTAMP DEFAULT NOW()",
|
||||
"updated_at": "TIMESTAMP DEFAULT NOW()"
|
||||
}
|
||||
|
||||
query = builder.build_create_table_query("detections", columns)
|
||||
|
||||
expected_parts = [
|
||||
"CREATE TABLE IF NOT EXISTS gas_station_1.detections",
|
||||
"id SERIAL PRIMARY KEY",
|
||||
"session_id VARCHAR(255) UNIQUE NOT NULL",
|
||||
"camera_id VARCHAR(255) NOT NULL",
|
||||
"bbox_data JSON",
|
||||
"created_at TIMESTAMP DEFAULT NOW()"
|
||||
]
|
||||
|
||||
for part in expected_parts:
|
||||
assert part in query
|
||||
|
||||
def test_escape_identifier(self):
|
||||
"""Test SQL identifier escaping."""
|
||||
builder = QueryBuilder("test")
|
||||
|
||||
assert builder.escape_identifier("table") == '"table"'
|
||||
assert builder.escape_identifier("column_name") == '"column_name"'
|
||||
assert builder.escape_identifier("user-table") == '"user-table"'
|
||||
|
||||
def test_format_value_for_sql(self):
|
||||
"""Test SQL value formatting."""
|
||||
builder = QueryBuilder("test")
|
||||
|
||||
# Regular values should use placeholder
|
||||
assert builder.format_value_for_sql("string") == ("%s", "string")
|
||||
assert builder.format_value_for_sql(42) == ("%s", 42)
|
||||
assert builder.format_value_for_sql(3.14) == ("%s", 3.14)
|
||||
|
||||
# SQL functions should be literal
|
||||
assert builder.format_value_for_sql("NOW()") == ("NOW()", None)
|
||||
assert builder.format_value_for_sql("CURRENT_TIMESTAMP") == ("CURRENT_TIMESTAMP", None)
|
||||
assert builder.format_value_for_sql("UUID()") == ("UUID()", None)
|
||||
|
||||
|
||||
class TestDatabaseConnection:
|
||||
"""Test database connection management."""
|
||||
|
||||
def test_creation(self, mock_database_connection):
|
||||
"""Test connection creation."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
assert conn.config == config
|
||||
assert conn.connection == mock_database_connection
|
||||
assert conn.is_connected is True
|
||||
|
||||
def test_execute_query(self, mock_database_connection):
|
||||
"""Test query execution."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
# Mock cursor behavior
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.fetchall.return_value = [
|
||||
(1, "John", "john@example.com"),
|
||||
(2, "Jane", "jane@example.com")
|
||||
]
|
||||
mock_cursor.rowcount = 2
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
query = "SELECT id, name, email FROM users WHERE status = %s"
|
||||
params = ["active"]
|
||||
|
||||
result = conn.execute_query(query, params)
|
||||
|
||||
assert result == [
|
||||
(1, "John", "john@example.com"),
|
||||
(2, "Jane", "jane@example.com")
|
||||
]
|
||||
|
||||
mock_cursor.execute.assert_called_once_with(query, params)
|
||||
mock_cursor.fetchall.assert_called_once()
|
||||
|
||||
def test_execute_query_single_result(self, mock_database_connection):
|
||||
"""Test query execution with single result."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.fetchone.return_value = (1, "John", "john@example.com")
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
result = conn.execute_query("SELECT * FROM users WHERE id = %s", [1], fetch_one=True)
|
||||
|
||||
assert result == (1, "John", "john@example.com")
|
||||
mock_cursor.fetchone.assert_called_once()
|
||||
|
||||
def test_execute_query_no_fetch(self, mock_database_connection):
|
||||
"""Test query execution without fetching results."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.rowcount = 1
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
result = conn.execute_query(
|
||||
"INSERT INTO users (name) VALUES (%s)",
|
||||
["John"],
|
||||
fetch_results=False
|
||||
)
|
||||
|
||||
assert result == 1 # Row count
|
||||
mock_cursor.execute.assert_called_once()
|
||||
mock_cursor.fetchall.assert_not_called()
|
||||
mock_cursor.fetchone.assert_not_called()
|
||||
|
||||
def test_execute_query_error(self, mock_database_connection):
|
||||
"""Test query execution error handling."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.execute.side_effect = psycopg2.Error("Database error")
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
with pytest.raises(DatabaseError) as exc_info:
|
||||
conn.execute_query("SELECT * FROM invalid_table")
|
||||
|
||||
assert "Database error" in str(exc_info.value)
|
||||
|
||||
def test_commit_transaction(self, mock_database_connection):
|
||||
"""Test transaction commit."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
conn.commit()
|
||||
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
|
||||
def test_rollback_transaction(self, mock_database_connection):
|
||||
"""Test transaction rollback."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
conn.rollback()
|
||||
|
||||
mock_database_connection.rollback.assert_called_once()
|
||||
|
||||
def test_close_connection(self, mock_database_connection):
|
||||
"""Test connection closing."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
conn.close()
|
||||
|
||||
assert conn.is_connected is False
|
||||
mock_database_connection.close.assert_called_once()
|
||||
|
||||
|
||||
class TestTransactionManager:
|
||||
"""Test transaction management."""
|
||||
|
||||
def test_transaction_context_success(self, mock_database_connection):
|
||||
"""Test successful transaction context."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
tx_manager = TransactionManager(conn)
|
||||
|
||||
with tx_manager:
|
||||
# Simulate some database operations
|
||||
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
|
||||
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
|
||||
|
||||
# Should commit on successful exit
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
mock_database_connection.rollback.assert_not_called()
|
||||
|
||||
def test_transaction_context_error(self, mock_database_connection):
|
||||
"""Test transaction context with error."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
tx_manager = TransactionManager(conn)
|
||||
|
||||
with pytest.raises(DatabaseError):
|
||||
with tx_manager:
|
||||
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
|
||||
# Simulate an error
|
||||
raise DatabaseError("Something went wrong")
|
||||
|
||||
# Should rollback on error
|
||||
mock_database_connection.rollback.assert_called_once()
|
||||
mock_database_connection.commit.assert_not_called()
|
||||
|
||||
|
||||
class TestDatabaseManager:
|
||||
"""Test main database manager functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test database manager initialization."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
assert manager.config == config
|
||||
assert isinstance(manager.query_builder, QueryBuilder)
|
||||
assert manager.query_builder.schema == "gas_station_1"
|
||||
assert manager.connection is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_success(self):
|
||||
"""Test successful database connection."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
with patch('psycopg2.connect') as mock_connect:
|
||||
mock_connection = Mock()
|
||||
mock_connect.return_value = mock_connection
|
||||
|
||||
await manager.connect()
|
||||
|
||||
assert manager.connection is not None
|
||||
assert manager.is_connected is True
|
||||
mock_connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_failure(self):
|
||||
"""Test database connection failure."""
|
||||
config = DatabaseConfig(
|
||||
host="nonexistent-host",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
with patch('psycopg2.connect') as mock_connect:
|
||||
mock_connect.side_effect = psycopg2.Error("Connection failed")
|
||||
|
||||
with pytest.raises(DatabaseError) as exc_info:
|
||||
await manager.connect()
|
||||
|
||||
assert "Connection failed" in str(exc_info.value)
|
||||
assert manager.is_connected is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect(self):
|
||||
"""Test database disconnection."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
# Mock connection
|
||||
mock_connection = Mock()
|
||||
manager.connection = DatabaseConnection(config, mock_connection)
|
||||
|
||||
await manager.disconnect()
|
||||
|
||||
assert manager.connection is None
|
||||
mock_connection.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query(self, mock_database_connection):
|
||||
"""Test query execution through manager."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behavior
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.fetchall.return_value = [(1, "Test"), (2, "Data")]
|
||||
|
||||
result = await manager.execute_query("SELECT * FROM test_table")
|
||||
|
||||
assert result == [(1, "Test"), (2, "Data")]
|
||||
mock_cursor.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_not_connected(self):
|
||||
"""Test query execution when not connected."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
with pytest.raises(DatabaseError) as exc_info:
|
||||
await manager.execute_query("SELECT * FROM test_table")
|
||||
|
||||
assert "not connected" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_record(self, mock_database_connection):
|
||||
"""Test inserting a record."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behavior
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.fetchone.return_value = (123,) # Returned ID
|
||||
|
||||
data = {
|
||||
"session_id": "session_123",
|
||||
"camera_id": "camera_001",
|
||||
"detection_class": "car",
|
||||
"confidence": 0.95,
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
|
||||
record_id = await manager.insert_record("car_detections", data)
|
||||
|
||||
assert record_id == 123
|
||||
mock_cursor.execute.assert_called_once()
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_record(self, mock_database_connection):
|
||||
"""Test updating a record."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behavior
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.rowcount = 1
|
||||
|
||||
data = {
|
||||
"car_brand": "Toyota",
|
||||
"car_body_type": "Sedan",
|
||||
"updated_at": "NOW()"
|
||||
}
|
||||
|
||||
where_conditions = {"session_id": "session_123"}
|
||||
|
||||
rows_affected = await manager.update_record("car_info", data, where_conditions)
|
||||
|
||||
assert rows_affected == 1
|
||||
mock_cursor.execute.assert_called_once()
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_records(self, mock_database_connection):
|
||||
"""Test deleting records."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behavior
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.rowcount = 3
|
||||
|
||||
where_conditions = {
|
||||
"created_at": "< NOW() - INTERVAL '30 days'",
|
||||
"processed": True
|
||||
}
|
||||
|
||||
rows_deleted = await manager.delete_records("old_detections", where_conditions)
|
||||
|
||||
assert rows_deleted == 3
|
||||
mock_cursor.execute.assert_called_once()
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_table(self, mock_database_connection):
|
||||
"""Test creating a table."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
columns = {
|
||||
"id": "SERIAL PRIMARY KEY",
|
||||
"session_id": "VARCHAR(255) UNIQUE NOT NULL",
|
||||
"camera_id": "VARCHAR(255) NOT NULL",
|
||||
"detection_data": "JSON",
|
||||
"created_at": "TIMESTAMP DEFAULT NOW()"
|
||||
}
|
||||
|
||||
await manager.create_table("test_detections", columns)
|
||||
|
||||
mock_database_connection.cursor.return_value.execute.assert_called_once()
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_exists(self, mock_database_connection):
|
||||
"""Test checking if table exists."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behavior - table exists
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.fetchone.return_value = (1,)
|
||||
|
||||
exists = await manager.table_exists("car_detections")
|
||||
|
||||
assert exists is True
|
||||
mock_cursor.execute.assert_called_once()
|
||||
|
||||
# Mock cursor behavior - table doesn't exist
|
||||
mock_cursor.fetchone.return_value = None
|
||||
|
||||
exists = await manager.table_exists("nonexistent_table")
|
||||
|
||||
assert exists is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transaction_context(self, mock_database_connection):
|
||||
"""Test transaction context manager."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
async with manager.transaction():
|
||||
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
|
||||
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
|
||||
|
||||
# Should commit on successful completion
|
||||
mock_database_connection.commit.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_schema(self, mock_database_connection):
|
||||
"""Test getting table schema information."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behavior
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.fetchall.return_value = [
|
||||
("id", "integer", "NOT NULL"),
|
||||
("session_id", "character varying", "NOT NULL"),
|
||||
("created_at", "timestamp without time zone", "DEFAULT now()")
|
||||
]
|
||||
|
||||
schema = await manager.get_table_schema("car_detections")
|
||||
|
||||
assert len(schema) == 3
|
||||
assert schema[0] == ("id", "integer", "NOT NULL")
|
||||
assert schema[1] == ("session_id", "character varying", "NOT NULL")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_insert(self, mock_database_connection):
|
||||
"""Test bulk insert operation."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
records = [
|
||||
{"name": "John", "email": "john@example.com"},
|
||||
{"name": "Jane", "email": "jane@example.com"},
|
||||
{"name": "Bob", "email": "bob@example.com"}
|
||||
]
|
||||
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.rowcount = 3
|
||||
|
||||
rows_inserted = await manager.bulk_insert("users", records)
|
||||
|
||||
assert rows_inserted == 3
|
||||
mock_cursor.executemany.assert_called_once()
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_connection_stats(self, mock_database_connection):
|
||||
"""Test getting connection statistics."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
stats = manager.get_connection_stats()
|
||||
|
||||
assert "connected" in stats
|
||||
assert "host" in stats
|
||||
assert "database" in stats
|
||||
assert "schema" in stats
|
||||
assert stats["connected"] is True
|
||||
assert stats["host"] == "localhost"
|
||||
assert stats["database"] == "test_db"
|
||||
|
||||
|
||||
class TestDatabaseManagerIntegration:
|
||||
"""Integration tests for database manager."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_car_detection_workflow(self, mock_database_connection):
|
||||
"""Test complete car detection database workflow."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="gas_station_db",
|
||||
username="detector_user",
|
||||
password="detector_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behaviors for different operations
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
|
||||
# 1. Create initial detection record
|
||||
mock_cursor.fetchone.return_value = (456,) # Returned ID
|
||||
|
||||
detection_data = {
|
||||
"session_id": str(uuid.uuid4()),
|
||||
"camera_id": "camera_001",
|
||||
"display_id": "display_001",
|
||||
"detection_class": "car",
|
||||
"confidence": 0.92,
|
||||
"bbox_x1": 100,
|
||||
"bbox_y1": 200,
|
||||
"bbox_x2": 300,
|
||||
"bbox_y2": 400,
|
||||
"track_id": 1001,
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
|
||||
detection_id = await manager.insert_record("car_detections", detection_data)
|
||||
assert detection_id == 456
|
||||
|
||||
# 2. Update with classification results
|
||||
mock_cursor.rowcount = 1
|
||||
|
||||
classification_data = {
|
||||
"car_brand": "Toyota",
|
||||
"car_model": "Camry",
|
||||
"car_body_type": "Sedan",
|
||||
"car_color": "Blue",
|
||||
"brand_confidence": 0.87,
|
||||
"bodytype_confidence": 0.82,
|
||||
"color_confidence": 0.79,
|
||||
"updated_at": "NOW()"
|
||||
}
|
||||
|
||||
where_conditions = {"session_id": detection_data["session_id"]}
|
||||
|
||||
rows_updated = await manager.update_record("car_detections", classification_data, where_conditions)
|
||||
assert rows_updated == 1
|
||||
|
||||
# 3. Query final results
|
||||
mock_cursor.fetchall.return_value = [
|
||||
(456, detection_data["session_id"], "camera_001", "car", 0.92, "Toyota", "Sedan")
|
||||
]
|
||||
|
||||
results = await manager.execute_query(
|
||||
"SELECT id, session_id, camera_id, detection_class, confidence, car_brand, car_body_type "
|
||||
"FROM gas_station_1.car_detections WHERE session_id = %s",
|
||||
[detection_data["session_id"]]
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0][0] == 456 # ID
|
||||
assert results[0][3] == "car" # detection_class
|
||||
assert results[0][5] == "Toyota" # car_brand
|
||||
|
||||
# Verify all database operations were called
|
||||
assert mock_cursor.execute.call_count == 3
|
||||
assert mock_database_connection.commit.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_and_recovery(self, mock_database_connection):
|
||||
"""Test error handling and recovery scenarios."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Test transaction rollback on error
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
|
||||
with pytest.raises(DatabaseError):
|
||||
async with manager.transaction():
|
||||
# First operation succeeds
|
||||
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
|
||||
|
||||
# Second operation fails
|
||||
mock_cursor.execute.side_effect = psycopg2.Error("Constraint violation")
|
||||
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
|
||||
|
||||
# Should have rolled back
|
||||
mock_database_connection.rollback.assert_called_once()
|
||||
mock_database_connection.commit.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_recovery(self):
|
||||
"""Test automatic connection recovery."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
with patch('psycopg2.connect') as mock_connect:
|
||||
# First connection attempt fails
|
||||
mock_connect.side_effect = [
|
||||
psycopg2.Error("Connection refused"),
|
||||
Mock() # Second attempt succeeds
|
||||
]
|
||||
|
||||
# First attempt should fail
|
||||
with pytest.raises(DatabaseError):
|
||||
await manager.connect()
|
||||
|
||||
# Second attempt should succeed
|
||||
await manager.connect()
|
||||
assert manager.is_connected is True
|
964
tests/unit/storage/test_redis_client.py
Normal file
964
tests/unit/storage/test_redis_client.py
Normal file
|
@ -0,0 +1,964 @@
|
|||
"""
|
||||
Unit tests for Redis client functionality.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
import time
|
||||
from unittest.mock import Mock, MagicMock, patch, AsyncMock
|
||||
from datetime import datetime, timedelta
|
||||
import redis
|
||||
import numpy as np
|
||||
|
||||
from detector_worker.storage.redis_client import (
|
||||
RedisClient,
|
||||
RedisConfig,
|
||||
RedisConnectionPool,
|
||||
RedisPublisher,
|
||||
RedisSubscriber,
|
||||
RedisImageStorage,
|
||||
RedisError,
|
||||
ConnectionPoolError
|
||||
)
|
||||
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
|
||||
from detector_worker.core.exceptions import ConfigurationError
|
||||
|
||||
|
||||
class TestRedisConfig:
|
||||
"""Test Redis configuration."""
|
||||
|
||||
def test_creation_minimal(self):
|
||||
"""Test creating Redis config with minimal parameters."""
|
||||
config = RedisConfig(
|
||||
host="localhost"
|
||||
)
|
||||
|
||||
assert config.host == "localhost"
|
||||
assert config.port == 6379 # Default port
|
||||
assert config.password is None
|
||||
assert config.db == 0 # Default database
|
||||
assert config.enabled is True
|
||||
|
||||
def test_creation_full(self):
|
||||
"""Test creating Redis config with all parameters."""
|
||||
config = RedisConfig(
|
||||
host="redis.example.com",
|
||||
port=6380,
|
||||
password="secure_pass",
|
||||
db=2,
|
||||
enabled=True,
|
||||
connection_timeout=5.0,
|
||||
socket_timeout=3.0,
|
||||
socket_connect_timeout=2.0,
|
||||
max_connections=50,
|
||||
retry_on_timeout=True,
|
||||
health_check_interval=30
|
||||
)
|
||||
|
||||
assert config.host == "redis.example.com"
|
||||
assert config.port == 6380
|
||||
assert config.password == "secure_pass"
|
||||
assert config.db == 2
|
||||
assert config.connection_timeout == 5.0
|
||||
assert config.max_connections == 50
|
||||
assert config.retry_on_timeout is True
|
||||
|
||||
def test_get_connection_params(self):
|
||||
"""Test getting Redis connection parameters."""
|
||||
config = RedisConfig(
|
||||
host="localhost",
|
||||
port=6379,
|
||||
password="test_pass",
|
||||
db=1,
|
||||
connection_timeout=10.0
|
||||
)
|
||||
|
||||
params = config.get_connection_params()
|
||||
|
||||
assert params["host"] == "localhost"
|
||||
assert params["port"] == 6379
|
||||
assert params["password"] == "test_pass"
|
||||
assert params["db"] == 1
|
||||
assert params["socket_timeout"] == 10.0
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test creating config from dictionary."""
|
||||
config_dict = {
|
||||
"host": "redis-server",
|
||||
"port": 6380,
|
||||
"password": "secret",
|
||||
"db": 3,
|
||||
"max_connections": 100,
|
||||
"unknown_field": "ignored"
|
||||
}
|
||||
|
||||
config = RedisConfig.from_dict(config_dict)
|
||||
|
||||
assert config.host == "redis-server"
|
||||
assert config.port == 6380
|
||||
assert config.password == "secret"
|
||||
assert config.db == 3
|
||||
assert config.max_connections == 100
|
||||
|
||||
|
||||
class TestRedisConnectionPool:
|
||||
"""Test Redis connection pool management."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test connection pool creation."""
|
||||
config = RedisConfig(
|
||||
host="localhost",
|
||||
max_connections=20
|
||||
)
|
||||
|
||||
pool = RedisConnectionPool(config)
|
||||
|
||||
assert pool.config == config
|
||||
assert pool.pool is None
|
||||
assert pool.is_connected is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_success(self):
|
||||
"""Test successful connection to Redis."""
|
||||
config = RedisConfig(host="localhost")
|
||||
pool = RedisConnectionPool(config)
|
||||
|
||||
with patch('redis.ConnectionPool') as mock_pool_class:
|
||||
mock_pool = Mock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
with patch('redis.Redis') as mock_redis_class:
|
||||
mock_redis = Mock()
|
||||
mock_redis.ping.return_value = True
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
await pool.connect()
|
||||
|
||||
assert pool.is_connected is True
|
||||
assert pool.pool is not None
|
||||
mock_pool_class.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_failure(self):
|
||||
"""Test Redis connection failure."""
|
||||
config = RedisConfig(host="nonexistent-redis")
|
||||
pool = RedisConnectionPool(config)
|
||||
|
||||
with patch('redis.ConnectionPool') as mock_pool_class:
|
||||
mock_pool_class.side_effect = redis.ConnectionError("Connection failed")
|
||||
|
||||
with pytest.raises(RedisError) as exc_info:
|
||||
await pool.connect()
|
||||
|
||||
assert "Connection failed" in str(exc_info.value)
|
||||
assert pool.is_connected is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect(self):
|
||||
"""Test Redis disconnection."""
|
||||
config = RedisConfig(host="localhost")
|
||||
pool = RedisConnectionPool(config)
|
||||
|
||||
# Mock connected state
|
||||
mock_pool = Mock()
|
||||
mock_redis = Mock()
|
||||
pool.pool = mock_pool
|
||||
pool._redis_client = mock_redis
|
||||
pool.is_connected = True
|
||||
|
||||
await pool.disconnect()
|
||||
|
||||
assert pool.is_connected is False
|
||||
assert pool.pool is None
|
||||
mock_pool.disconnect.assert_called_once()
|
||||
|
||||
def test_get_client_connected(self):
|
||||
"""Test getting Redis client when connected."""
|
||||
config = RedisConfig(host="localhost")
|
||||
pool = RedisConnectionPool(config)
|
||||
|
||||
mock_pool = Mock()
|
||||
mock_redis = Mock()
|
||||
pool.pool = mock_pool
|
||||
pool._redis_client = mock_redis
|
||||
pool.is_connected = True
|
||||
|
||||
client = pool.get_client()
|
||||
assert client == mock_redis
|
||||
|
||||
def test_get_client_not_connected(self):
|
||||
"""Test getting Redis client when not connected."""
|
||||
config = RedisConfig(host="localhost")
|
||||
pool = RedisConnectionPool(config)
|
||||
|
||||
with pytest.raises(RedisError) as exc_info:
|
||||
pool.get_client()
|
||||
|
||||
assert "not connected" in str(exc_info.value).lower()
|
||||
|
||||
def test_health_check(self):
|
||||
"""Test Redis health check."""
|
||||
config = RedisConfig(host="localhost")
|
||||
pool = RedisConnectionPool(config)
|
||||
|
||||
mock_redis = Mock()
|
||||
mock_redis.ping.return_value = True
|
||||
pool._redis_client = mock_redis
|
||||
pool.is_connected = True
|
||||
|
||||
is_healthy = pool.health_check()
|
||||
|
||||
assert is_healthy is True
|
||||
mock_redis.ping.assert_called_once()
|
||||
|
||||
def test_health_check_failure(self):
|
||||
"""Test Redis health check failure."""
|
||||
config = RedisConfig(host="localhost")
|
||||
pool = RedisConnectionPool(config)
|
||||
|
||||
mock_redis = Mock()
|
||||
mock_redis.ping.side_effect = redis.ConnectionError("Connection lost")
|
||||
pool._redis_client = mock_redis
|
||||
pool.is_connected = True
|
||||
|
||||
is_healthy = pool.health_check()
|
||||
|
||||
assert is_healthy is False
|
||||
|
||||
|
||||
class TestRedisImageStorage:
|
||||
"""Test Redis image storage functionality."""
|
||||
|
||||
def test_creation(self, mock_redis_client):
|
||||
"""Test Redis image storage creation."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
assert storage.redis_client == mock_redis_client
|
||||
assert storage.default_expiry == 3600 # 1 hour
|
||||
assert storage.compression_enabled is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_image_success(self, mock_redis_client, mock_frame):
|
||||
"""Test successful image storage."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
mock_redis_client.set.return_value = True
|
||||
mock_redis_client.expire.return_value = True
|
||||
|
||||
with patch('cv2.imencode') as mock_imencode:
|
||||
# Mock successful encoding
|
||||
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
|
||||
mock_imencode.return_value = (True, encoded_data)
|
||||
|
||||
result = await storage.store_image("test_key", mock_frame, expire_seconds=600)
|
||||
|
||||
assert result is True
|
||||
mock_redis_client.set.assert_called_once()
|
||||
mock_redis_client.expire.assert_called_once_with("test_key", 600)
|
||||
mock_imencode.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_image_cropped(self, mock_redis_client, mock_frame):
|
||||
"""Test storing cropped image."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
mock_redis_client.set.return_value = True
|
||||
mock_redis_client.expire.return_value = True
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
with patch('cv2.imencode') as mock_imencode:
|
||||
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
|
||||
mock_imencode.return_value = (True, encoded_data)
|
||||
|
||||
result = await storage.store_image("cropped_key", mock_frame, crop_bbox=bbox)
|
||||
|
||||
assert result is True
|
||||
mock_redis_client.set.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_image_encoding_failure(self, mock_redis_client, mock_frame):
|
||||
"""Test image storage with encoding failure."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
with patch('cv2.imencode') as mock_imencode:
|
||||
# Mock encoding failure
|
||||
mock_imencode.return_value = (False, None)
|
||||
|
||||
with pytest.raises(RedisError) as exc_info:
|
||||
await storage.store_image("test_key", mock_frame)
|
||||
|
||||
assert "Failed to encode image" in str(exc_info.value)
|
||||
mock_redis_client.set.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_image_redis_failure(self, mock_redis_client, mock_frame):
|
||||
"""Test image storage with Redis failure."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
mock_redis_client.set.side_effect = redis.RedisError("Redis error")
|
||||
|
||||
with patch('cv2.imencode') as mock_imencode:
|
||||
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
|
||||
mock_imencode.return_value = (True, encoded_data)
|
||||
|
||||
with pytest.raises(RedisError) as exc_info:
|
||||
await storage.store_image("test_key", mock_frame)
|
||||
|
||||
assert "Redis error" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_image_success(self, mock_redis_client):
|
||||
"""Test successful image retrieval."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
# Mock encoded image data
|
||||
original_image = np.ones((100, 100, 3), dtype=np.uint8) * 128
|
||||
|
||||
with patch('cv2.imencode') as mock_imencode:
|
||||
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
|
||||
mock_imencode.return_value = (True, encoded_data)
|
||||
|
||||
# Mock Redis returning base64 encoded data
|
||||
base64_data = base64.b64encode(encoded_data.tobytes()).decode('utf-8')
|
||||
mock_redis_client.get.return_value = base64_data
|
||||
|
||||
with patch('cv2.imdecode') as mock_imdecode:
|
||||
mock_imdecode.return_value = original_image
|
||||
|
||||
retrieved_image = await storage.retrieve_image("test_key")
|
||||
|
||||
assert retrieved_image is not None
|
||||
assert retrieved_image.shape == (100, 100, 3)
|
||||
mock_redis_client.get.assert_called_once_with("test_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_image_not_found(self, mock_redis_client):
|
||||
"""Test image retrieval when key not found."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
retrieved_image = await storage.retrieve_image("nonexistent_key")
|
||||
|
||||
assert retrieved_image is None
|
||||
mock_redis_client.get.assert_called_once_with("nonexistent_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_image(self, mock_redis_client):
|
||||
"""Test image deletion."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
mock_redis_client.delete.return_value = 1
|
||||
|
||||
result = await storage.delete_image("test_key")
|
||||
|
||||
assert result is True
|
||||
mock_redis_client.delete.assert_called_once_with("test_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_image_not_found(self, mock_redis_client):
|
||||
"""Test deleting non-existent image."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
mock_redis_client.delete.return_value = 0
|
||||
|
||||
result = await storage.delete_image("nonexistent_key")
|
||||
|
||||
assert result is False
|
||||
mock_redis_client.delete.assert_called_once_with("nonexistent_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_delete_images(self, mock_redis_client):
|
||||
"""Test bulk image deletion."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
keys = ["key1", "key2", "key3"]
|
||||
mock_redis_client.delete.return_value = 3
|
||||
|
||||
deleted_count = await storage.bulk_delete_images(keys)
|
||||
|
||||
assert deleted_count == 3
|
||||
mock_redis_client.delete.assert_called_once_with(*keys)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_images(self, mock_redis_client):
|
||||
"""Test cleanup of expired images."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
# Mock scan to return image keys
|
||||
mock_redis_client.scan_iter.return_value = [
|
||||
b"inference:camera1:image1",
|
||||
b"inference:camera2:image2",
|
||||
b"inference:camera1:image3"
|
||||
]
|
||||
|
||||
# Mock ttl to return different expiry times
|
||||
mock_redis_client.ttl.side_effect = [-1, 100, -2] # No expiry, valid, expired
|
||||
mock_redis_client.delete.return_value = 1
|
||||
|
||||
deleted_count = await storage.cleanup_expired_images("inference:*")
|
||||
|
||||
assert deleted_count == 1 # Only expired images deleted
|
||||
mock_redis_client.delete.assert_called_once()
|
||||
|
||||
def test_get_image_info(self, mock_redis_client):
|
||||
"""Test getting image metadata."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
mock_redis_client.exists.return_value = 1
|
||||
mock_redis_client.ttl.return_value = 1800 # 30 minutes
|
||||
mock_redis_client.memory_usage.return_value = 4096 # 4KB
|
||||
|
||||
info = storage.get_image_info("test_key")
|
||||
|
||||
assert info["exists"] is True
|
||||
assert info["ttl"] == 1800
|
||||
assert info["size_bytes"] == 4096
|
||||
|
||||
mock_redis_client.exists.assert_called_once_with("test_key")
|
||||
mock_redis_client.ttl.assert_called_once_with("test_key")
|
||||
|
||||
|
||||
class TestRedisPublisher:
|
||||
"""Test Redis publisher functionality."""
|
||||
|
||||
def test_creation(self, mock_redis_client):
|
||||
"""Test Redis publisher creation."""
|
||||
publisher = RedisPublisher(mock_redis_client)
|
||||
|
||||
assert publisher.redis_client == mock_redis_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_message_string(self, mock_redis_client):
|
||||
"""Test publishing string message."""
|
||||
publisher = RedisPublisher(mock_redis_client)
|
||||
|
||||
mock_redis_client.publish.return_value = 2 # 2 subscribers
|
||||
|
||||
result = await publisher.publish("test_channel", "Hello, Redis!")
|
||||
|
||||
assert result == 2
|
||||
mock_redis_client.publish.assert_called_once_with("test_channel", "Hello, Redis!")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_message_json(self, mock_redis_client):
|
||||
"""Test publishing JSON message."""
|
||||
publisher = RedisPublisher(mock_redis_client)
|
||||
|
||||
mock_redis_client.publish.return_value = 1
|
||||
|
||||
message_data = {
|
||||
"camera_id": "camera_001",
|
||||
"detection_class": "car",
|
||||
"confidence": 0.95,
|
||||
"timestamp": 1640995200000
|
||||
}
|
||||
|
||||
result = await publisher.publish("detections", message_data)
|
||||
|
||||
assert result == 1
|
||||
|
||||
# Should have been JSON serialized
|
||||
expected_json = json.dumps(message_data)
|
||||
mock_redis_client.publish.assert_called_once_with("detections", expected_json)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_detection_event(self, mock_redis_client):
|
||||
"""Test publishing detection event."""
|
||||
publisher = RedisPublisher(mock_redis_client)
|
||||
|
||||
mock_redis_client.publish.return_value = 3
|
||||
|
||||
detection = DetectionResult("car", 0.92, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
|
||||
|
||||
result = await publisher.publish_detection_event(
|
||||
"camera_detections",
|
||||
detection,
|
||||
camera_id="camera_001",
|
||||
session_id="session_123"
|
||||
)
|
||||
|
||||
assert result == 3
|
||||
|
||||
# Verify the published message structure
|
||||
call_args = mock_redis_client.publish.call_args
|
||||
channel = call_args[0][0]
|
||||
message_str = call_args[0][1]
|
||||
message_data = json.loads(message_str)
|
||||
|
||||
assert channel == "camera_detections"
|
||||
assert message_data["event_type"] == "detection"
|
||||
assert message_data["camera_id"] == "camera_001"
|
||||
assert message_data["session_id"] == "session_123"
|
||||
assert message_data["detection"]["class"] == "car"
|
||||
assert message_data["detection"]["confidence"] == 0.92
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_batch_messages(self, mock_redis_client):
|
||||
"""Test publishing multiple messages in batch."""
|
||||
publisher = RedisPublisher(mock_redis_client)
|
||||
|
||||
mock_pipeline = Mock()
|
||||
mock_redis_client.pipeline.return_value = mock_pipeline
|
||||
mock_pipeline.execute.return_value = [1, 2, 1] # Subscriber counts
|
||||
|
||||
messages = [
|
||||
("channel1", "message1"),
|
||||
("channel2", {"data": "message2"}),
|
||||
("channel1", "message3")
|
||||
]
|
||||
|
||||
results = await publisher.publish_batch(messages)
|
||||
|
||||
assert results == [1, 2, 1]
|
||||
mock_redis_client.pipeline.assert_called_once()
|
||||
assert mock_pipeline.publish.call_count == 3
|
||||
mock_pipeline.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_error_handling(self, mock_redis_client):
|
||||
"""Test error handling in publishing."""
|
||||
publisher = RedisPublisher(mock_redis_client)
|
||||
|
||||
mock_redis_client.publish.side_effect = redis.RedisError("Publish failed")
|
||||
|
||||
with pytest.raises(RedisError) as exc_info:
|
||||
await publisher.publish("test_channel", "test_message")
|
||||
|
||||
assert "Publish failed" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestRedisSubscriber:
|
||||
"""Test Redis subscriber functionality."""
|
||||
|
||||
def test_creation(self, mock_redis_client):
|
||||
"""Test Redis subscriber creation."""
|
||||
subscriber = RedisSubscriber(mock_redis_client)
|
||||
|
||||
assert subscriber.redis_client == mock_redis_client
|
||||
assert subscriber.pubsub is None
|
||||
assert subscriber.subscriptions == set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_to_channel(self, mock_redis_client):
|
||||
"""Test subscribing to a channel."""
|
||||
subscriber = RedisSubscriber(mock_redis_client)
|
||||
|
||||
mock_pubsub = Mock()
|
||||
mock_redis_client.pubsub.return_value = mock_pubsub
|
||||
|
||||
await subscriber.subscribe("test_channel")
|
||||
|
||||
assert "test_channel" in subscriber.subscriptions
|
||||
mock_pubsub.subscribe.assert_called_once_with("test_channel")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_to_pattern(self, mock_redis_client):
|
||||
"""Test subscribing to a pattern."""
|
||||
subscriber = RedisSubscriber(mock_redis_client)
|
||||
|
||||
mock_pubsub = Mock()
|
||||
mock_redis_client.pubsub.return_value = mock_pubsub
|
||||
|
||||
await subscriber.subscribe_pattern("detection:*")
|
||||
|
||||
assert "detection:*" in subscriber.subscriptions
|
||||
mock_pubsub.psubscribe.assert_called_once_with("detection:*")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe_from_channel(self, mock_redis_client):
|
||||
"""Test unsubscribing from a channel."""
|
||||
subscriber = RedisSubscriber(mock_redis_client)
|
||||
|
||||
mock_pubsub = Mock()
|
||||
mock_redis_client.pubsub.return_value = mock_pubsub
|
||||
subscriber.pubsub = mock_pubsub
|
||||
subscriber.subscriptions.add("test_channel")
|
||||
|
||||
await subscriber.unsubscribe("test_channel")
|
||||
|
||||
assert "test_channel" not in subscriber.subscriptions
|
||||
mock_pubsub.unsubscribe.assert_called_once_with("test_channel")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_for_messages(self, mock_redis_client):
|
||||
"""Test listening for messages."""
|
||||
subscriber = RedisSubscriber(mock_redis_client)
|
||||
|
||||
mock_pubsub = Mock()
|
||||
mock_redis_client.pubsub.return_value = mock_pubsub
|
||||
|
||||
# Mock message stream
|
||||
messages = [
|
||||
{"type": "subscribe", "channel": "test", "data": 1},
|
||||
{"type": "message", "channel": "test", "data": "Hello"},
|
||||
{"type": "message", "channel": "test", "data": '{"key": "value"}'}
|
||||
]
|
||||
|
||||
mock_pubsub.listen.return_value = iter(messages)
|
||||
|
||||
received_messages = []
|
||||
message_count = 0
|
||||
|
||||
async for message in subscriber.listen():
|
||||
received_messages.append(message)
|
||||
message_count += 1
|
||||
if message_count >= 2: # Only process actual messages
|
||||
break
|
||||
|
||||
# Should receive 2 actual messages (excluding subscribe confirmation)
|
||||
assert len(received_messages) == 2
|
||||
assert received_messages[0]["data"] == "Hello"
|
||||
assert received_messages[1]["data"] == {"key": "value"} # Should be parsed as JSON
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_subscription(self, mock_redis_client):
|
||||
"""Test closing subscription."""
|
||||
subscriber = RedisSubscriber(mock_redis_client)
|
||||
|
||||
mock_pubsub = Mock()
|
||||
subscriber.pubsub = mock_pubsub
|
||||
subscriber.subscriptions = {"channel1", "pattern:*"}
|
||||
|
||||
await subscriber.close()
|
||||
|
||||
assert len(subscriber.subscriptions) == 0
|
||||
mock_pubsub.close.assert_called_once()
|
||||
assert subscriber.pubsub is None
|
||||
|
||||
|
||||
class TestRedisClient:
|
||||
"""Test main Redis client functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test Redis client initialization."""
|
||||
config = RedisConfig(host="localhost", port=6379)
|
||||
client = RedisClient(config)
|
||||
|
||||
assert client.config == config
|
||||
assert isinstance(client.connection_pool, RedisConnectionPool)
|
||||
assert client.image_storage is None
|
||||
assert client.publisher is None
|
||||
assert client.subscriber is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_and_initialize_components(self):
|
||||
"""Test connecting and initializing all components."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
with patch.object(client.connection_pool, 'connect', new_callable=AsyncMock) as mock_connect:
|
||||
mock_redis_client = Mock()
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
await client.connect()
|
||||
|
||||
assert client.image_storage is not None
|
||||
assert client.publisher is not None
|
||||
assert client.subscriber is not None
|
||||
assert isinstance(client.image_storage, RedisImageStorage)
|
||||
assert isinstance(client.publisher, RedisPublisher)
|
||||
assert isinstance(client.subscriber, RedisSubscriber)
|
||||
|
||||
mock_connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect(self):
|
||||
"""Test disconnection."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.is_connected = True
|
||||
client.subscriber = Mock()
|
||||
client.subscriber.close = AsyncMock()
|
||||
|
||||
with patch.object(client.connection_pool, 'disconnect', new_callable=AsyncMock) as mock_disconnect:
|
||||
await client.disconnect()
|
||||
|
||||
client.subscriber.close.assert_called_once()
|
||||
mock_disconnect.assert_called_once()
|
||||
|
||||
assert client.image_storage is None
|
||||
assert client.publisher is None
|
||||
assert client.subscriber is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_and_retrieve_data(self, mock_redis_client):
|
||||
"""Test storing and retrieving data."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
# Test storing data
|
||||
mock_redis_client.set.return_value = True
|
||||
result = await client.set("test_key", "test_value", expire_seconds=300)
|
||||
assert result is True
|
||||
mock_redis_client.set.assert_called_once_with("test_key", "test_value")
|
||||
mock_redis_client.expire.assert_called_once_with("test_key", 300)
|
||||
|
||||
# Test retrieving data
|
||||
mock_redis_client.get.return_value = "test_value"
|
||||
value = await client.get("test_key")
|
||||
assert value == "test_value"
|
||||
mock_redis_client.get.assert_called_once_with("test_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keys(self, mock_redis_client):
|
||||
"""Test deleting keys."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
mock_redis_client.delete.return_value = 2
|
||||
|
||||
result = await client.delete("key1", "key2")
|
||||
|
||||
assert result == 2
|
||||
mock_redis_client.delete.assert_called_once_with("key1", "key2")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists_check(self, mock_redis_client):
|
||||
"""Test checking key existence."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
mock_redis_client.exists.return_value = 1
|
||||
|
||||
exists = await client.exists("test_key")
|
||||
|
||||
assert exists is True
|
||||
mock_redis_client.exists.assert_called_once_with("test_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expire_key(self, mock_redis_client):
|
||||
"""Test setting key expiration."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
mock_redis_client.expire.return_value = True
|
||||
|
||||
result = await client.expire("test_key", 600)
|
||||
|
||||
assert result is True
|
||||
mock_redis_client.expire.assert_called_once_with("test_key", 600)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_ttl(self, mock_redis_client):
|
||||
"""Test getting key TTL."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
mock_redis_client.ttl.return_value = 300
|
||||
|
||||
ttl = await client.ttl("test_key")
|
||||
|
||||
assert ttl == 300
|
||||
mock_redis_client.ttl.assert_called_once_with("test_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_keys(self, mock_redis_client):
|
||||
"""Test scanning for keys."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
mock_redis_client.scan_iter.return_value = [b"key1", b"key2", b"key3"]
|
||||
|
||||
keys = await client.scan_keys("test:*")
|
||||
|
||||
assert keys == ["key1", "key2", "key3"]
|
||||
mock_redis_client.scan_iter.assert_called_once_with(match="test:*")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_database(self, mock_redis_client):
|
||||
"""Test flushing database."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
mock_redis_client.flushdb.return_value = True
|
||||
|
||||
result = await client.flush_db()
|
||||
|
||||
assert result is True
|
||||
mock_redis_client.flushdb.assert_called_once()
|
||||
|
||||
def test_get_connection_info(self):
|
||||
"""Test getting connection information."""
|
||||
config = RedisConfig(
|
||||
host="redis.example.com",
|
||||
port=6380,
|
||||
db=2
|
||||
)
|
||||
client = RedisClient(config)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
info = client.get_connection_info()
|
||||
|
||||
assert info["connected"] is True
|
||||
assert info["host"] == "redis.example.com"
|
||||
assert info["port"] == 6380
|
||||
assert info["database"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_operations(self, mock_redis_client):
|
||||
"""Test Redis pipeline operations."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
mock_pipeline = Mock()
|
||||
mock_redis_client.pipeline.return_value = mock_pipeline
|
||||
mock_pipeline.execute.return_value = [True, True, 1]
|
||||
|
||||
async with client.pipeline() as pipe:
|
||||
pipe.set("key1", "value1")
|
||||
pipe.set("key2", "value2")
|
||||
pipe.delete("key3")
|
||||
results = await pipe.execute()
|
||||
|
||||
assert results == [True, True, 1]
|
||||
mock_redis_client.pipeline.assert_called_once()
|
||||
mock_pipeline.execute.assert_called_once()
|
||||
|
||||
|
||||
class TestRedisClientIntegration:
|
||||
"""Integration tests for Redis client."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_image_workflow(self, mock_redis_client, mock_frame):
|
||||
"""Test complete image storage workflow."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state and components
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
client.image_storage = RedisImageStorage(mock_redis_client)
|
||||
client.publisher = RedisPublisher(mock_redis_client)
|
||||
|
||||
# Mock Redis operations
|
||||
mock_redis_client.set.return_value = True
|
||||
mock_redis_client.expire.return_value = True
|
||||
mock_redis_client.publish.return_value = 2
|
||||
|
||||
with patch('cv2.imencode') as mock_imencode:
|
||||
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
|
||||
mock_imencode.return_value = (True, encoded_data)
|
||||
|
||||
# Store image
|
||||
store_result = await client.image_storage.store_image(
|
||||
"detection:camera001:1640995200:session123",
|
||||
mock_frame,
|
||||
expire_seconds=600
|
||||
)
|
||||
|
||||
# Publish detection event
|
||||
detection_event = {
|
||||
"camera_id": "camera001",
|
||||
"session_id": "session123",
|
||||
"detection_class": "car",
|
||||
"confidence": 0.95,
|
||||
"timestamp": 1640995200000
|
||||
}
|
||||
|
||||
publish_result = await client.publisher.publish("detections:camera001", detection_event)
|
||||
|
||||
assert store_result is True
|
||||
assert publish_result == 2
|
||||
|
||||
# Verify Redis operations
|
||||
mock_redis_client.set.assert_called_once()
|
||||
mock_redis_client.expire.assert_called_once()
|
||||
mock_redis_client.publish.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_recovery_and_reconnection(self):
|
||||
"""Test error recovery and reconnection."""
|
||||
config = RedisConfig(host="localhost", retry_on_timeout=True)
|
||||
client = RedisClient(config)
|
||||
|
||||
with patch.object(client.connection_pool, 'connect', new_callable=AsyncMock) as mock_connect:
|
||||
with patch.object(client.connection_pool, 'health_check') as mock_health_check:
|
||||
# First health check fails, second succeeds
|
||||
mock_health_check.side_effect = [False, True]
|
||||
|
||||
# First connection attempt fails, second succeeds
|
||||
mock_connect.side_effect = [RedisError("Connection failed"), None]
|
||||
|
||||
# Simulate connection recovery
|
||||
try:
|
||||
await client.connect()
|
||||
except RedisError:
|
||||
# Retry connection
|
||||
await client.connect()
|
||||
|
||||
assert mock_connect.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_operations_performance(self, mock_redis_client):
|
||||
"""Test bulk operations for performance."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
client.publisher = RedisPublisher(mock_redis_client)
|
||||
|
||||
# Mock pipeline operations
|
||||
mock_pipeline = Mock()
|
||||
mock_redis_client.pipeline.return_value = mock_pipeline
|
||||
mock_pipeline.execute.return_value = [1] * 100 # 100 successful operations
|
||||
|
||||
# Prepare bulk messages
|
||||
messages = [
|
||||
(f"channel_{i}", f"message_{i}")
|
||||
for i in range(100)
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
results = await client.publisher.publish_batch(messages)
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
assert len(results) == 100
|
||||
assert all(result == 1 for result in results)
|
||||
|
||||
# Should be faster than individual operations
|
||||
assert execution_time < 1.0 # Should complete in less than 1 second
|
||||
|
||||
# Pipeline should be used for efficiency
|
||||
mock_redis_client.pipeline.assert_called_once()
|
||||
assert mock_pipeline.publish.call_count == 100
|
||||
mock_pipeline.execute.assert_called_once()
|
883
tests/unit/storage/test_session_cache.py
Normal file
883
tests/unit/storage/test_session_cache.py
Normal file
|
@ -0,0 +1,883 @@
|
|||
"""
|
||||
Unit tests for session cache management.
|
||||
"""
|
||||
import pytest
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
|
||||
from detector_worker.storage.session_cache import (
|
||||
SessionCache,
|
||||
SessionCacheManager,
|
||||
SessionData,
|
||||
CacheConfig,
|
||||
CacheEntry,
|
||||
CacheStats,
|
||||
SessionError,
|
||||
CacheError
|
||||
)
|
||||
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
|
||||
|
||||
|
||||
class TestCacheConfig:
|
||||
"""Test cache configuration."""
|
||||
|
||||
def test_creation_default(self):
|
||||
"""Test creating cache config with default values."""
|
||||
config = CacheConfig()
|
||||
|
||||
assert config.max_size == 1000
|
||||
assert config.ttl_seconds == 3600 # 1 hour
|
||||
assert config.cleanup_interval == 300 # 5 minutes
|
||||
assert config.eviction_policy == "lru"
|
||||
assert config.enable_persistence is False
|
||||
|
||||
def test_creation_custom(self):
|
||||
"""Test creating cache config with custom values."""
|
||||
config = CacheConfig(
|
||||
max_size=5000,
|
||||
ttl_seconds=7200,
|
||||
cleanup_interval=600,
|
||||
eviction_policy="lfu",
|
||||
enable_persistence=True,
|
||||
persistence_path="/tmp/cache"
|
||||
)
|
||||
|
||||
assert config.max_size == 5000
|
||||
assert config.ttl_seconds == 7200
|
||||
assert config.cleanup_interval == 600
|
||||
assert config.eviction_policy == "lfu"
|
||||
assert config.enable_persistence is True
|
||||
assert config.persistence_path == "/tmp/cache"
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test creating config from dictionary."""
|
||||
config_dict = {
|
||||
"max_size": 2000,
|
||||
"ttl_seconds": 1800,
|
||||
"eviction_policy": "fifo",
|
||||
"enable_persistence": True,
|
||||
"unknown_field": "ignored"
|
||||
}
|
||||
|
||||
config = CacheConfig.from_dict(config_dict)
|
||||
|
||||
assert config.max_size == 2000
|
||||
assert config.ttl_seconds == 1800
|
||||
assert config.eviction_policy == "fifo"
|
||||
assert config.enable_persistence is True
|
||||
|
||||
|
||||
class TestCacheEntry:
|
||||
"""Test cache entry data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test cache entry creation."""
|
||||
data = {"key": "value", "number": 42}
|
||||
entry = CacheEntry(data, ttl_seconds=600)
|
||||
|
||||
assert entry.data == data
|
||||
assert entry.ttl_seconds == 600
|
||||
assert entry.created_at <= time.time()
|
||||
assert entry.last_accessed <= time.time()
|
||||
assert entry.access_count == 1
|
||||
assert entry.size > 0
|
||||
|
||||
def test_is_expired(self):
|
||||
"""Test expiration checking."""
|
||||
# Non-expired entry
|
||||
entry = CacheEntry({"data": "test"}, ttl_seconds=600)
|
||||
assert entry.is_expired() is False
|
||||
|
||||
# Expired entry (simulate by setting old creation time)
|
||||
entry.created_at = time.time() - 700 # Created 700 seconds ago
|
||||
assert entry.is_expired() is True
|
||||
|
||||
# Entry without expiration
|
||||
entry_no_ttl = CacheEntry({"data": "test"})
|
||||
assert entry_no_ttl.is_expired() is False
|
||||
|
||||
def test_touch(self):
|
||||
"""Test updating access time and count."""
|
||||
entry = CacheEntry({"data": "test"})
|
||||
|
||||
original_access_time = entry.last_accessed
|
||||
original_access_count = entry.access_count
|
||||
|
||||
time.sleep(0.01) # Small delay
|
||||
entry.touch()
|
||||
|
||||
assert entry.last_accessed > original_access_time
|
||||
assert entry.access_count == original_access_count + 1
|
||||
|
||||
def test_age(self):
|
||||
"""Test age calculation."""
|
||||
entry = CacheEntry({"data": "test"})
|
||||
|
||||
time.sleep(0.01) # Small delay
|
||||
age = entry.age()
|
||||
|
||||
assert age > 0
|
||||
assert age < 1 # Should be less than 1 second
|
||||
|
||||
def test_size_estimation(self):
|
||||
"""Test size estimation."""
|
||||
small_entry = CacheEntry({"key": "value"})
|
||||
large_entry = CacheEntry({"key": "x" * 1000, "data": list(range(100))})
|
||||
|
||||
assert large_entry.size > small_entry.size
|
||||
|
||||
|
||||
class TestSessionData:
|
||||
"""Test session data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test session data creation."""
|
||||
session_data = SessionData(
|
||||
session_id="session_123",
|
||||
camera_id="camera_001",
|
||||
display_id="display_001"
|
||||
)
|
||||
|
||||
assert session_data.session_id == "session_123"
|
||||
assert session_data.camera_id == "camera_001"
|
||||
assert session_data.display_id == "display_001"
|
||||
assert session_data.created_at <= time.time()
|
||||
assert session_data.last_activity <= time.time()
|
||||
assert session_data.detection_data == {}
|
||||
assert session_data.metadata == {}
|
||||
|
||||
def test_update_activity(self):
|
||||
"""Test updating last activity."""
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
|
||||
original_activity = session_data.last_activity
|
||||
time.sleep(0.01)
|
||||
session_data.update_activity()
|
||||
|
||||
assert session_data.last_activity > original_activity
|
||||
|
||||
def test_add_detection_data(self):
|
||||
"""Test adding detection data."""
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
|
||||
detection_data = {
|
||||
"class": "car",
|
||||
"confidence": 0.95,
|
||||
"bbox": [100, 200, 300, 400]
|
||||
}
|
||||
|
||||
session_data.add_detection_data("main_detection", detection_data)
|
||||
|
||||
assert "main_detection" in session_data.detection_data
|
||||
assert session_data.detection_data["main_detection"] == detection_data
|
||||
|
||||
def test_add_metadata(self):
|
||||
"""Test adding metadata."""
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
|
||||
session_data.add_metadata("model_version", "v2.1")
|
||||
session_data.add_metadata("inference_time", 0.15)
|
||||
|
||||
assert session_data.metadata["model_version"] == "v2.1"
|
||||
assert session_data.metadata["inference_time"] == 0.15
|
||||
|
||||
def test_is_expired(self):
|
||||
"""Test session expiration."""
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
|
||||
# Not expired with default timeout
|
||||
assert session_data.is_expired() is False
|
||||
|
||||
# Expired with short timeout
|
||||
assert session_data.is_expired(timeout_seconds=0.001) is True
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test converting session to dictionary."""
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
session_data.add_detection_data("detection", {"class": "car", "confidence": 0.9})
|
||||
session_data.add_metadata("model_id", "yolo_v8")
|
||||
|
||||
data_dict = session_data.to_dict()
|
||||
|
||||
assert data_dict["session_id"] == "session_123"
|
||||
assert data_dict["camera_id"] == "camera_001"
|
||||
assert data_dict["detection_data"]["detection"]["class"] == "car"
|
||||
assert data_dict["metadata"]["model_id"] == "yolo_v8"
|
||||
assert "created_at" in data_dict
|
||||
assert "last_activity" in data_dict
|
||||
|
||||
|
||||
class TestCacheStats:
|
||||
"""Test cache statistics."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test cache stats creation."""
|
||||
stats = CacheStats()
|
||||
|
||||
assert stats.hits == 0
|
||||
assert stats.misses == 0
|
||||
assert stats.evictions == 0
|
||||
assert stats.size == 0
|
||||
assert stats.memory_usage == 0
|
||||
|
||||
def test_hit_rate_calculation(self):
|
||||
"""Test hit rate calculation."""
|
||||
stats = CacheStats()
|
||||
|
||||
# No requests yet
|
||||
assert stats.hit_rate() == 0.0
|
||||
|
||||
# Some hits and misses
|
||||
stats.hits = 8
|
||||
stats.misses = 2
|
||||
|
||||
assert stats.hit_rate() == 0.8 # 8 / (8 + 2)
|
||||
|
||||
def test_total_requests(self):
|
||||
"""Test total requests calculation."""
|
||||
stats = CacheStats()
|
||||
|
||||
stats.hits = 15
|
||||
stats.misses = 5
|
||||
|
||||
assert stats.total_requests() == 20
|
||||
|
||||
|
||||
class TestSessionCache:
|
||||
"""Test session cache functionality."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test session cache creation."""
|
||||
config = CacheConfig(max_size=100, ttl_seconds=300)
|
||||
cache = SessionCache(config)
|
||||
|
||||
assert cache.config == config
|
||||
assert cache.max_size == 100
|
||||
assert cache.ttl_seconds == 300
|
||||
assert len(cache._cache) == 0
|
||||
assert len(cache._access_order) == 0
|
||||
|
||||
def test_put_and_get_session(self):
|
||||
"""Test putting and getting session data."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
session_data.add_detection_data("main", {"class": "car", "confidence": 0.9})
|
||||
|
||||
# Put session
|
||||
cache.put("session_123", session_data)
|
||||
|
||||
# Get session
|
||||
retrieved_data = cache.get("session_123")
|
||||
|
||||
assert retrieved_data is not None
|
||||
assert retrieved_data.session_id == "session_123"
|
||||
assert retrieved_data.camera_id == "camera_001"
|
||||
assert "main" in retrieved_data.detection_data
|
||||
|
||||
def test_get_nonexistent_session(self):
|
||||
"""Test getting non-existent session."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
result = cache.get("nonexistent_session")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_contains_check(self):
|
||||
"""Test checking if session exists."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
cache.put("session_123", session_data)
|
||||
|
||||
assert cache.contains("session_123") is True
|
||||
assert cache.contains("nonexistent_session") is False
|
||||
|
||||
def test_remove_session(self):
|
||||
"""Test removing session."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
cache.put("session_123", session_data)
|
||||
|
||||
assert cache.contains("session_123") is True
|
||||
|
||||
removed_data = cache.remove("session_123")
|
||||
|
||||
assert removed_data is not None
|
||||
assert removed_data.session_id == "session_123"
|
||||
assert cache.contains("session_123") is False
|
||||
|
||||
def test_size_tracking(self):
|
||||
"""Test cache size tracking."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
assert cache.size() == 0
|
||||
assert cache.is_empty() is True
|
||||
|
||||
# Add sessions
|
||||
for i in range(3):
|
||||
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
|
||||
cache.put(f"session_{i}", session_data)
|
||||
|
||||
assert cache.size() == 3
|
||||
assert cache.is_empty() is False
|
||||
|
||||
def test_lru_eviction(self):
|
||||
"""Test LRU eviction policy."""
|
||||
cache = SessionCache(CacheConfig(max_size=3, eviction_policy="lru"))
|
||||
|
||||
# Fill cache to capacity
|
||||
for i in range(3):
|
||||
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
|
||||
cache.put(f"session_{i}", session_data)
|
||||
|
||||
# Access session_1 to make it recently used
|
||||
cache.get("session_1")
|
||||
|
||||
# Add another session (should evict session_0, the least recently used)
|
||||
new_session = SessionData("session_3", "camera_001", "display_001")
|
||||
cache.put("session_3", new_session)
|
||||
|
||||
assert cache.size() == 3
|
||||
assert cache.contains("session_0") is False # Evicted
|
||||
assert cache.contains("session_1") is True # Recently accessed
|
||||
assert cache.contains("session_2") is True
|
||||
assert cache.contains("session_3") is True # Newly added
|
||||
|
||||
def test_ttl_expiration(self):
|
||||
"""Test TTL-based expiration."""
|
||||
cache = SessionCache(CacheConfig(max_size=10, ttl_seconds=0.1)) # 100ms TTL
|
||||
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
cache.put("session_123", session_data)
|
||||
|
||||
# Should exist immediately
|
||||
assert cache.contains("session_123") is True
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(0.2)
|
||||
|
||||
# Should be expired (but might still be in cache until cleanup)
|
||||
entry = cache._cache.get("session_123")
|
||||
if entry:
|
||||
assert entry.is_expired() is True
|
||||
|
||||
# Getting expired entry should return None and clean it up
|
||||
retrieved = cache.get("session_123")
|
||||
assert retrieved is None
|
||||
assert cache.contains("session_123") is False
|
||||
|
||||
def test_cleanup_expired_entries(self):
|
||||
"""Test cleanup of expired entries."""
|
||||
cache = SessionCache(CacheConfig(max_size=10, ttl_seconds=0.1))
|
||||
|
||||
# Add multiple sessions
|
||||
for i in range(3):
|
||||
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
|
||||
cache.put(f"session_{i}", session_data)
|
||||
|
||||
assert cache.size() == 3
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(0.2)
|
||||
|
||||
# Cleanup expired entries
|
||||
cleaned_count = cache.cleanup_expired()
|
||||
|
||||
assert cleaned_count == 3
|
||||
assert cache.size() == 0
|
||||
|
||||
def test_clear_cache(self):
|
||||
"""Test clearing entire cache."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
# Add sessions
|
||||
for i in range(5):
|
||||
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
|
||||
cache.put(f"session_{i}", session_data)
|
||||
|
||||
assert cache.size() == 5
|
||||
|
||||
cache.clear()
|
||||
|
||||
assert cache.size() == 0
|
||||
assert cache.is_empty() is True
|
||||
|
||||
def test_get_all_sessions(self):
|
||||
"""Test getting all sessions."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
sessions = []
|
||||
for i in range(3):
|
||||
session_data = SessionData(f"session_{i}", f"camera_{i}", "display_001")
|
||||
cache.put(f"session_{i}", session_data)
|
||||
sessions.append(session_data)
|
||||
|
||||
all_sessions = cache.get_all()
|
||||
|
||||
assert len(all_sessions) == 3
|
||||
for session_id, session_data in all_sessions.items():
|
||||
assert session_id.startswith("session_")
|
||||
assert session_data.session_id == session_id
|
||||
|
||||
def test_get_sessions_by_camera(self):
|
||||
"""Test getting sessions by camera ID."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
# Add sessions for different cameras
|
||||
for i in range(2):
|
||||
session_data1 = SessionData(f"session_cam1_{i}", "camera_001", "display_001")
|
||||
session_data2 = SessionData(f"session_cam2_{i}", "camera_002", "display_001")
|
||||
cache.put(f"session_cam1_{i}", session_data1)
|
||||
cache.put(f"session_cam2_{i}", session_data2)
|
||||
|
||||
camera1_sessions = cache.get_by_camera("camera_001")
|
||||
camera2_sessions = cache.get_by_camera("camera_002")
|
||||
|
||||
assert len(camera1_sessions) == 2
|
||||
assert len(camera2_sessions) == 2
|
||||
|
||||
for session_data in camera1_sessions:
|
||||
assert session_data.camera_id == "camera_001"
|
||||
|
||||
for session_data in camera2_sessions:
|
||||
assert session_data.camera_id == "camera_002"
|
||||
|
||||
def test_statistics_tracking(self):
|
||||
"""Test cache statistics tracking."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
cache.put("session_123", session_data)
|
||||
|
||||
# Cache miss
|
||||
cache.get("nonexistent_session")
|
||||
|
||||
# Cache hit
|
||||
cache.get("session_123")
|
||||
cache.get("session_123") # Another hit
|
||||
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats.hits == 2
|
||||
assert stats.misses == 1
|
||||
assert stats.size == 1
|
||||
assert stats.hit_rate() == 2/3 # 2 hits out of 3 total requests
|
||||
|
||||
def test_memory_usage_estimation(self):
|
||||
"""Test memory usage estimation."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
initial_memory = cache.get_memory_usage()
|
||||
|
||||
# Add large session
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
session_data.add_detection_data("large_data", {"data": "x" * 1000})
|
||||
cache.put("session_123", session_data)
|
||||
|
||||
after_memory = cache.get_memory_usage()
|
||||
|
||||
assert after_memory > initial_memory
|
||||
|
||||
|
||||
class TestSessionCacheManager:
|
||||
"""Test session cache manager."""
|
||||
|
||||
def test_singleton_behavior(self):
|
||||
"""Test that SessionCacheManager is a singleton."""
|
||||
manager1 = SessionCacheManager()
|
||||
manager2 = SessionCacheManager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test session cache manager initialization."""
|
||||
manager = SessionCacheManager()
|
||||
|
||||
assert manager.detection_cache is not None
|
||||
assert manager.pipeline_cache is not None
|
||||
assert manager.session_cache is not None
|
||||
assert isinstance(manager.detection_cache, SessionCache)
|
||||
assert isinstance(manager.pipeline_cache, SessionCache)
|
||||
assert isinstance(manager.session_cache, SessionCache)
|
||||
|
||||
def test_cache_detection_result(self):
|
||||
"""Test caching detection results."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all() # Start fresh
|
||||
|
||||
detection_data = {
|
||||
"class": "car",
|
||||
"confidence": 0.95,
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"track_id": 1001
|
||||
}
|
||||
|
||||
manager.cache_detection("camera_001", detection_data)
|
||||
|
||||
cached_detection = manager.get_cached_detection("camera_001")
|
||||
|
||||
assert cached_detection is not None
|
||||
assert cached_detection["class"] == "car"
|
||||
assert cached_detection["confidence"] == 0.95
|
||||
assert cached_detection["track_id"] == 1001
|
||||
|
||||
def test_cache_pipeline_result(self):
|
||||
"""Test caching pipeline results."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
pipeline_result = {
|
||||
"status": "success",
|
||||
"detections": [{"class": "car", "confidence": 0.9}],
|
||||
"execution_time": 0.15,
|
||||
"model_id": "yolo_v8"
|
||||
}
|
||||
|
||||
manager.cache_pipeline_result("camera_001", pipeline_result)
|
||||
|
||||
cached_result = manager.get_cached_pipeline_result("camera_001")
|
||||
|
||||
assert cached_result is not None
|
||||
assert cached_result["status"] == "success"
|
||||
assert cached_result["execution_time"] == 0.15
|
||||
assert len(cached_result["detections"]) == 1
|
||||
|
||||
def test_manage_session_data(self):
|
||||
"""Test session data management."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Create session
|
||||
manager.create_session(session_id, "camera_001", {"initial": "data"})
|
||||
|
||||
# Update session
|
||||
manager.update_session_detection(session_id, {"car_brand": "Toyota"})
|
||||
|
||||
# Get session
|
||||
session_data = manager.get_session_detection(session_id)
|
||||
|
||||
assert session_data is not None
|
||||
assert "initial" in session_data
|
||||
assert session_data["car_brand"] == "Toyota"
|
||||
|
||||
def test_set_latest_frame(self):
|
||||
"""Test setting and getting latest frame."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
frame_data = b"fake_frame_data"
|
||||
|
||||
manager.set_latest_frame("camera_001", frame_data)
|
||||
|
||||
retrieved_frame = manager.get_latest_frame("camera_001")
|
||||
|
||||
assert retrieved_frame == frame_data
|
||||
|
||||
def test_frame_skip_flag_management(self):
|
||||
"""Test frame skip flag management."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
# Initially should be False
|
||||
assert manager.get_frame_skip_flag("camera_001") is False
|
||||
|
||||
# Set to True
|
||||
manager.set_frame_skip_flag("camera_001", True)
|
||||
assert manager.get_frame_skip_flag("camera_001") is True
|
||||
|
||||
# Set back to False
|
||||
manager.set_frame_skip_flag("camera_001", False)
|
||||
assert manager.get_frame_skip_flag("camera_001") is False
|
||||
|
||||
def test_cleanup_expired_sessions(self):
|
||||
"""Test cleanup of expired sessions."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
# Create sessions with short TTL
|
||||
manager.session_cache = SessionCache(CacheConfig(max_size=10, ttl_seconds=0.1))
|
||||
|
||||
# Add sessions
|
||||
for i in range(3):
|
||||
session_id = f"session_{i}"
|
||||
manager.create_session(session_id, "camera_001", {"test": "data"})
|
||||
|
||||
assert manager.session_cache.size() == 3
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(0.2)
|
||||
|
||||
# Cleanup
|
||||
expired_count = manager.cleanup_expired_sessions()
|
||||
|
||||
assert expired_count == 3
|
||||
assert manager.session_cache.size() == 0
|
||||
|
||||
def test_clear_camera_cache(self):
|
||||
"""Test clearing cache for specific camera."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
# Add data for multiple cameras
|
||||
manager.cache_detection("camera_001", {"class": "car"})
|
||||
manager.cache_detection("camera_002", {"class": "truck"})
|
||||
manager.cache_pipeline_result("camera_001", {"status": "success"})
|
||||
manager.set_latest_frame("camera_001", b"frame1")
|
||||
manager.set_latest_frame("camera_002", b"frame2")
|
||||
|
||||
# Clear camera_001 cache
|
||||
manager.clear_camera_cache("camera_001")
|
||||
|
||||
# camera_001 data should be gone
|
||||
assert manager.get_cached_detection("camera_001") is None
|
||||
assert manager.get_cached_pipeline_result("camera_001") is None
|
||||
assert manager.get_latest_frame("camera_001") is None
|
||||
|
||||
# camera_002 data should remain
|
||||
assert manager.get_cached_detection("camera_002") is not None
|
||||
assert manager.get_latest_frame("camera_002") is not None
|
||||
|
||||
def test_get_cache_statistics(self):
|
||||
"""Test getting cache statistics."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
# Add some data to generate statistics
|
||||
manager.cache_detection("camera_001", {"class": "car"})
|
||||
manager.cache_pipeline_result("camera_001", {"status": "success"})
|
||||
manager.create_session("session_123", "camera_001", {"initial": "data"})
|
||||
|
||||
# Access data to generate hits/misses
|
||||
manager.get_cached_detection("camera_001") # Hit
|
||||
manager.get_cached_detection("camera_002") # Miss
|
||||
|
||||
stats = manager.get_cache_statistics()
|
||||
|
||||
assert "detection_cache" in stats
|
||||
assert "pipeline_cache" in stats
|
||||
assert "session_cache" in stats
|
||||
assert "total_memory_usage" in stats
|
||||
|
||||
detection_stats = stats["detection_cache"]
|
||||
assert detection_stats["size"] >= 1
|
||||
assert detection_stats["hits"] >= 1
|
||||
assert detection_stats["misses"] >= 1
|
||||
|
||||
def test_memory_pressure_handling(self):
|
||||
"""Test handling memory pressure."""
|
||||
# Create manager with small cache sizes
|
||||
config = CacheConfig(max_size=3)
|
||||
manager = SessionCacheManager()
|
||||
manager.detection_cache = SessionCache(config)
|
||||
manager.pipeline_cache = SessionCache(config)
|
||||
manager.session_cache = SessionCache(config)
|
||||
|
||||
# Fill caches beyond capacity
|
||||
for i in range(5):
|
||||
manager.cache_detection(f"camera_{i}", {"class": "car", "data": "x" * 100})
|
||||
manager.cache_pipeline_result(f"camera_{i}", {"status": "success", "data": "y" * 100})
|
||||
manager.create_session(f"session_{i}", f"camera_{i}", {"data": "z" * 100})
|
||||
|
||||
# Caches should not exceed max size due to eviction
|
||||
assert manager.detection_cache.size() <= 3
|
||||
assert manager.pipeline_cache.size() <= 3
|
||||
assert manager.session_cache.size() <= 3
|
||||
|
||||
def test_concurrent_access_thread_safety(self):
|
||||
"""Test thread safety of concurrent cache access."""
|
||||
import threading
|
||||
import concurrent.futures
|
||||
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def cache_operation(thread_id):
|
||||
try:
|
||||
# Each thread performs multiple cache operations
|
||||
for i in range(10):
|
||||
session_id = f"thread_{thread_id}_session_{i}"
|
||||
|
||||
# Create session
|
||||
manager.create_session(session_id, f"camera_{thread_id}", {"thread": thread_id, "index": i})
|
||||
|
||||
# Update session
|
||||
manager.update_session_detection(session_id, {"updated": True})
|
||||
|
||||
# Read session
|
||||
data = manager.get_session_detection(session_id)
|
||||
if data and data.get("thread") == thread_id:
|
||||
results.append((thread_id, i))
|
||||
|
||||
except Exception as e:
|
||||
errors.append((thread_id, str(e)))
|
||||
|
||||
# Run operations concurrently
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(cache_operation, i) for i in range(5)]
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
# Should have no errors and successful operations
|
||||
assert len(errors) == 0
|
||||
assert len(results) >= 25 # At least some operations should succeed
|
||||
|
||||
|
||||
class TestSessionCacheIntegration:
|
||||
"""Integration tests for session cache."""
|
||||
|
||||
def test_complete_detection_workflow(self):
|
||||
"""Test complete detection workflow with caching."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
camera_id = "camera_001"
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# 1. Cache initial detection
|
||||
detection_data = {
|
||||
"class": "car",
|
||||
"confidence": 0.92,
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"track_id": 1001,
|
||||
"timestamp": int(time.time() * 1000)
|
||||
}
|
||||
|
||||
manager.cache_detection(camera_id, detection_data)
|
||||
|
||||
# 2. Create session for tracking
|
||||
initial_session_data = {
|
||||
"detection_class": detection_data["class"],
|
||||
"confidence": detection_data["confidence"],
|
||||
"track_id": detection_data["track_id"]
|
||||
}
|
||||
|
||||
manager.create_session(session_id, camera_id, initial_session_data)
|
||||
|
||||
# 3. Cache pipeline processing result
|
||||
pipeline_result = {
|
||||
"status": "processing",
|
||||
"stage": "classification",
|
||||
"detections": [detection_data],
|
||||
"branches_completed": [],
|
||||
"branches_pending": ["car_brand_cls", "car_bodytype_cls"]
|
||||
}
|
||||
|
||||
manager.cache_pipeline_result(camera_id, pipeline_result)
|
||||
|
||||
# 4. Update session with classification results
|
||||
classification_updates = [
|
||||
{"car_brand": "Toyota", "brand_confidence": 0.87},
|
||||
{"car_body_type": "Sedan", "bodytype_confidence": 0.82}
|
||||
]
|
||||
|
||||
for update in classification_updates:
|
||||
manager.update_session_detection(session_id, update)
|
||||
|
||||
# 5. Update pipeline result to completed
|
||||
final_pipeline_result = {
|
||||
"status": "completed",
|
||||
"stage": "finished",
|
||||
"detections": [detection_data],
|
||||
"branches_completed": ["car_brand_cls", "car_bodytype_cls"],
|
||||
"branches_pending": [],
|
||||
"execution_time": 0.25
|
||||
}
|
||||
|
||||
manager.cache_pipeline_result(camera_id, final_pipeline_result)
|
||||
|
||||
# 6. Verify all cached data
|
||||
cached_detection = manager.get_cached_detection(camera_id)
|
||||
cached_pipeline = manager.get_cached_pipeline_result(camera_id)
|
||||
cached_session = manager.get_session_detection(session_id)
|
||||
|
||||
# Assertions
|
||||
assert cached_detection["class"] == "car"
|
||||
assert cached_detection["track_id"] == 1001
|
||||
|
||||
assert cached_pipeline["status"] == "completed"
|
||||
assert len(cached_pipeline["branches_completed"]) == 2
|
||||
|
||||
assert cached_session["detection_class"] == "car"
|
||||
assert cached_session["car_brand"] == "Toyota"
|
||||
assert cached_session["car_body_type"] == "Sedan"
|
||||
assert cached_session["brand_confidence"] == 0.87
|
||||
|
||||
def test_cache_performance_under_load(self):
|
||||
"""Test cache performance under load."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
import time
|
||||
|
||||
# Measure performance of cache operations
|
||||
start_time = time.time()
|
||||
|
||||
# Perform many cache operations
|
||||
for i in range(1000):
|
||||
camera_id = f"camera_{i % 10}" # 10 different cameras
|
||||
session_id = f"session_{i}"
|
||||
|
||||
# Cache detection
|
||||
detection_data = {
|
||||
"class": "car",
|
||||
"confidence": 0.9 + (i % 10) * 0.01,
|
||||
"track_id": i,
|
||||
"bbox": [i % 100, i % 100, (i % 100) + 200, (i % 100) + 200]
|
||||
}
|
||||
manager.cache_detection(camera_id, detection_data)
|
||||
|
||||
# Create session
|
||||
manager.create_session(session_id, camera_id, {"index": i})
|
||||
|
||||
# Read back (every 10th operation)
|
||||
if i % 10 == 0:
|
||||
manager.get_cached_detection(camera_id)
|
||||
manager.get_session_detection(session_id)
|
||||
|
||||
end_time = time.time()
|
||||
total_time = end_time - start_time
|
||||
|
||||
# Should complete in reasonable time (less than 1 second)
|
||||
assert total_time < 1.0
|
||||
|
||||
# Verify cache statistics
|
||||
stats = manager.get_cache_statistics()
|
||||
assert stats["detection_cache"]["size"] > 0
|
||||
assert stats["session_cache"]["size"] > 0
|
||||
assert stats["detection_cache"]["hits"] > 0
|
||||
|
||||
def test_cache_persistence_and_recovery(self):
|
||||
"""Test cache persistence and recovery (if enabled)."""
|
||||
# This test would be more meaningful with actual persistence
|
||||
# For now, test the configuration and structure
|
||||
|
||||
persistence_config = CacheConfig(
|
||||
max_size=100,
|
||||
enable_persistence=True,
|
||||
persistence_path="/tmp/detector_cache_test"
|
||||
)
|
||||
|
||||
cache = SessionCache(persistence_config)
|
||||
|
||||
# Add some data
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
session_data.add_detection_data("main", {"class": "car", "confidence": 0.95})
|
||||
|
||||
cache.put("session_123", session_data)
|
||||
|
||||
# Verify data exists
|
||||
assert cache.contains("session_123") is True
|
||||
|
||||
# In a real implementation, this would test:
|
||||
# 1. Saving cache to disk
|
||||
# 2. Loading cache from disk
|
||||
# 3. Verifying data integrity after reload
|
Loading…
Add table
Add a link
Reference in a new issue