Refactor: PHASE 8: Testing & Integration

This commit is contained in:
ziesorx 2025-09-12 18:55:23 +07:00
parent af34f4fd08
commit 9e8c6804a7
32 changed files with 17128 additions and 0 deletions

View file

@ -0,0 +1,976 @@
"""
Unit tests for database management functionality.
"""
import pytest
import asyncio
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from datetime import datetime, timedelta
import psycopg2
import uuid
from detector_worker.storage.database_manager import (
DatabaseManager,
DatabaseConfig,
DatabaseConnection,
QueryBuilder,
TransactionManager,
DatabaseError,
ConnectionPoolError
)
from detector_worker.core.exceptions import ConfigurationError
class TestDatabaseConfig:
"""Test database configuration."""
def test_creation_minimal(self):
"""Test creating database config with minimal parameters."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
assert config.host == "localhost"
assert config.port == 5432 # Default port
assert config.database == "test_db"
assert config.username == "test_user"
assert config.password == "test_pass"
assert config.schema == "public" # Default schema
assert config.enabled is True
def test_creation_full(self):
"""Test creating database config with all parameters."""
config = DatabaseConfig(
host="db.example.com",
port=5433,
database="production_db",
username="prod_user",
password="secure_pass",
schema="gas_station_1",
enabled=True,
pool_min_conn=2,
pool_max_conn=20,
pool_timeout=30.0,
connection_timeout=10.0,
ssl_mode="require"
)
assert config.host == "db.example.com"
assert config.port == 5433
assert config.database == "production_db"
assert config.schema == "gas_station_1"
assert config.pool_min_conn == 2
assert config.pool_max_conn == 20
assert config.ssl_mode == "require"
def test_get_connection_string(self):
"""Test generating connection string."""
config = DatabaseConfig(
host="localhost",
port=5432,
database="test_db",
username="test_user",
password="test_pass"
)
conn_string = config.get_connection_string()
expected = "host=localhost port=5432 database=test_db user=test_user password=test_pass"
assert conn_string == expected
def test_get_connection_string_with_ssl(self):
"""Test generating connection string with SSL."""
config = DatabaseConfig(
host="db.example.com",
database="secure_db",
username="user",
password="pass",
ssl_mode="require"
)
conn_string = config.get_connection_string()
assert "sslmode=require" in conn_string
def test_from_dict(self):
"""Test creating config from dictionary."""
config_dict = {
"host": "test-host",
"port": 5433,
"database": "test-db",
"username": "test-user",
"password": "test-pass",
"schema": "test_schema",
"pool_max_conn": 15,
"unknown_field": "ignored"
}
config = DatabaseConfig.from_dict(config_dict)
assert config.host == "test-host"
assert config.port == 5433
assert config.database == "test-db"
assert config.schema == "test_schema"
assert config.pool_max_conn == 15
class TestQueryBuilder:
"""Test SQL query building functionality."""
def test_build_select_query(self):
"""Test building SELECT queries."""
builder = QueryBuilder("test_schema")
query, params = builder.build_select_query(
table="users",
columns=["id", "name", "email"],
where={"status": "active", "age": 25},
order_by="created_at DESC",
limit=10
)
expected_query = (
"SELECT id, name, email FROM test_schema.users "
"WHERE status = %s AND age = %s "
"ORDER BY created_at DESC LIMIT 10"
)
assert query == expected_query
assert params == ["active", 25]
def test_build_select_all_columns(self):
"""Test building SELECT * query."""
builder = QueryBuilder("public")
query, params = builder.build_select_query("products")
expected_query = "SELECT * FROM public.products"
assert query == expected_query
assert params == []
def test_build_insert_query(self):
"""Test building INSERT queries."""
builder = QueryBuilder("inventory")
data = {
"product_name": "Widget",
"price": 19.99,
"quantity": 100,
"created_at": "NOW()"
}
query, params = builder.build_insert_query("products", data)
expected_query = (
"INSERT INTO inventory.products (product_name, price, quantity, created_at) "
"VALUES (%s, %s, %s, NOW()) RETURNING id"
)
assert query == expected_query
assert params == ["Widget", 19.99, 100]
def test_build_update_query(self):
"""Test building UPDATE queries."""
builder = QueryBuilder("sales")
data = {
"status": "shipped",
"shipped_date": "NOW()",
"tracking_number": "ABC123"
}
where_conditions = {"order_id": 12345}
query, params = builder.build_update_query("orders", data, where_conditions)
expected_query = (
"UPDATE sales.orders SET status = %s, shipped_date = NOW(), tracking_number = %s "
"WHERE order_id = %s"
)
assert query == expected_query
assert params == ["shipped", "ABC123", 12345]
def test_build_delete_query(self):
"""Test building DELETE queries."""
builder = QueryBuilder("logs")
where_conditions = {
"level": "DEBUG",
"created_at": "< NOW() - INTERVAL '7 days'"
}
query, params = builder.build_delete_query("application_logs", where_conditions)
expected_query = (
"DELETE FROM logs.application_logs "
"WHERE level = %s AND created_at < NOW() - INTERVAL '7 days'"
)
assert query == expected_query
assert params == ["DEBUG"]
def test_build_create_table_query(self):
"""Test building CREATE TABLE queries."""
builder = QueryBuilder("gas_station_1")
columns = {
"id": "SERIAL PRIMARY KEY",
"session_id": "VARCHAR(255) UNIQUE NOT NULL",
"camera_id": "VARCHAR(255) NOT NULL",
"detection_class": "VARCHAR(100)",
"confidence": "DECIMAL(4,3)",
"bbox_data": "JSON",
"created_at": "TIMESTAMP DEFAULT NOW()",
"updated_at": "TIMESTAMP DEFAULT NOW()"
}
query = builder.build_create_table_query("detections", columns)
expected_parts = [
"CREATE TABLE IF NOT EXISTS gas_station_1.detections",
"id SERIAL PRIMARY KEY",
"session_id VARCHAR(255) UNIQUE NOT NULL",
"camera_id VARCHAR(255) NOT NULL",
"bbox_data JSON",
"created_at TIMESTAMP DEFAULT NOW()"
]
for part in expected_parts:
assert part in query
def test_escape_identifier(self):
"""Test SQL identifier escaping."""
builder = QueryBuilder("test")
assert builder.escape_identifier("table") == '"table"'
assert builder.escape_identifier("column_name") == '"column_name"'
assert builder.escape_identifier("user-table") == '"user-table"'
def test_format_value_for_sql(self):
"""Test SQL value formatting."""
builder = QueryBuilder("test")
# Regular values should use placeholder
assert builder.format_value_for_sql("string") == ("%s", "string")
assert builder.format_value_for_sql(42) == ("%s", 42)
assert builder.format_value_for_sql(3.14) == ("%s", 3.14)
# SQL functions should be literal
assert builder.format_value_for_sql("NOW()") == ("NOW()", None)
assert builder.format_value_for_sql("CURRENT_TIMESTAMP") == ("CURRENT_TIMESTAMP", None)
assert builder.format_value_for_sql("UUID()") == ("UUID()", None)
class TestDatabaseConnection:
"""Test database connection management."""
def test_creation(self, mock_database_connection):
"""Test connection creation."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
assert conn.config == config
assert conn.connection == mock_database_connection
assert conn.is_connected is True
def test_execute_query(self, mock_database_connection):
"""Test query execution."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchall.return_value = [
(1, "John", "john@example.com"),
(2, "Jane", "jane@example.com")
]
mock_cursor.rowcount = 2
conn = DatabaseConnection(config, mock_database_connection)
query = "SELECT id, name, email FROM users WHERE status = %s"
params = ["active"]
result = conn.execute_query(query, params)
assert result == [
(1, "John", "john@example.com"),
(2, "Jane", "jane@example.com")
]
mock_cursor.execute.assert_called_once_with(query, params)
mock_cursor.fetchall.assert_called_once()
def test_execute_query_single_result(self, mock_database_connection):
"""Test query execution with single result."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchone.return_value = (1, "John", "john@example.com")
conn = DatabaseConnection(config, mock_database_connection)
result = conn.execute_query("SELECT * FROM users WHERE id = %s", [1], fetch_one=True)
assert result == (1, "John", "john@example.com")
mock_cursor.fetchone.assert_called_once()
def test_execute_query_no_fetch(self, mock_database_connection):
"""Test query execution without fetching results."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.rowcount = 1
conn = DatabaseConnection(config, mock_database_connection)
result = conn.execute_query(
"INSERT INTO users (name) VALUES (%s)",
["John"],
fetch_results=False
)
assert result == 1 # Row count
mock_cursor.execute.assert_called_once()
mock_cursor.fetchall.assert_not_called()
mock_cursor.fetchone.assert_not_called()
def test_execute_query_error(self, mock_database_connection):
"""Test query execution error handling."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.execute.side_effect = psycopg2.Error("Database error")
conn = DatabaseConnection(config, mock_database_connection)
with pytest.raises(DatabaseError) as exc_info:
conn.execute_query("SELECT * FROM invalid_table")
assert "Database error" in str(exc_info.value)
def test_commit_transaction(self, mock_database_connection):
"""Test transaction commit."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
conn.commit()
mock_database_connection.commit.assert_called_once()
def test_rollback_transaction(self, mock_database_connection):
"""Test transaction rollback."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
conn.rollback()
mock_database_connection.rollback.assert_called_once()
def test_close_connection(self, mock_database_connection):
"""Test connection closing."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
conn.close()
assert conn.is_connected is False
mock_database_connection.close.assert_called_once()
class TestTransactionManager:
"""Test transaction management."""
def test_transaction_context_success(self, mock_database_connection):
"""Test successful transaction context."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
tx_manager = TransactionManager(conn)
with tx_manager:
# Simulate some database operations
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
# Should commit on successful exit
mock_database_connection.commit.assert_called_once()
mock_database_connection.rollback.assert_not_called()
def test_transaction_context_error(self, mock_database_connection):
"""Test transaction context with error."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
tx_manager = TransactionManager(conn)
with pytest.raises(DatabaseError):
with tx_manager:
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
# Simulate an error
raise DatabaseError("Something went wrong")
# Should rollback on error
mock_database_connection.rollback.assert_called_once()
mock_database_connection.commit.assert_not_called()
class TestDatabaseManager:
"""Test main database manager functionality."""
def test_initialization(self):
"""Test database manager initialization."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
assert manager.config == config
assert isinstance(manager.query_builder, QueryBuilder)
assert manager.query_builder.schema == "gas_station_1"
assert manager.connection is None
@pytest.mark.asyncio
async def test_connect_success(self):
"""Test successful database connection."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
with patch('psycopg2.connect') as mock_connect:
mock_connection = Mock()
mock_connect.return_value = mock_connection
await manager.connect()
assert manager.connection is not None
assert manager.is_connected is True
mock_connect.assert_called_once()
@pytest.mark.asyncio
async def test_connect_failure(self):
"""Test database connection failure."""
config = DatabaseConfig(
host="nonexistent-host",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
with patch('psycopg2.connect') as mock_connect:
mock_connect.side_effect = psycopg2.Error("Connection failed")
with pytest.raises(DatabaseError) as exc_info:
await manager.connect()
assert "Connection failed" in str(exc_info.value)
assert manager.is_connected is False
@pytest.mark.asyncio
async def test_disconnect(self):
"""Test database disconnection."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
# Mock connection
mock_connection = Mock()
manager.connection = DatabaseConnection(config, mock_connection)
await manager.disconnect()
assert manager.connection is None
mock_connection.close.assert_called_once()
@pytest.mark.asyncio
async def test_execute_query(self, mock_database_connection):
"""Test query execution through manager."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchall.return_value = [(1, "Test"), (2, "Data")]
result = await manager.execute_query("SELECT * FROM test_table")
assert result == [(1, "Test"), (2, "Data")]
mock_cursor.execute.assert_called_once()
@pytest.mark.asyncio
async def test_execute_query_not_connected(self):
"""Test query execution when not connected."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
with pytest.raises(DatabaseError) as exc_info:
await manager.execute_query("SELECT * FROM test_table")
assert "not connected" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_insert_record(self, mock_database_connection):
"""Test inserting a record."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchone.return_value = (123,) # Returned ID
data = {
"session_id": "session_123",
"camera_id": "camera_001",
"detection_class": "car",
"confidence": 0.95,
"created_at": "NOW()"
}
record_id = await manager.insert_record("car_detections", data)
assert record_id == 123
mock_cursor.execute.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_update_record(self, mock_database_connection):
"""Test updating a record."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.rowcount = 1
data = {
"car_brand": "Toyota",
"car_body_type": "Sedan",
"updated_at": "NOW()"
}
where_conditions = {"session_id": "session_123"}
rows_affected = await manager.update_record("car_info", data, where_conditions)
assert rows_affected == 1
mock_cursor.execute.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_delete_records(self, mock_database_connection):
"""Test deleting records."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.rowcount = 3
where_conditions = {
"created_at": "< NOW() - INTERVAL '30 days'",
"processed": True
}
rows_deleted = await manager.delete_records("old_detections", where_conditions)
assert rows_deleted == 3
mock_cursor.execute.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_create_table(self, mock_database_connection):
"""Test creating a table."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
columns = {
"id": "SERIAL PRIMARY KEY",
"session_id": "VARCHAR(255) UNIQUE NOT NULL",
"camera_id": "VARCHAR(255) NOT NULL",
"detection_data": "JSON",
"created_at": "TIMESTAMP DEFAULT NOW()"
}
await manager.create_table("test_detections", columns)
mock_database_connection.cursor.return_value.execute.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_table_exists(self, mock_database_connection):
"""Test checking if table exists."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior - table exists
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchone.return_value = (1,)
exists = await manager.table_exists("car_detections")
assert exists is True
mock_cursor.execute.assert_called_once()
# Mock cursor behavior - table doesn't exist
mock_cursor.fetchone.return_value = None
exists = await manager.table_exists("nonexistent_table")
assert exists is False
@pytest.mark.asyncio
async def test_transaction_context(self, mock_database_connection):
"""Test transaction context manager."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
async with manager.transaction():
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
# Should commit on successful completion
mock_database_connection.commit.assert_called()
@pytest.mark.asyncio
async def test_get_table_schema(self, mock_database_connection):
"""Test getting table schema information."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchall.return_value = [
("id", "integer", "NOT NULL"),
("session_id", "character varying", "NOT NULL"),
("created_at", "timestamp without time zone", "DEFAULT now()")
]
schema = await manager.get_table_schema("car_detections")
assert len(schema) == 3
assert schema[0] == ("id", "integer", "NOT NULL")
assert schema[1] == ("session_id", "character varying", "NOT NULL")
@pytest.mark.asyncio
async def test_bulk_insert(self, mock_database_connection):
"""Test bulk insert operation."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
records = [
{"name": "John", "email": "john@example.com"},
{"name": "Jane", "email": "jane@example.com"},
{"name": "Bob", "email": "bob@example.com"}
]
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.rowcount = 3
rows_inserted = await manager.bulk_insert("users", records)
assert rows_inserted == 3
mock_cursor.executemany.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_get_connection_stats(self, mock_database_connection):
"""Test getting connection statistics."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
stats = manager.get_connection_stats()
assert "connected" in stats
assert "host" in stats
assert "database" in stats
assert "schema" in stats
assert stats["connected"] is True
assert stats["host"] == "localhost"
assert stats["database"] == "test_db"
class TestDatabaseManagerIntegration:
"""Integration tests for database manager."""
@pytest.mark.asyncio
async def test_complete_car_detection_workflow(self, mock_database_connection):
"""Test complete car detection database workflow."""
config = DatabaseConfig(
host="localhost",
database="gas_station_db",
username="detector_user",
password="detector_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behaviors for different operations
mock_cursor = mock_database_connection.cursor.return_value
# 1. Create initial detection record
mock_cursor.fetchone.return_value = (456,) # Returned ID
detection_data = {
"session_id": str(uuid.uuid4()),
"camera_id": "camera_001",
"display_id": "display_001",
"detection_class": "car",
"confidence": 0.92,
"bbox_x1": 100,
"bbox_y1": 200,
"bbox_x2": 300,
"bbox_y2": 400,
"track_id": 1001,
"created_at": "NOW()"
}
detection_id = await manager.insert_record("car_detections", detection_data)
assert detection_id == 456
# 2. Update with classification results
mock_cursor.rowcount = 1
classification_data = {
"car_brand": "Toyota",
"car_model": "Camry",
"car_body_type": "Sedan",
"car_color": "Blue",
"brand_confidence": 0.87,
"bodytype_confidence": 0.82,
"color_confidence": 0.79,
"updated_at": "NOW()"
}
where_conditions = {"session_id": detection_data["session_id"]}
rows_updated = await manager.update_record("car_detections", classification_data, where_conditions)
assert rows_updated == 1
# 3. Query final results
mock_cursor.fetchall.return_value = [
(456, detection_data["session_id"], "camera_001", "car", 0.92, "Toyota", "Sedan")
]
results = await manager.execute_query(
"SELECT id, session_id, camera_id, detection_class, confidence, car_brand, car_body_type "
"FROM gas_station_1.car_detections WHERE session_id = %s",
[detection_data["session_id"]]
)
assert len(results) == 1
assert results[0][0] == 456 # ID
assert results[0][3] == "car" # detection_class
assert results[0][5] == "Toyota" # car_brand
# Verify all database operations were called
assert mock_cursor.execute.call_count == 3
assert mock_database_connection.commit.call_count == 2
@pytest.mark.asyncio
async def test_error_handling_and_recovery(self, mock_database_connection):
"""Test error handling and recovery scenarios."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Test transaction rollback on error
mock_cursor = mock_database_connection.cursor.return_value
with pytest.raises(DatabaseError):
async with manager.transaction():
# First operation succeeds
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
# Second operation fails
mock_cursor.execute.side_effect = psycopg2.Error("Constraint violation")
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
# Should have rolled back
mock_database_connection.rollback.assert_called_once()
mock_database_connection.commit.assert_not_called()
@pytest.mark.asyncio
async def test_connection_recovery(self):
"""Test automatic connection recovery."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
with patch('psycopg2.connect') as mock_connect:
# First connection attempt fails
mock_connect.side_effect = [
psycopg2.Error("Connection refused"),
Mock() # Second attempt succeeds
]
# First attempt should fail
with pytest.raises(DatabaseError):
await manager.connect()
# Second attempt should succeed
await manager.connect()
assert manager.is_connected is True

View file

@ -0,0 +1,964 @@
"""
Unit tests for Redis client functionality.
"""
import pytest
import asyncio
import json
import base64
import time
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from datetime import datetime, timedelta
import redis
import numpy as np
from detector_worker.storage.redis_client import (
RedisClient,
RedisConfig,
RedisConnectionPool,
RedisPublisher,
RedisSubscriber,
RedisImageStorage,
RedisError,
ConnectionPoolError
)
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
from detector_worker.core.exceptions import ConfigurationError
class TestRedisConfig:
"""Test Redis configuration."""
def test_creation_minimal(self):
"""Test creating Redis config with minimal parameters."""
config = RedisConfig(
host="localhost"
)
assert config.host == "localhost"
assert config.port == 6379 # Default port
assert config.password is None
assert config.db == 0 # Default database
assert config.enabled is True
def test_creation_full(self):
"""Test creating Redis config with all parameters."""
config = RedisConfig(
host="redis.example.com",
port=6380,
password="secure_pass",
db=2,
enabled=True,
connection_timeout=5.0,
socket_timeout=3.0,
socket_connect_timeout=2.0,
max_connections=50,
retry_on_timeout=True,
health_check_interval=30
)
assert config.host == "redis.example.com"
assert config.port == 6380
assert config.password == "secure_pass"
assert config.db == 2
assert config.connection_timeout == 5.0
assert config.max_connections == 50
assert config.retry_on_timeout is True
def test_get_connection_params(self):
"""Test getting Redis connection parameters."""
config = RedisConfig(
host="localhost",
port=6379,
password="test_pass",
db=1,
connection_timeout=10.0
)
params = config.get_connection_params()
assert params["host"] == "localhost"
assert params["port"] == 6379
assert params["password"] == "test_pass"
assert params["db"] == 1
assert params["socket_timeout"] == 10.0
def test_from_dict(self):
"""Test creating config from dictionary."""
config_dict = {
"host": "redis-server",
"port": 6380,
"password": "secret",
"db": 3,
"max_connections": 100,
"unknown_field": "ignored"
}
config = RedisConfig.from_dict(config_dict)
assert config.host == "redis-server"
assert config.port == 6380
assert config.password == "secret"
assert config.db == 3
assert config.max_connections == 100
class TestRedisConnectionPool:
"""Test Redis connection pool management."""
def test_creation(self):
"""Test connection pool creation."""
config = RedisConfig(
host="localhost",
max_connections=20
)
pool = RedisConnectionPool(config)
assert pool.config == config
assert pool.pool is None
assert pool.is_connected is False
@pytest.mark.asyncio
async def test_connect_success(self):
"""Test successful connection to Redis."""
config = RedisConfig(host="localhost")
pool = RedisConnectionPool(config)
with patch('redis.ConnectionPool') as mock_pool_class:
mock_pool = Mock()
mock_pool_class.return_value = mock_pool
with patch('redis.Redis') as mock_redis_class:
mock_redis = Mock()
mock_redis.ping.return_value = True
mock_redis_class.return_value = mock_redis
await pool.connect()
assert pool.is_connected is True
assert pool.pool is not None
mock_pool_class.assert_called_once()
@pytest.mark.asyncio
async def test_connect_failure(self):
"""Test Redis connection failure."""
config = RedisConfig(host="nonexistent-redis")
pool = RedisConnectionPool(config)
with patch('redis.ConnectionPool') as mock_pool_class:
mock_pool_class.side_effect = redis.ConnectionError("Connection failed")
with pytest.raises(RedisError) as exc_info:
await pool.connect()
assert "Connection failed" in str(exc_info.value)
assert pool.is_connected is False
@pytest.mark.asyncio
async def test_disconnect(self):
"""Test Redis disconnection."""
config = RedisConfig(host="localhost")
pool = RedisConnectionPool(config)
# Mock connected state
mock_pool = Mock()
mock_redis = Mock()
pool.pool = mock_pool
pool._redis_client = mock_redis
pool.is_connected = True
await pool.disconnect()
assert pool.is_connected is False
assert pool.pool is None
mock_pool.disconnect.assert_called_once()
def test_get_client_connected(self):
"""Test getting Redis client when connected."""
config = RedisConfig(host="localhost")
pool = RedisConnectionPool(config)
mock_pool = Mock()
mock_redis = Mock()
pool.pool = mock_pool
pool._redis_client = mock_redis
pool.is_connected = True
client = pool.get_client()
assert client == mock_redis
def test_get_client_not_connected(self):
"""Test getting Redis client when not connected."""
config = RedisConfig(host="localhost")
pool = RedisConnectionPool(config)
with pytest.raises(RedisError) as exc_info:
pool.get_client()
assert "not connected" in str(exc_info.value).lower()
def test_health_check(self):
"""Test Redis health check."""
config = RedisConfig(host="localhost")
pool = RedisConnectionPool(config)
mock_redis = Mock()
mock_redis.ping.return_value = True
pool._redis_client = mock_redis
pool.is_connected = True
is_healthy = pool.health_check()
assert is_healthy is True
mock_redis.ping.assert_called_once()
def test_health_check_failure(self):
"""Test Redis health check failure."""
config = RedisConfig(host="localhost")
pool = RedisConnectionPool(config)
mock_redis = Mock()
mock_redis.ping.side_effect = redis.ConnectionError("Connection lost")
pool._redis_client = mock_redis
pool.is_connected = True
is_healthy = pool.health_check()
assert is_healthy is False
class TestRedisImageStorage:
"""Test Redis image storage functionality."""
def test_creation(self, mock_redis_client):
"""Test Redis image storage creation."""
storage = RedisImageStorage(mock_redis_client)
assert storage.redis_client == mock_redis_client
assert storage.default_expiry == 3600 # 1 hour
assert storage.compression_enabled is True
@pytest.mark.asyncio
async def test_store_image_success(self, mock_redis_client, mock_frame):
"""Test successful image storage."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
with patch('cv2.imencode') as mock_imencode:
# Mock successful encoding
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
result = await storage.store_image("test_key", mock_frame, expire_seconds=600)
assert result is True
mock_redis_client.set.assert_called_once()
mock_redis_client.expire.assert_called_once_with("test_key", 600)
mock_imencode.assert_called_once()
@pytest.mark.asyncio
async def test_store_image_cropped(self, mock_redis_client, mock_frame):
"""Test storing cropped image."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
with patch('cv2.imencode') as mock_imencode:
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
result = await storage.store_image("cropped_key", mock_frame, crop_bbox=bbox)
assert result is True
mock_redis_client.set.assert_called_once()
@pytest.mark.asyncio
async def test_store_image_encoding_failure(self, mock_redis_client, mock_frame):
"""Test image storage with encoding failure."""
storage = RedisImageStorage(mock_redis_client)
with patch('cv2.imencode') as mock_imencode:
# Mock encoding failure
mock_imencode.return_value = (False, None)
with pytest.raises(RedisError) as exc_info:
await storage.store_image("test_key", mock_frame)
assert "Failed to encode image" in str(exc_info.value)
mock_redis_client.set.assert_not_called()
@pytest.mark.asyncio
async def test_store_image_redis_failure(self, mock_redis_client, mock_frame):
"""Test image storage with Redis failure."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.set.side_effect = redis.RedisError("Redis error")
with patch('cv2.imencode') as mock_imencode:
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
with pytest.raises(RedisError) as exc_info:
await storage.store_image("test_key", mock_frame)
assert "Redis error" in str(exc_info.value)
@pytest.mark.asyncio
async def test_retrieve_image_success(self, mock_redis_client):
"""Test successful image retrieval."""
storage = RedisImageStorage(mock_redis_client)
# Mock encoded image data
original_image = np.ones((100, 100, 3), dtype=np.uint8) * 128
with patch('cv2.imencode') as mock_imencode:
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
# Mock Redis returning base64 encoded data
base64_data = base64.b64encode(encoded_data.tobytes()).decode('utf-8')
mock_redis_client.get.return_value = base64_data
with patch('cv2.imdecode') as mock_imdecode:
mock_imdecode.return_value = original_image
retrieved_image = await storage.retrieve_image("test_key")
assert retrieved_image is not None
assert retrieved_image.shape == (100, 100, 3)
mock_redis_client.get.assert_called_once_with("test_key")
@pytest.mark.asyncio
async def test_retrieve_image_not_found(self, mock_redis_client):
"""Test image retrieval when key not found."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.get.return_value = None
retrieved_image = await storage.retrieve_image("nonexistent_key")
assert retrieved_image is None
mock_redis_client.get.assert_called_once_with("nonexistent_key")
@pytest.mark.asyncio
async def test_delete_image(self, mock_redis_client):
"""Test image deletion."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.delete.return_value = 1
result = await storage.delete_image("test_key")
assert result is True
mock_redis_client.delete.assert_called_once_with("test_key")
@pytest.mark.asyncio
async def test_delete_image_not_found(self, mock_redis_client):
"""Test deleting non-existent image."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.delete.return_value = 0
result = await storage.delete_image("nonexistent_key")
assert result is False
mock_redis_client.delete.assert_called_once_with("nonexistent_key")
@pytest.mark.asyncio
async def test_bulk_delete_images(self, mock_redis_client):
"""Test bulk image deletion."""
storage = RedisImageStorage(mock_redis_client)
keys = ["key1", "key2", "key3"]
mock_redis_client.delete.return_value = 3
deleted_count = await storage.bulk_delete_images(keys)
assert deleted_count == 3
mock_redis_client.delete.assert_called_once_with(*keys)
@pytest.mark.asyncio
async def test_cleanup_expired_images(self, mock_redis_client):
"""Test cleanup of expired images."""
storage = RedisImageStorage(mock_redis_client)
# Mock scan to return image keys
mock_redis_client.scan_iter.return_value = [
b"inference:camera1:image1",
b"inference:camera2:image2",
b"inference:camera1:image3"
]
# Mock ttl to return different expiry times
mock_redis_client.ttl.side_effect = [-1, 100, -2] # No expiry, valid, expired
mock_redis_client.delete.return_value = 1
deleted_count = await storage.cleanup_expired_images("inference:*")
assert deleted_count == 1 # Only expired images deleted
mock_redis_client.delete.assert_called_once()
def test_get_image_info(self, mock_redis_client):
"""Test getting image metadata."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.exists.return_value = 1
mock_redis_client.ttl.return_value = 1800 # 30 minutes
mock_redis_client.memory_usage.return_value = 4096 # 4KB
info = storage.get_image_info("test_key")
assert info["exists"] is True
assert info["ttl"] == 1800
assert info["size_bytes"] == 4096
mock_redis_client.exists.assert_called_once_with("test_key")
mock_redis_client.ttl.assert_called_once_with("test_key")
class TestRedisPublisher:
"""Test Redis publisher functionality."""
def test_creation(self, mock_redis_client):
"""Test Redis publisher creation."""
publisher = RedisPublisher(mock_redis_client)
assert publisher.redis_client == mock_redis_client
@pytest.mark.asyncio
async def test_publish_message_string(self, mock_redis_client):
"""Test publishing string message."""
publisher = RedisPublisher(mock_redis_client)
mock_redis_client.publish.return_value = 2 # 2 subscribers
result = await publisher.publish("test_channel", "Hello, Redis!")
assert result == 2
mock_redis_client.publish.assert_called_once_with("test_channel", "Hello, Redis!")
@pytest.mark.asyncio
async def test_publish_message_json(self, mock_redis_client):
"""Test publishing JSON message."""
publisher = RedisPublisher(mock_redis_client)
mock_redis_client.publish.return_value = 1
message_data = {
"camera_id": "camera_001",
"detection_class": "car",
"confidence": 0.95,
"timestamp": 1640995200000
}
result = await publisher.publish("detections", message_data)
assert result == 1
# Should have been JSON serialized
expected_json = json.dumps(message_data)
mock_redis_client.publish.assert_called_once_with("detections", expected_json)
@pytest.mark.asyncio
async def test_publish_detection_event(self, mock_redis_client):
"""Test publishing detection event."""
publisher = RedisPublisher(mock_redis_client)
mock_redis_client.publish.return_value = 3
detection = DetectionResult("car", 0.92, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
result = await publisher.publish_detection_event(
"camera_detections",
detection,
camera_id="camera_001",
session_id="session_123"
)
assert result == 3
# Verify the published message structure
call_args = mock_redis_client.publish.call_args
channel = call_args[0][0]
message_str = call_args[0][1]
message_data = json.loads(message_str)
assert channel == "camera_detections"
assert message_data["event_type"] == "detection"
assert message_data["camera_id"] == "camera_001"
assert message_data["session_id"] == "session_123"
assert message_data["detection"]["class"] == "car"
assert message_data["detection"]["confidence"] == 0.92
@pytest.mark.asyncio
async def test_publish_batch_messages(self, mock_redis_client):
"""Test publishing multiple messages in batch."""
publisher = RedisPublisher(mock_redis_client)
mock_pipeline = Mock()
mock_redis_client.pipeline.return_value = mock_pipeline
mock_pipeline.execute.return_value = [1, 2, 1] # Subscriber counts
messages = [
("channel1", "message1"),
("channel2", {"data": "message2"}),
("channel1", "message3")
]
results = await publisher.publish_batch(messages)
assert results == [1, 2, 1]
mock_redis_client.pipeline.assert_called_once()
assert mock_pipeline.publish.call_count == 3
mock_pipeline.execute.assert_called_once()
@pytest.mark.asyncio
async def test_publish_error_handling(self, mock_redis_client):
"""Test error handling in publishing."""
publisher = RedisPublisher(mock_redis_client)
mock_redis_client.publish.side_effect = redis.RedisError("Publish failed")
with pytest.raises(RedisError) as exc_info:
await publisher.publish("test_channel", "test_message")
assert "Publish failed" in str(exc_info.value)
class TestRedisSubscriber:
"""Test Redis subscriber functionality."""
def test_creation(self, mock_redis_client):
"""Test Redis subscriber creation."""
subscriber = RedisSubscriber(mock_redis_client)
assert subscriber.redis_client == mock_redis_client
assert subscriber.pubsub is None
assert subscriber.subscriptions == set()
@pytest.mark.asyncio
async def test_subscribe_to_channel(self, mock_redis_client):
"""Test subscribing to a channel."""
subscriber = RedisSubscriber(mock_redis_client)
mock_pubsub = Mock()
mock_redis_client.pubsub.return_value = mock_pubsub
await subscriber.subscribe("test_channel")
assert "test_channel" in subscriber.subscriptions
mock_pubsub.subscribe.assert_called_once_with("test_channel")
@pytest.mark.asyncio
async def test_subscribe_to_pattern(self, mock_redis_client):
"""Test subscribing to a pattern."""
subscriber = RedisSubscriber(mock_redis_client)
mock_pubsub = Mock()
mock_redis_client.pubsub.return_value = mock_pubsub
await subscriber.subscribe_pattern("detection:*")
assert "detection:*" in subscriber.subscriptions
mock_pubsub.psubscribe.assert_called_once_with("detection:*")
@pytest.mark.asyncio
async def test_unsubscribe_from_channel(self, mock_redis_client):
"""Test unsubscribing from a channel."""
subscriber = RedisSubscriber(mock_redis_client)
mock_pubsub = Mock()
mock_redis_client.pubsub.return_value = mock_pubsub
subscriber.pubsub = mock_pubsub
subscriber.subscriptions.add("test_channel")
await subscriber.unsubscribe("test_channel")
assert "test_channel" not in subscriber.subscriptions
mock_pubsub.unsubscribe.assert_called_once_with("test_channel")
@pytest.mark.asyncio
async def test_listen_for_messages(self, mock_redis_client):
"""Test listening for messages."""
subscriber = RedisSubscriber(mock_redis_client)
mock_pubsub = Mock()
mock_redis_client.pubsub.return_value = mock_pubsub
# Mock message stream
messages = [
{"type": "subscribe", "channel": "test", "data": 1},
{"type": "message", "channel": "test", "data": "Hello"},
{"type": "message", "channel": "test", "data": '{"key": "value"}'}
]
mock_pubsub.listen.return_value = iter(messages)
received_messages = []
message_count = 0
async for message in subscriber.listen():
received_messages.append(message)
message_count += 1
if message_count >= 2: # Only process actual messages
break
# Should receive 2 actual messages (excluding subscribe confirmation)
assert len(received_messages) == 2
assert received_messages[0]["data"] == "Hello"
assert received_messages[1]["data"] == {"key": "value"} # Should be parsed as JSON
@pytest.mark.asyncio
async def test_close_subscription(self, mock_redis_client):
"""Test closing subscription."""
subscriber = RedisSubscriber(mock_redis_client)
mock_pubsub = Mock()
subscriber.pubsub = mock_pubsub
subscriber.subscriptions = {"channel1", "pattern:*"}
await subscriber.close()
assert len(subscriber.subscriptions) == 0
mock_pubsub.close.assert_called_once()
assert subscriber.pubsub is None
class TestRedisClient:
"""Test main Redis client functionality."""
def test_initialization(self):
"""Test Redis client initialization."""
config = RedisConfig(host="localhost", port=6379)
client = RedisClient(config)
assert client.config == config
assert isinstance(client.connection_pool, RedisConnectionPool)
assert client.image_storage is None
assert client.publisher is None
assert client.subscriber is None
@pytest.mark.asyncio
async def test_connect_and_initialize_components(self):
"""Test connecting and initializing all components."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
with patch.object(client.connection_pool, 'connect', new_callable=AsyncMock) as mock_connect:
mock_redis_client = Mock()
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
await client.connect()
assert client.image_storage is not None
assert client.publisher is not None
assert client.subscriber is not None
assert isinstance(client.image_storage, RedisImageStorage)
assert isinstance(client.publisher, RedisPublisher)
assert isinstance(client.subscriber, RedisSubscriber)
mock_connect.assert_called_once()
@pytest.mark.asyncio
async def test_disconnect(self):
"""Test disconnection."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.is_connected = True
client.subscriber = Mock()
client.subscriber.close = AsyncMock()
with patch.object(client.connection_pool, 'disconnect', new_callable=AsyncMock) as mock_disconnect:
await client.disconnect()
client.subscriber.close.assert_called_once()
mock_disconnect.assert_called_once()
assert client.image_storage is None
assert client.publisher is None
assert client.subscriber is None
@pytest.mark.asyncio
async def test_store_and_retrieve_data(self, mock_redis_client):
"""Test storing and retrieving data."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
# Test storing data
mock_redis_client.set.return_value = True
result = await client.set("test_key", "test_value", expire_seconds=300)
assert result is True
mock_redis_client.set.assert_called_once_with("test_key", "test_value")
mock_redis_client.expire.assert_called_once_with("test_key", 300)
# Test retrieving data
mock_redis_client.get.return_value = "test_value"
value = await client.get("test_key")
assert value == "test_value"
mock_redis_client.get.assert_called_once_with("test_key")
@pytest.mark.asyncio
async def test_delete_keys(self, mock_redis_client):
"""Test deleting keys."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_redis_client.delete.return_value = 2
result = await client.delete("key1", "key2")
assert result == 2
mock_redis_client.delete.assert_called_once_with("key1", "key2")
@pytest.mark.asyncio
async def test_exists_check(self, mock_redis_client):
"""Test checking key existence."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_redis_client.exists.return_value = 1
exists = await client.exists("test_key")
assert exists is True
mock_redis_client.exists.assert_called_once_with("test_key")
@pytest.mark.asyncio
async def test_expire_key(self, mock_redis_client):
"""Test setting key expiration."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_redis_client.expire.return_value = True
result = await client.expire("test_key", 600)
assert result is True
mock_redis_client.expire.assert_called_once_with("test_key", 600)
@pytest.mark.asyncio
async def test_get_ttl(self, mock_redis_client):
"""Test getting key TTL."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_redis_client.ttl.return_value = 300
ttl = await client.ttl("test_key")
assert ttl == 300
mock_redis_client.ttl.assert_called_once_with("test_key")
@pytest.mark.asyncio
async def test_scan_keys(self, mock_redis_client):
"""Test scanning for keys."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_redis_client.scan_iter.return_value = [b"key1", b"key2", b"key3"]
keys = await client.scan_keys("test:*")
assert keys == ["key1", "key2", "key3"]
mock_redis_client.scan_iter.assert_called_once_with(match="test:*")
@pytest.mark.asyncio
async def test_flush_database(self, mock_redis_client):
"""Test flushing database."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_redis_client.flushdb.return_value = True
result = await client.flush_db()
assert result is True
mock_redis_client.flushdb.assert_called_once()
def test_get_connection_info(self):
"""Test getting connection information."""
config = RedisConfig(
host="redis.example.com",
port=6380,
db=2
)
client = RedisClient(config)
client.connection_pool.is_connected = True
info = client.get_connection_info()
assert info["connected"] is True
assert info["host"] == "redis.example.com"
assert info["port"] == 6380
assert info["database"] == 2
@pytest.mark.asyncio
async def test_pipeline_operations(self, mock_redis_client):
"""Test Redis pipeline operations."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_pipeline = Mock()
mock_redis_client.pipeline.return_value = mock_pipeline
mock_pipeline.execute.return_value = [True, True, 1]
async with client.pipeline() as pipe:
pipe.set("key1", "value1")
pipe.set("key2", "value2")
pipe.delete("key3")
results = await pipe.execute()
assert results == [True, True, 1]
mock_redis_client.pipeline.assert_called_once()
mock_pipeline.execute.assert_called_once()
class TestRedisClientIntegration:
"""Integration tests for Redis client."""
@pytest.mark.asyncio
async def test_complete_image_workflow(self, mock_redis_client, mock_frame):
"""Test complete image storage workflow."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state and components
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
client.image_storage = RedisImageStorage(mock_redis_client)
client.publisher = RedisPublisher(mock_redis_client)
# Mock Redis operations
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
mock_redis_client.publish.return_value = 2
with patch('cv2.imencode') as mock_imencode:
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
# Store image
store_result = await client.image_storage.store_image(
"detection:camera001:1640995200:session123",
mock_frame,
expire_seconds=600
)
# Publish detection event
detection_event = {
"camera_id": "camera001",
"session_id": "session123",
"detection_class": "car",
"confidence": 0.95,
"timestamp": 1640995200000
}
publish_result = await client.publisher.publish("detections:camera001", detection_event)
assert store_result is True
assert publish_result == 2
# Verify Redis operations
mock_redis_client.set.assert_called_once()
mock_redis_client.expire.assert_called_once()
mock_redis_client.publish.assert_called_once()
@pytest.mark.asyncio
async def test_error_recovery_and_reconnection(self):
"""Test error recovery and reconnection."""
config = RedisConfig(host="localhost", retry_on_timeout=True)
client = RedisClient(config)
with patch.object(client.connection_pool, 'connect', new_callable=AsyncMock) as mock_connect:
with patch.object(client.connection_pool, 'health_check') as mock_health_check:
# First health check fails, second succeeds
mock_health_check.side_effect = [False, True]
# First connection attempt fails, second succeeds
mock_connect.side_effect = [RedisError("Connection failed"), None]
# Simulate connection recovery
try:
await client.connect()
except RedisError:
# Retry connection
await client.connect()
assert mock_connect.call_count == 2
@pytest.mark.asyncio
async def test_bulk_operations_performance(self, mock_redis_client):
"""Test bulk operations for performance."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
client.publisher = RedisPublisher(mock_redis_client)
# Mock pipeline operations
mock_pipeline = Mock()
mock_redis_client.pipeline.return_value = mock_pipeline
mock_pipeline.execute.return_value = [1] * 100 # 100 successful operations
# Prepare bulk messages
messages = [
(f"channel_{i}", f"message_{i}")
for i in range(100)
]
start_time = time.time()
results = await client.publisher.publish_batch(messages)
execution_time = time.time() - start_time
assert len(results) == 100
assert all(result == 1 for result in results)
# Should be faster than individual operations
assert execution_time < 1.0 # Should complete in less than 1 second
# Pipeline should be used for efficiency
mock_redis_client.pipeline.assert_called_once()
assert mock_pipeline.publish.call_count == 100
mock_pipeline.execute.assert_called_once()

View file

@ -0,0 +1,883 @@
"""
Unit tests for session cache management.
"""
import pytest
import time
import uuid
from unittest.mock import Mock, patch
from datetime import datetime, timedelta
from collections import defaultdict
from detector_worker.storage.session_cache import (
SessionCache,
SessionCacheManager,
SessionData,
CacheConfig,
CacheEntry,
CacheStats,
SessionError,
CacheError
)
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
class TestCacheConfig:
"""Test cache configuration."""
def test_creation_default(self):
"""Test creating cache config with default values."""
config = CacheConfig()
assert config.max_size == 1000
assert config.ttl_seconds == 3600 # 1 hour
assert config.cleanup_interval == 300 # 5 minutes
assert config.eviction_policy == "lru"
assert config.enable_persistence is False
def test_creation_custom(self):
"""Test creating cache config with custom values."""
config = CacheConfig(
max_size=5000,
ttl_seconds=7200,
cleanup_interval=600,
eviction_policy="lfu",
enable_persistence=True,
persistence_path="/tmp/cache"
)
assert config.max_size == 5000
assert config.ttl_seconds == 7200
assert config.cleanup_interval == 600
assert config.eviction_policy == "lfu"
assert config.enable_persistence is True
assert config.persistence_path == "/tmp/cache"
def test_from_dict(self):
"""Test creating config from dictionary."""
config_dict = {
"max_size": 2000,
"ttl_seconds": 1800,
"eviction_policy": "fifo",
"enable_persistence": True,
"unknown_field": "ignored"
}
config = CacheConfig.from_dict(config_dict)
assert config.max_size == 2000
assert config.ttl_seconds == 1800
assert config.eviction_policy == "fifo"
assert config.enable_persistence is True
class TestCacheEntry:
"""Test cache entry data structure."""
def test_creation(self):
"""Test cache entry creation."""
data = {"key": "value", "number": 42}
entry = CacheEntry(data, ttl_seconds=600)
assert entry.data == data
assert entry.ttl_seconds == 600
assert entry.created_at <= time.time()
assert entry.last_accessed <= time.time()
assert entry.access_count == 1
assert entry.size > 0
def test_is_expired(self):
"""Test expiration checking."""
# Non-expired entry
entry = CacheEntry({"data": "test"}, ttl_seconds=600)
assert entry.is_expired() is False
# Expired entry (simulate by setting old creation time)
entry.created_at = time.time() - 700 # Created 700 seconds ago
assert entry.is_expired() is True
# Entry without expiration
entry_no_ttl = CacheEntry({"data": "test"})
assert entry_no_ttl.is_expired() is False
def test_touch(self):
"""Test updating access time and count."""
entry = CacheEntry({"data": "test"})
original_access_time = entry.last_accessed
original_access_count = entry.access_count
time.sleep(0.01) # Small delay
entry.touch()
assert entry.last_accessed > original_access_time
assert entry.access_count == original_access_count + 1
def test_age(self):
"""Test age calculation."""
entry = CacheEntry({"data": "test"})
time.sleep(0.01) # Small delay
age = entry.age()
assert age > 0
assert age < 1 # Should be less than 1 second
def test_size_estimation(self):
"""Test size estimation."""
small_entry = CacheEntry({"key": "value"})
large_entry = CacheEntry({"key": "x" * 1000, "data": list(range(100))})
assert large_entry.size > small_entry.size
class TestSessionData:
"""Test session data structure."""
def test_creation(self):
"""Test session data creation."""
session_data = SessionData(
session_id="session_123",
camera_id="camera_001",
display_id="display_001"
)
assert session_data.session_id == "session_123"
assert session_data.camera_id == "camera_001"
assert session_data.display_id == "display_001"
assert session_data.created_at <= time.time()
assert session_data.last_activity <= time.time()
assert session_data.detection_data == {}
assert session_data.metadata == {}
def test_update_activity(self):
"""Test updating last activity."""
session_data = SessionData("session_123", "camera_001", "display_001")
original_activity = session_data.last_activity
time.sleep(0.01)
session_data.update_activity()
assert session_data.last_activity > original_activity
def test_add_detection_data(self):
"""Test adding detection data."""
session_data = SessionData("session_123", "camera_001", "display_001")
detection_data = {
"class": "car",
"confidence": 0.95,
"bbox": [100, 200, 300, 400]
}
session_data.add_detection_data("main_detection", detection_data)
assert "main_detection" in session_data.detection_data
assert session_data.detection_data["main_detection"] == detection_data
def test_add_metadata(self):
"""Test adding metadata."""
session_data = SessionData("session_123", "camera_001", "display_001")
session_data.add_metadata("model_version", "v2.1")
session_data.add_metadata("inference_time", 0.15)
assert session_data.metadata["model_version"] == "v2.1"
assert session_data.metadata["inference_time"] == 0.15
def test_is_expired(self):
"""Test session expiration."""
session_data = SessionData("session_123", "camera_001", "display_001")
# Not expired with default timeout
assert session_data.is_expired() is False
# Expired with short timeout
assert session_data.is_expired(timeout_seconds=0.001) is True
def test_to_dict(self):
"""Test converting session to dictionary."""
session_data = SessionData("session_123", "camera_001", "display_001")
session_data.add_detection_data("detection", {"class": "car", "confidence": 0.9})
session_data.add_metadata("model_id", "yolo_v8")
data_dict = session_data.to_dict()
assert data_dict["session_id"] == "session_123"
assert data_dict["camera_id"] == "camera_001"
assert data_dict["detection_data"]["detection"]["class"] == "car"
assert data_dict["metadata"]["model_id"] == "yolo_v8"
assert "created_at" in data_dict
assert "last_activity" in data_dict
class TestCacheStats:
"""Test cache statistics."""
def test_creation(self):
"""Test cache stats creation."""
stats = CacheStats()
assert stats.hits == 0
assert stats.misses == 0
assert stats.evictions == 0
assert stats.size == 0
assert stats.memory_usage == 0
def test_hit_rate_calculation(self):
"""Test hit rate calculation."""
stats = CacheStats()
# No requests yet
assert stats.hit_rate() == 0.0
# Some hits and misses
stats.hits = 8
stats.misses = 2
assert stats.hit_rate() == 0.8 # 8 / (8 + 2)
def test_total_requests(self):
"""Test total requests calculation."""
stats = CacheStats()
stats.hits = 15
stats.misses = 5
assert stats.total_requests() == 20
class TestSessionCache:
"""Test session cache functionality."""
def test_creation(self):
"""Test session cache creation."""
config = CacheConfig(max_size=100, ttl_seconds=300)
cache = SessionCache(config)
assert cache.config == config
assert cache.max_size == 100
assert cache.ttl_seconds == 300
assert len(cache._cache) == 0
assert len(cache._access_order) == 0
def test_put_and_get_session(self):
"""Test putting and getting session data."""
cache = SessionCache(CacheConfig(max_size=10))
session_data = SessionData("session_123", "camera_001", "display_001")
session_data.add_detection_data("main", {"class": "car", "confidence": 0.9})
# Put session
cache.put("session_123", session_data)
# Get session
retrieved_data = cache.get("session_123")
assert retrieved_data is not None
assert retrieved_data.session_id == "session_123"
assert retrieved_data.camera_id == "camera_001"
assert "main" in retrieved_data.detection_data
def test_get_nonexistent_session(self):
"""Test getting non-existent session."""
cache = SessionCache(CacheConfig(max_size=10))
result = cache.get("nonexistent_session")
assert result is None
def test_contains_check(self):
"""Test checking if session exists."""
cache = SessionCache(CacheConfig(max_size=10))
session_data = SessionData("session_123", "camera_001", "display_001")
cache.put("session_123", session_data)
assert cache.contains("session_123") is True
assert cache.contains("nonexistent_session") is False
def test_remove_session(self):
"""Test removing session."""
cache = SessionCache(CacheConfig(max_size=10))
session_data = SessionData("session_123", "camera_001", "display_001")
cache.put("session_123", session_data)
assert cache.contains("session_123") is True
removed_data = cache.remove("session_123")
assert removed_data is not None
assert removed_data.session_id == "session_123"
assert cache.contains("session_123") is False
def test_size_tracking(self):
"""Test cache size tracking."""
cache = SessionCache(CacheConfig(max_size=10))
assert cache.size() == 0
assert cache.is_empty() is True
# Add sessions
for i in range(3):
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
cache.put(f"session_{i}", session_data)
assert cache.size() == 3
assert cache.is_empty() is False
def test_lru_eviction(self):
"""Test LRU eviction policy."""
cache = SessionCache(CacheConfig(max_size=3, eviction_policy="lru"))
# Fill cache to capacity
for i in range(3):
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
cache.put(f"session_{i}", session_data)
# Access session_1 to make it recently used
cache.get("session_1")
# Add another session (should evict session_0, the least recently used)
new_session = SessionData("session_3", "camera_001", "display_001")
cache.put("session_3", new_session)
assert cache.size() == 3
assert cache.contains("session_0") is False # Evicted
assert cache.contains("session_1") is True # Recently accessed
assert cache.contains("session_2") is True
assert cache.contains("session_3") is True # Newly added
def test_ttl_expiration(self):
"""Test TTL-based expiration."""
cache = SessionCache(CacheConfig(max_size=10, ttl_seconds=0.1)) # 100ms TTL
session_data = SessionData("session_123", "camera_001", "display_001")
cache.put("session_123", session_data)
# Should exist immediately
assert cache.contains("session_123") is True
# Wait for expiration
time.sleep(0.2)
# Should be expired (but might still be in cache until cleanup)
entry = cache._cache.get("session_123")
if entry:
assert entry.is_expired() is True
# Getting expired entry should return None and clean it up
retrieved = cache.get("session_123")
assert retrieved is None
assert cache.contains("session_123") is False
def test_cleanup_expired_entries(self):
"""Test cleanup of expired entries."""
cache = SessionCache(CacheConfig(max_size=10, ttl_seconds=0.1))
# Add multiple sessions
for i in range(3):
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
cache.put(f"session_{i}", session_data)
assert cache.size() == 3
# Wait for expiration
time.sleep(0.2)
# Cleanup expired entries
cleaned_count = cache.cleanup_expired()
assert cleaned_count == 3
assert cache.size() == 0
def test_clear_cache(self):
"""Test clearing entire cache."""
cache = SessionCache(CacheConfig(max_size=10))
# Add sessions
for i in range(5):
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
cache.put(f"session_{i}", session_data)
assert cache.size() == 5
cache.clear()
assert cache.size() == 0
assert cache.is_empty() is True
def test_get_all_sessions(self):
"""Test getting all sessions."""
cache = SessionCache(CacheConfig(max_size=10))
sessions = []
for i in range(3):
session_data = SessionData(f"session_{i}", f"camera_{i}", "display_001")
cache.put(f"session_{i}", session_data)
sessions.append(session_data)
all_sessions = cache.get_all()
assert len(all_sessions) == 3
for session_id, session_data in all_sessions.items():
assert session_id.startswith("session_")
assert session_data.session_id == session_id
def test_get_sessions_by_camera(self):
"""Test getting sessions by camera ID."""
cache = SessionCache(CacheConfig(max_size=10))
# Add sessions for different cameras
for i in range(2):
session_data1 = SessionData(f"session_cam1_{i}", "camera_001", "display_001")
session_data2 = SessionData(f"session_cam2_{i}", "camera_002", "display_001")
cache.put(f"session_cam1_{i}", session_data1)
cache.put(f"session_cam2_{i}", session_data2)
camera1_sessions = cache.get_by_camera("camera_001")
camera2_sessions = cache.get_by_camera("camera_002")
assert len(camera1_sessions) == 2
assert len(camera2_sessions) == 2
for session_data in camera1_sessions:
assert session_data.camera_id == "camera_001"
for session_data in camera2_sessions:
assert session_data.camera_id == "camera_002"
def test_statistics_tracking(self):
"""Test cache statistics tracking."""
cache = SessionCache(CacheConfig(max_size=10))
session_data = SessionData("session_123", "camera_001", "display_001")
cache.put("session_123", session_data)
# Cache miss
cache.get("nonexistent_session")
# Cache hit
cache.get("session_123")
cache.get("session_123") # Another hit
stats = cache.get_stats()
assert stats.hits == 2
assert stats.misses == 1
assert stats.size == 1
assert stats.hit_rate() == 2/3 # 2 hits out of 3 total requests
def test_memory_usage_estimation(self):
"""Test memory usage estimation."""
cache = SessionCache(CacheConfig(max_size=10))
initial_memory = cache.get_memory_usage()
# Add large session
session_data = SessionData("session_123", "camera_001", "display_001")
session_data.add_detection_data("large_data", {"data": "x" * 1000})
cache.put("session_123", session_data)
after_memory = cache.get_memory_usage()
assert after_memory > initial_memory
class TestSessionCacheManager:
"""Test session cache manager."""
def test_singleton_behavior(self):
"""Test that SessionCacheManager is a singleton."""
manager1 = SessionCacheManager()
manager2 = SessionCacheManager()
assert manager1 is manager2
def test_initialization(self):
"""Test session cache manager initialization."""
manager = SessionCacheManager()
assert manager.detection_cache is not None
assert manager.pipeline_cache is not None
assert manager.session_cache is not None
assert isinstance(manager.detection_cache, SessionCache)
assert isinstance(manager.pipeline_cache, SessionCache)
assert isinstance(manager.session_cache, SessionCache)
def test_cache_detection_result(self):
"""Test caching detection results."""
manager = SessionCacheManager()
manager.clear_all() # Start fresh
detection_data = {
"class": "car",
"confidence": 0.95,
"bbox": [100, 200, 300, 400],
"track_id": 1001
}
manager.cache_detection("camera_001", detection_data)
cached_detection = manager.get_cached_detection("camera_001")
assert cached_detection is not None
assert cached_detection["class"] == "car"
assert cached_detection["confidence"] == 0.95
assert cached_detection["track_id"] == 1001
def test_cache_pipeline_result(self):
"""Test caching pipeline results."""
manager = SessionCacheManager()
manager.clear_all()
pipeline_result = {
"status": "success",
"detections": [{"class": "car", "confidence": 0.9}],
"execution_time": 0.15,
"model_id": "yolo_v8"
}
manager.cache_pipeline_result("camera_001", pipeline_result)
cached_result = manager.get_cached_pipeline_result("camera_001")
assert cached_result is not None
assert cached_result["status"] == "success"
assert cached_result["execution_time"] == 0.15
assert len(cached_result["detections"]) == 1
def test_manage_session_data(self):
"""Test session data management."""
manager = SessionCacheManager()
manager.clear_all()
session_id = str(uuid.uuid4())
# Create session
manager.create_session(session_id, "camera_001", {"initial": "data"})
# Update session
manager.update_session_detection(session_id, {"car_brand": "Toyota"})
# Get session
session_data = manager.get_session_detection(session_id)
assert session_data is not None
assert "initial" in session_data
assert session_data["car_brand"] == "Toyota"
def test_set_latest_frame(self):
"""Test setting and getting latest frame."""
manager = SessionCacheManager()
manager.clear_all()
frame_data = b"fake_frame_data"
manager.set_latest_frame("camera_001", frame_data)
retrieved_frame = manager.get_latest_frame("camera_001")
assert retrieved_frame == frame_data
def test_frame_skip_flag_management(self):
"""Test frame skip flag management."""
manager = SessionCacheManager()
manager.clear_all()
# Initially should be False
assert manager.get_frame_skip_flag("camera_001") is False
# Set to True
manager.set_frame_skip_flag("camera_001", True)
assert manager.get_frame_skip_flag("camera_001") is True
# Set back to False
manager.set_frame_skip_flag("camera_001", False)
assert manager.get_frame_skip_flag("camera_001") is False
def test_cleanup_expired_sessions(self):
"""Test cleanup of expired sessions."""
manager = SessionCacheManager()
manager.clear_all()
# Create sessions with short TTL
manager.session_cache = SessionCache(CacheConfig(max_size=10, ttl_seconds=0.1))
# Add sessions
for i in range(3):
session_id = f"session_{i}"
manager.create_session(session_id, "camera_001", {"test": "data"})
assert manager.session_cache.size() == 3
# Wait for expiration
time.sleep(0.2)
# Cleanup
expired_count = manager.cleanup_expired_sessions()
assert expired_count == 3
assert manager.session_cache.size() == 0
def test_clear_camera_cache(self):
"""Test clearing cache for specific camera."""
manager = SessionCacheManager()
manager.clear_all()
# Add data for multiple cameras
manager.cache_detection("camera_001", {"class": "car"})
manager.cache_detection("camera_002", {"class": "truck"})
manager.cache_pipeline_result("camera_001", {"status": "success"})
manager.set_latest_frame("camera_001", b"frame1")
manager.set_latest_frame("camera_002", b"frame2")
# Clear camera_001 cache
manager.clear_camera_cache("camera_001")
# camera_001 data should be gone
assert manager.get_cached_detection("camera_001") is None
assert manager.get_cached_pipeline_result("camera_001") is None
assert manager.get_latest_frame("camera_001") is None
# camera_002 data should remain
assert manager.get_cached_detection("camera_002") is not None
assert manager.get_latest_frame("camera_002") is not None
def test_get_cache_statistics(self):
"""Test getting cache statistics."""
manager = SessionCacheManager()
manager.clear_all()
# Add some data to generate statistics
manager.cache_detection("camera_001", {"class": "car"})
manager.cache_pipeline_result("camera_001", {"status": "success"})
manager.create_session("session_123", "camera_001", {"initial": "data"})
# Access data to generate hits/misses
manager.get_cached_detection("camera_001") # Hit
manager.get_cached_detection("camera_002") # Miss
stats = manager.get_cache_statistics()
assert "detection_cache" in stats
assert "pipeline_cache" in stats
assert "session_cache" in stats
assert "total_memory_usage" in stats
detection_stats = stats["detection_cache"]
assert detection_stats["size"] >= 1
assert detection_stats["hits"] >= 1
assert detection_stats["misses"] >= 1
def test_memory_pressure_handling(self):
"""Test handling memory pressure."""
# Create manager with small cache sizes
config = CacheConfig(max_size=3)
manager = SessionCacheManager()
manager.detection_cache = SessionCache(config)
manager.pipeline_cache = SessionCache(config)
manager.session_cache = SessionCache(config)
# Fill caches beyond capacity
for i in range(5):
manager.cache_detection(f"camera_{i}", {"class": "car", "data": "x" * 100})
manager.cache_pipeline_result(f"camera_{i}", {"status": "success", "data": "y" * 100})
manager.create_session(f"session_{i}", f"camera_{i}", {"data": "z" * 100})
# Caches should not exceed max size due to eviction
assert manager.detection_cache.size() <= 3
assert manager.pipeline_cache.size() <= 3
assert manager.session_cache.size() <= 3
def test_concurrent_access_thread_safety(self):
"""Test thread safety of concurrent cache access."""
import threading
import concurrent.futures
manager = SessionCacheManager()
manager.clear_all()
results = []
errors = []
def cache_operation(thread_id):
try:
# Each thread performs multiple cache operations
for i in range(10):
session_id = f"thread_{thread_id}_session_{i}"
# Create session
manager.create_session(session_id, f"camera_{thread_id}", {"thread": thread_id, "index": i})
# Update session
manager.update_session_detection(session_id, {"updated": True})
# Read session
data = manager.get_session_detection(session_id)
if data and data.get("thread") == thread_id:
results.append((thread_id, i))
except Exception as e:
errors.append((thread_id, str(e)))
# Run operations concurrently
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(cache_operation, i) for i in range(5)]
concurrent.futures.wait(futures)
# Should have no errors and successful operations
assert len(errors) == 0
assert len(results) >= 25 # At least some operations should succeed
class TestSessionCacheIntegration:
"""Integration tests for session cache."""
def test_complete_detection_workflow(self):
"""Test complete detection workflow with caching."""
manager = SessionCacheManager()
manager.clear_all()
camera_id = "camera_001"
session_id = str(uuid.uuid4())
# 1. Cache initial detection
detection_data = {
"class": "car",
"confidence": 0.92,
"bbox": [100, 200, 300, 400],
"track_id": 1001,
"timestamp": int(time.time() * 1000)
}
manager.cache_detection(camera_id, detection_data)
# 2. Create session for tracking
initial_session_data = {
"detection_class": detection_data["class"],
"confidence": detection_data["confidence"],
"track_id": detection_data["track_id"]
}
manager.create_session(session_id, camera_id, initial_session_data)
# 3. Cache pipeline processing result
pipeline_result = {
"status": "processing",
"stage": "classification",
"detections": [detection_data],
"branches_completed": [],
"branches_pending": ["car_brand_cls", "car_bodytype_cls"]
}
manager.cache_pipeline_result(camera_id, pipeline_result)
# 4. Update session with classification results
classification_updates = [
{"car_brand": "Toyota", "brand_confidence": 0.87},
{"car_body_type": "Sedan", "bodytype_confidence": 0.82}
]
for update in classification_updates:
manager.update_session_detection(session_id, update)
# 5. Update pipeline result to completed
final_pipeline_result = {
"status": "completed",
"stage": "finished",
"detections": [detection_data],
"branches_completed": ["car_brand_cls", "car_bodytype_cls"],
"branches_pending": [],
"execution_time": 0.25
}
manager.cache_pipeline_result(camera_id, final_pipeline_result)
# 6. Verify all cached data
cached_detection = manager.get_cached_detection(camera_id)
cached_pipeline = manager.get_cached_pipeline_result(camera_id)
cached_session = manager.get_session_detection(session_id)
# Assertions
assert cached_detection["class"] == "car"
assert cached_detection["track_id"] == 1001
assert cached_pipeline["status"] == "completed"
assert len(cached_pipeline["branches_completed"]) == 2
assert cached_session["detection_class"] == "car"
assert cached_session["car_brand"] == "Toyota"
assert cached_session["car_body_type"] == "Sedan"
assert cached_session["brand_confidence"] == 0.87
def test_cache_performance_under_load(self):
"""Test cache performance under load."""
manager = SessionCacheManager()
manager.clear_all()
import time
# Measure performance of cache operations
start_time = time.time()
# Perform many cache operations
for i in range(1000):
camera_id = f"camera_{i % 10}" # 10 different cameras
session_id = f"session_{i}"
# Cache detection
detection_data = {
"class": "car",
"confidence": 0.9 + (i % 10) * 0.01,
"track_id": i,
"bbox": [i % 100, i % 100, (i % 100) + 200, (i % 100) + 200]
}
manager.cache_detection(camera_id, detection_data)
# Create session
manager.create_session(session_id, camera_id, {"index": i})
# Read back (every 10th operation)
if i % 10 == 0:
manager.get_cached_detection(camera_id)
manager.get_session_detection(session_id)
end_time = time.time()
total_time = end_time - start_time
# Should complete in reasonable time (less than 1 second)
assert total_time < 1.0
# Verify cache statistics
stats = manager.get_cache_statistics()
assert stats["detection_cache"]["size"] > 0
assert stats["session_cache"]["size"] > 0
assert stats["detection_cache"]["hits"] > 0
def test_cache_persistence_and_recovery(self):
"""Test cache persistence and recovery (if enabled)."""
# This test would be more meaningful with actual persistence
# For now, test the configuration and structure
persistence_config = CacheConfig(
max_size=100,
enable_persistence=True,
persistence_path="/tmp/detector_cache_test"
)
cache = SessionCache(persistence_config)
# Add some data
session_data = SessionData("session_123", "camera_001", "display_001")
session_data.add_detection_data("main", {"class": "car", "confidence": 0.95})
cache.put("session_123", session_data)
# Verify data exists
assert cache.contains("session_123") is True
# In a real implementation, this would test:
# 1. Saving cache to disk
# 2. Loading cache from disk
# 3. Verifying data integrity after reload