Refactor: PHASE 3: Action & Storage Extraction

This commit is contained in:
ziesorx 2025-09-12 14:58:28 +07:00
parent 4e9ae6bcc4
commit cdeaaf4a4f
5 changed files with 3048 additions and 0 deletions

View file

@ -0,0 +1,617 @@
"""
Database management and operations.
This module provides comprehensive database functionality including:
- PostgreSQL connection management
- Table schema management
- Dynamic query execution
- Transaction handling
"""
import psycopg2
import psycopg2.extras
import uuid
import logging
from typing import Optional, Dict, Any, List, Tuple
from dataclasses import dataclass, field
from datetime import datetime
from ..core.constants import (
DB_CONNECTION_TIMEOUT,
DB_OPERATION_TIMEOUT,
DB_RETRY_ATTEMPTS
)
from ..core.exceptions import DatabaseError, create_detection_error
logger = logging.getLogger(__name__)
@dataclass
class DatabaseConfig:
"""Database connection configuration."""
host: str
port: int
database: str
username: str
password: str
schema: str = "gas_station_1"
connection_timeout: int = DB_CONNECTION_TIMEOUT
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"host": self.host,
"port": self.port,
"database": self.database,
"username": self.username,
"password": self.password,
"schema": self.schema,
"connection_timeout": self.connection_timeout
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'DatabaseConfig':
"""Create from dictionary data."""
return cls(
host=data["host"],
port=data["port"],
database=data["database"],
username=data.get("username", data.get("user", "")),
password=data["password"],
schema=data.get("schema", "gas_station_1"),
connection_timeout=data.get("connection_timeout", DB_CONNECTION_TIMEOUT)
)
@dataclass
class QueryResult:
"""Result from database query execution."""
success: bool
affected_rows: int = 0
data: Optional[List[Dict[str, Any]]] = None
error: Optional[str] = None
query: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
result = {
"success": self.success,
"affected_rows": self.affected_rows
}
if self.data:
result["data"] = self.data
if self.error:
result["error"] = self.error
if self.query:
result["query"] = self.query
return result
class DatabaseManager:
"""
Comprehensive database manager for PostgreSQL operations.
This class provides connection management, schema operations,
and dynamic query execution with transaction support.
"""
def __init__(self, config: Dict[str, Any]):
"""
Initialize database manager.
Args:
config: Database configuration dictionary
"""
if isinstance(config, dict):
self.config = DatabaseConfig.from_dict(config)
else:
self.config = config
self.connection: Optional[psycopg2.extensions.connection] = None
self._lock = None
def _ensure_thread_safety(self):
"""Initialize thread safety if not already done."""
if self._lock is None:
import threading
self._lock = threading.RLock()
def connect(self) -> bool:
"""
Establish database connection.
Returns:
True if connection successful, False otherwise
"""
self._ensure_thread_safety()
with self._lock:
try:
if self.connection and not self.connection.closed:
# Connection already exists and is open
return True
self.connection = psycopg2.connect(
host=self.config.host,
port=self.config.port,
database=self.config.database,
user=self.config.username,
password=self.config.password,
connect_timeout=self.config.connection_timeout
)
# Set connection properties
self.connection.set_client_encoding('UTF8')
logger.info(f"PostgreSQL connection established successfully to {self.config.host}:{self.config.port}")
return True
except Exception as e:
logger.error(f"Failed to connect to PostgreSQL: {e}")
self.connection = None
return False
def disconnect(self) -> None:
"""Close database connection."""
self._ensure_thread_safety()
with self._lock:
if self.connection:
try:
self.connection.close()
logger.info("PostgreSQL connection closed")
except Exception as e:
logger.error(f"Error closing PostgreSQL connection: {e}")
finally:
self.connection = None
def is_connected(self) -> bool:
"""
Check if database connection is active.
Returns:
True if connected, False otherwise
"""
self._ensure_thread_safety()
with self._lock:
try:
if self.connection and not self.connection.closed:
cur = self.connection.cursor()
cur.execute("SELECT 1")
cur.fetchone()
cur.close()
return True
except Exception as e:
logger.debug(f"Connection check failed: {e}")
return False
def _ensure_connected(self) -> bool:
"""Ensure database connection is active."""
if not self.is_connected():
return self.connect()
return True
def execute_query(self,
query: str,
params: Optional[Tuple] = None,
fetch_results: bool = False) -> QueryResult:
"""
Execute a database query.
Args:
query: SQL query string
params: Query parameters
fetch_results: Whether to fetch and return results
Returns:
QueryResult with execution details
"""
self._ensure_thread_safety()
with self._lock:
if not self._ensure_connected():
return QueryResult(
success=False,
error="Failed to establish database connection",
query=query
)
cursor = None
try:
cursor = self.connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
logger.debug(f"Executing query: {query}")
if params:
logger.debug(f"Query parameters: {params}")
cursor.execute(query, params)
affected_rows = cursor.rowcount
data = None
if fetch_results:
data = [dict(row) for row in cursor.fetchall()]
self.connection.commit()
logger.debug(f"Query executed successfully, affected rows: {affected_rows}")
return QueryResult(
success=True,
affected_rows=affected_rows,
data=data,
query=query
)
except Exception as e:
error_msg = f"Database query execution failed: {e}"
logger.error(error_msg)
logger.debug(f"Failed query: {query}")
if params:
logger.debug(f"Query parameters: {params}")
if self.connection:
try:
self.connection.rollback()
except Exception as rollback_error:
logger.error(f"Rollback failed: {rollback_error}")
return QueryResult(
success=False,
error=error_msg,
query=query
)
finally:
if cursor:
cursor.close()
def create_schema_if_not_exists(self, schema_name: Optional[str] = None) -> bool:
"""
Create database schema if it doesn't exist.
Args:
schema_name: Schema name (uses config schema if not provided)
Returns:
True if successful, False otherwise
"""
schema_name = schema_name or self.config.schema
query = f"CREATE SCHEMA IF NOT EXISTS {schema_name}"
result = self.execute_query(query)
if result.success:
logger.info(f"Schema '{schema_name}' created or verified successfully")
return result.success
def create_car_frontal_info_table(self) -> bool:
"""
Create the car_frontal_info table in the configured schema if it doesn't exist.
Returns:
True if successful, False otherwise
"""
schema_name = self.config.schema
# Ensure schema exists
if not self.create_schema_if_not_exists(schema_name):
return False
try:
# Create table if it doesn't exist
create_table_query = f"""
CREATE TABLE IF NOT EXISTS {schema_name}.car_frontal_info (
display_id VARCHAR(255),
captured_timestamp VARCHAR(255),
session_id VARCHAR(255) PRIMARY KEY,
license_character VARCHAR(255) DEFAULT NULL,
license_type VARCHAR(255) DEFAULT 'No model available',
car_brand VARCHAR(255) DEFAULT NULL,
car_model VARCHAR(255) DEFAULT NULL,
car_body_type VARCHAR(255) DEFAULT NULL,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
)
"""
result = self.execute_query(create_table_query)
if not result.success:
return False
# Add columns if they don't exist (for existing tables)
alter_queries = [
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS car_brand VARCHAR(255) DEFAULT NULL",
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS car_model VARCHAR(255) DEFAULT NULL",
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS car_body_type VARCHAR(255) DEFAULT NULL",
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS created_at TIMESTAMP DEFAULT NOW()",
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP DEFAULT NOW()"
]
for alter_query in alter_queries:
try:
alter_result = self.execute_query(alter_query)
if alter_result.success:
logger.debug(f"Executed: {alter_query}")
else:
# Check if it's just a "column already exists" error
if "already exists" not in (alter_result.error or "").lower():
logger.warning(f"ALTER TABLE failed: {alter_result.error}")
except Exception as e:
# Ignore errors if column already exists (for older PostgreSQL versions)
if "already exists" in str(e).lower():
logger.debug(f"Column already exists, skipping: {alter_query}")
else:
logger.warning(f"Error in ALTER TABLE: {e}")
logger.info("Successfully created/verified car_frontal_info table with all required columns")
return True
except Exception as e:
logger.error(f"Failed to create car_frontal_info table: {e}")
return False
def insert_initial_detection(self,
display_id: str,
captured_timestamp: str,
session_id: Optional[str] = None) -> Optional[str]:
"""
Insert initial detection record and return the session_id.
Args:
display_id: Display identifier
captured_timestamp: Timestamp of capture
session_id: Optional session ID (generated if not provided)
Returns:
Session ID if successful, None otherwise
"""
# Generate session_id if not provided
if not session_id:
session_id = str(uuid.uuid4())
try:
# Ensure table exists
if not self.create_car_frontal_info_table():
logger.error("Failed to create/verify table before insertion")
return None
schema_name = self.config.schema
insert_query = f"""
INSERT INTO {schema_name}.car_frontal_info
(display_id, captured_timestamp, session_id, license_character, license_type, car_brand, car_model, car_body_type, created_at)
VALUES (%s, %s, %s, NULL, 'No model available', NULL, NULL, NULL, NOW())
ON CONFLICT (session_id) DO NOTHING
"""
result = self.execute_query(insert_query, (display_id, captured_timestamp, session_id))
if result.success:
logger.info(f"Inserted initial detection record with session_id: {session_id}")
return session_id
else:
logger.error(f"Failed to insert initial detection record: {result.error}")
return None
except Exception as e:
logger.error(f"Failed to insert initial detection record: {e}")
return None
def execute_update(self,
table: str,
key_field: str,
key_value: str,
fields: Dict[str, Any]) -> bool:
"""
Execute dynamic update/insert operation.
Args:
table: Table name
key_field: Primary key field name
key_value: Primary key value
fields: Dictionary of fields to update
Returns:
True if successful, False otherwise
"""
try:
# Add schema prefix if table doesn't already have it
if '.' not in table:
table = f"{self.config.schema}.{table}"
# Build the INSERT and UPDATE query dynamically
insert_placeholders = []
insert_values = [key_value] # Start with key_value
set_clauses = []
update_values = []
for field, value in fields.items():
if value == "NOW()":
# Special handling for NOW()
insert_placeholders.append("NOW()")
set_clauses.append(f"{field} = NOW()")
else:
insert_placeholders.append("%s")
insert_values.append(value)
set_clauses.append(f"{field} = %s")
update_values.append(value)
# Build the complete query
query = f"""
INSERT INTO {table} ({key_field}, {', '.join(fields.keys())})
VALUES (%s, {', '.join(insert_placeholders)})
ON CONFLICT ({key_field})
DO UPDATE SET {', '.join(set_clauses)}
"""
# Combine values for the query: insert_values + update_values
all_values = tuple(insert_values + update_values)
result = self.execute_query(query, all_values)
if result.success:
logger.info(f"✅ Updated {table} for {key_field}={key_value} with {len(fields)} fields")
logger.debug(f"Updated fields: {fields}")
return True
else:
logger.error(f"❌ Failed to update {table}: {result.error}")
return False
except Exception as e:
logger.error(f"❌ Failed to execute update on {table}: {e}")
return False
def update_car_info(self,
session_id: str,
brand: str,
model: str,
body_type: str) -> bool:
"""
Update car information for a session.
Args:
session_id: Session identifier
brand: Car brand
model: Car model
body_type: Car body type
Returns:
True if successful, False otherwise
"""
schema_name = self.config.schema
query = f"""
INSERT INTO {schema_name}.car_frontal_info (session_id, car_brand, car_model, car_body_type, updated_at)
VALUES (%s, %s, %s, %s, NOW())
ON CONFLICT (session_id)
DO UPDATE SET
car_brand = EXCLUDED.car_brand,
car_model = EXCLUDED.car_model,
car_body_type = EXCLUDED.car_body_type,
updated_at = NOW()
"""
result = self.execute_query(query, (session_id, brand, model, body_type))
if result.success:
logger.info(f"Updated car info for session {session_id}: {brand} {model} ({body_type})")
return result.success
def get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]:
"""
Get session data by session ID.
Args:
session_id: Session identifier
Returns:
Session data dictionary or None if not found
"""
schema_name = self.config.schema
query = f"SELECT * FROM {schema_name}.car_frontal_info WHERE session_id = %s"
result = self.execute_query(query, (session_id,), fetch_results=True)
if result.success and result.data:
return result.data[0]
return None
def get_connection_stats(self) -> Dict[str, Any]:
"""
Get database connection statistics.
Returns:
Dictionary with connection statistics
"""
stats = {
"connected": self.is_connected(),
"config": self.config.to_dict(),
"connection_closed": self.connection.closed if self.connection else True
}
if self.is_connected():
try:
# Get database version and basic stats
version_result = self.execute_query("SELECT version()", fetch_results=True)
if version_result.success and version_result.data:
stats["database_version"] = version_result.data[0]["version"]
# Get current database name
db_result = self.execute_query("SELECT current_database()", fetch_results=True)
if db_result.success and db_result.data:
stats["current_database"] = db_result.data[0]["current_database"]
except Exception as e:
stats["stats_error"] = str(e)
return stats
# ===== CONVENIENCE FUNCTIONS =====
# These provide compatibility with the original database.py interface
def validate_postgresql_config(config: Dict[str, Any]) -> bool:
"""
Validate PostgreSQL configuration.
Args:
config: Configuration dictionary
Returns:
True if configuration is valid
"""
required_fields = ["host", "port", "database", "password"]
for field in required_fields:
if field not in config:
logger.error(f"Missing required PostgreSQL config field: {field}")
return False
if not config[field]:
logger.error(f"Empty PostgreSQL config field: {field}")
return False
# Check for username (could be 'username' or 'user')
if not config.get("username") and not config.get("user"):
logger.error("Missing PostgreSQL username (provide either 'username' or 'user')")
return False
# Validate port is numeric
try:
port = int(config["port"])
if port <= 0 or port > 65535:
logger.error(f"Invalid PostgreSQL port: {port}")
return False
except (ValueError, TypeError):
logger.error(f"PostgreSQL port must be numeric: {config['port']}")
return False
return True
def create_database_manager(config: Dict[str, Any]) -> Optional[DatabaseManager]:
"""
Create database manager with configuration validation.
Args:
config: Database configuration
Returns:
DatabaseManager instance or None if invalid config
"""
if not validate_postgresql_config(config):
return None
try:
db_manager = DatabaseManager(config)
if db_manager.connect():
logger.info(f"Successfully created database manager for {config['host']}:{config['port']}")
return db_manager
else:
logger.error("Failed to establish initial database connection")
return None
except Exception as e:
logger.error(f"Error creating database manager: {e}")
return None

View file

@ -0,0 +1,733 @@
"""
Redis client management and operations.
This module provides comprehensive Redis functionality including:
- Connection management with retries
- Key-value operations
- Pub/Sub messaging
- Image storage with compression
- Pipeline operations
"""
import redis
import json
import time
import logging
from typing import Dict, Any, Optional, List, Union, Tuple
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from ..core.constants import (
REDIS_CONNECTION_TIMEOUT,
REDIS_SOCKET_TIMEOUT,
REDIS_IMAGE_DEFAULT_QUALITY,
REDIS_IMAGE_DEFAULT_FORMAT
)
from ..core.exceptions import RedisError
logger = logging.getLogger(__name__)
@dataclass
class RedisConfig:
"""Redis connection configuration."""
host: str
port: int
password: Optional[str] = None
db: int = 0
connection_timeout: int = REDIS_CONNECTION_TIMEOUT
socket_timeout: int = REDIS_SOCKET_TIMEOUT
retry_on_timeout: bool = True
health_check_interval: int = 30
max_connections: int = 10
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"host": self.host,
"port": self.port,
"password": self.password,
"db": self.db,
"connection_timeout": self.connection_timeout,
"socket_timeout": self.socket_timeout,
"retry_on_timeout": self.retry_on_timeout,
"health_check_interval": self.health_check_interval,
"max_connections": self.max_connections
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'RedisConfig':
"""Create from dictionary data."""
return cls(
host=data["host"],
port=data["port"],
password=data.get("password"),
db=data.get("db", 0),
connection_timeout=data.get("connection_timeout", REDIS_CONNECTION_TIMEOUT),
socket_timeout=data.get("socket_timeout", REDIS_SOCKET_TIMEOUT),
retry_on_timeout=data.get("retry_on_timeout", True),
health_check_interval=data.get("health_check_interval", 30),
max_connections=data.get("max_connections", 10)
)
@dataclass
class RedisOperationResult:
"""Result from Redis operation."""
success: bool
data: Optional[Any] = None
error: Optional[str] = None
operation: Optional[str] = None
key: Optional[str] = None
execution_time: Optional[float] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
result = {
"success": self.success,
"operation": self.operation
}
if self.data is not None:
result["data"] = self.data
if self.error:
result["error"] = self.error
if self.key:
result["key"] = self.key
if self.execution_time:
result["execution_time"] = self.execution_time
return result
class RedisClientManager:
"""
Comprehensive Redis client manager with connection pooling and retry logic.
This class provides high-level Redis operations with automatic reconnection,
connection pooling, and comprehensive error handling.
"""
def __init__(self, config: Union[Dict[str, Any], RedisConfig]):
"""
Initialize Redis client manager.
Args:
config: Redis configuration dictionary or RedisConfig object
"""
if isinstance(config, dict):
self.config = RedisConfig.from_dict(config)
else:
self.config = config
self.client: Optional[redis.Redis] = None
self.connection_pool: Optional[redis.ConnectionPool] = None
self._lock = None
self._last_health_check = 0.0
def _ensure_thread_safety(self):
"""Initialize thread safety if not already done."""
if self._lock is None:
import threading
self._lock = threading.RLock()
def _create_connection_pool(self) -> redis.ConnectionPool:
"""Create Redis connection pool."""
pool_kwargs = {
"host": self.config.host,
"port": self.config.port,
"db": self.config.db,
"socket_timeout": self.config.socket_timeout,
"socket_connect_timeout": self.config.connection_timeout,
"retry_on_timeout": self.config.retry_on_timeout,
"health_check_interval": self.config.health_check_interval,
"max_connections": self.config.max_connections
}
if self.config.password:
pool_kwargs["password"] = self.config.password
return redis.ConnectionPool(**pool_kwargs)
def connect(self) -> bool:
"""
Establish Redis connection.
Returns:
True if connection successful, False otherwise
"""
self._ensure_thread_safety()
with self._lock:
try:
# Create connection pool
self.connection_pool = self._create_connection_pool()
# Create Redis client
self.client = redis.Redis(connection_pool=self.connection_pool)
# Test connection
self.client.ping()
self._last_health_check = time.time()
logger.info(f"Redis connection established successfully to {self.config.host}:{self.config.port}")
return True
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
self.client = None
self.connection_pool = None
return False
def disconnect(self) -> None:
"""Close Redis connection and cleanup resources."""
self._ensure_thread_safety()
with self._lock:
if self.connection_pool:
try:
self.connection_pool.disconnect()
logger.info("Redis connection pool disconnected")
except Exception as e:
logger.error(f"Error disconnecting Redis pool: {e}")
finally:
self.connection_pool = None
self.client = None
def is_connected(self) -> bool:
"""
Check if Redis connection is active.
Returns:
True if connected, False otherwise
"""
self._ensure_thread_safety()
with self._lock:
try:
if self.client:
# Perform periodic health check
current_time = time.time()
if current_time - self._last_health_check > self.config.health_check_interval:
self.client.ping()
self._last_health_check = current_time
return True
except Exception as e:
logger.debug(f"Redis connection check failed: {e}")
return False
def _ensure_connected(self) -> bool:
"""Ensure Redis connection is active."""
if not self.is_connected():
return self.connect()
return True
def _execute_operation(self,
operation_name: str,
operation_func,
key: Optional[str] = None) -> RedisOperationResult:
"""Execute Redis operation with error handling and timing."""
if not self._ensure_connected():
return RedisOperationResult(
success=False,
error="Failed to establish Redis connection",
operation=operation_name,
key=key
)
start_time = time.time()
try:
result = operation_func()
execution_time = time.time() - start_time
return RedisOperationResult(
success=True,
data=result,
operation=operation_name,
key=key,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
error_msg = f"Redis {operation_name} operation failed: {e}"
logger.error(error_msg)
return RedisOperationResult(
success=False,
error=error_msg,
operation=operation_name,
key=key,
execution_time=execution_time
)
# ===== KEY-VALUE OPERATIONS =====
def set(self, key: str, value: Any, ex: Optional[int] = None) -> RedisOperationResult:
"""
Set key-value pair with optional expiration.
Args:
key: Redis key
value: Value to store
ex: Expiration time in seconds
Returns:
RedisOperationResult with operation status
"""
def operation():
return self.client.set(key, value, ex=ex)
result = self._execute_operation("SET", operation, key)
if result.success:
expire_msg = f" (expires in {ex}s)" if ex else ""
logger.debug(f"Set Redis key '{key}'{expire_msg}")
return result
def setex(self, key: str, time: int, value: Any) -> RedisOperationResult:
"""
Set key-value pair with expiration time.
Args:
key: Redis key
time: Expiration time in seconds
value: Value to store
Returns:
RedisOperationResult with operation status
"""
def operation():
return self.client.setex(key, time, value)
result = self._execute_operation("SETEX", operation, key)
if result.success:
logger.debug(f"Set Redis key '{key}' with {time}s expiration")
return result
def get(self, key: str) -> RedisOperationResult:
"""
Get value by key.
Args:
key: Redis key
Returns:
RedisOperationResult with value data
"""
def operation():
return self.client.get(key)
return self._execute_operation("GET", operation, key)
def delete(self, *keys: str) -> RedisOperationResult:
"""
Delete one or more keys.
Args:
*keys: Redis keys to delete
Returns:
RedisOperationResult with number of deleted keys
"""
def operation():
return self.client.delete(*keys)
result = self._execute_operation("DELETE", operation)
if result.success:
logger.debug(f"Deleted {result.data} Redis key(s): {keys}")
return result
def exists(self, *keys: str) -> RedisOperationResult:
"""
Check if keys exist.
Args:
*keys: Redis keys to check
Returns:
RedisOperationResult with count of existing keys
"""
def operation():
return self.client.exists(*keys)
return self._execute_operation("EXISTS", operation)
def expire(self, key: str, time: int) -> RedisOperationResult:
"""
Set expiration time for a key.
Args:
key: Redis key
time: Expiration time in seconds
Returns:
RedisOperationResult with operation status
"""
def operation():
return self.client.expire(key, time)
return self._execute_operation("EXPIRE", operation, key)
def ttl(self, key: str) -> RedisOperationResult:
"""
Get time to live for a key.
Args:
key: Redis key
Returns:
RedisOperationResult with TTL in seconds
"""
def operation():
return self.client.ttl(key)
return self._execute_operation("TTL", operation, key)
# ===== PUB/SUB OPERATIONS =====
def publish(self, channel: str, message: Union[str, Dict[str, Any]]) -> RedisOperationResult:
"""
Publish message to channel.
Args:
channel: Channel name
message: Message to publish (string or dict)
Returns:
RedisOperationResult with number of subscribers
"""
def operation():
# Convert dict to JSON string
if isinstance(message, dict):
message_str = json.dumps(message)
else:
message_str = str(message)
return self.client.publish(channel, message_str)
result = self._execute_operation("PUBLISH", operation)
if result.success:
logger.debug(f"Published to channel '{channel}', subscribers: {result.data}")
if result.data == 0:
logger.warning(f"No subscribers listening to channel '{channel}'")
return result
def subscribe(self, *channels: str):
"""
Subscribe to channels (returns pubsub object).
Args:
*channels: Channel names to subscribe to
Returns:
PubSub object for listening to messages
"""
if not self._ensure_connected():
raise RedisError("Failed to establish Redis connection for subscription")
pubsub = self.client.pubsub()
pubsub.subscribe(*channels)
logger.info(f"Subscribed to channels: {channels}")
return pubsub
# ===== HASH OPERATIONS =====
def hset(self, key: str, field: str, value: Any) -> RedisOperationResult:
"""
Set hash field.
Args:
key: Redis key
field: Hash field name
value: Field value
Returns:
RedisOperationResult with operation status
"""
def operation():
return self.client.hset(key, field, value)
return self._execute_operation("HSET", operation, key)
def hget(self, key: str, field: str) -> RedisOperationResult:
"""
Get hash field value.
Args:
key: Redis key
field: Hash field name
Returns:
RedisOperationResult with field value
"""
def operation():
return self.client.hget(key, field)
return self._execute_operation("HGET", operation, key)
def hgetall(self, key: str) -> RedisOperationResult:
"""
Get all hash fields and values.
Args:
key: Redis key
Returns:
RedisOperationResult with hash dictionary
"""
def operation():
return self.client.hgetall(key)
return self._execute_operation("HGETALL", operation, key)
# ===== LIST OPERATIONS =====
def lpush(self, key: str, *values: Any) -> RedisOperationResult:
"""
Push values to left of list.
Args:
key: Redis key
*values: Values to push
Returns:
RedisOperationResult with list length
"""
def operation():
return self.client.lpush(key, *values)
return self._execute_operation("LPUSH", operation, key)
def rpush(self, key: str, *values: Any) -> RedisOperationResult:
"""
Push values to right of list.
Args:
key: Redis key
*values: Values to push
Returns:
RedisOperationResult with list length
"""
def operation():
return self.client.rpush(key, *values)
return self._execute_operation("RPUSH", operation, key)
def lrange(self, key: str, start: int, end: int) -> RedisOperationResult:
"""
Get range of list elements.
Args:
key: Redis key
start: Start index
end: End index
Returns:
RedisOperationResult with list elements
"""
def operation():
return self.client.lrange(key, start, end)
return self._execute_operation("LRANGE", operation, key)
# ===== UTILITY OPERATIONS =====
def ping(self) -> RedisOperationResult:
"""
Ping Redis server.
Returns:
RedisOperationResult with ping response
"""
def operation():
return self.client.ping()
return self._execute_operation("PING", operation)
def info(self, section: Optional[str] = None) -> RedisOperationResult:
"""
Get Redis server information.
Args:
section: Optional info section
Returns:
RedisOperationResult with server info
"""
def operation():
return self.client.info(section)
return self._execute_operation("INFO", operation)
def flushdb(self) -> RedisOperationResult:
"""
Flush current database.
Returns:
RedisOperationResult with operation status
"""
def operation():
return self.client.flushdb()
result = self._execute_operation("FLUSHDB", operation)
if result.success:
logger.warning(f"Flushed Redis database {self.config.db}")
return result
def keys(self, pattern: str = "*") -> RedisOperationResult:
"""
Get keys matching pattern.
Args:
pattern: Key pattern (default: all keys)
Returns:
RedisOperationResult with matching keys
"""
def operation():
return self.client.keys(pattern)
return self._execute_operation("KEYS", operation)
# ===== BATCH OPERATIONS =====
def pipeline(self):
"""
Create Redis pipeline for batch operations.
Returns:
Redis pipeline object
"""
if not self._ensure_connected():
raise RedisError("Failed to establish Redis connection for pipeline")
return self.client.pipeline()
def execute_pipeline(self, pipeline) -> RedisOperationResult:
"""
Execute Redis pipeline.
Args:
pipeline: Redis pipeline object
Returns:
RedisOperationResult with pipeline results
"""
def operation():
return pipeline.execute()
return self._execute_operation("PIPELINE", operation)
# ===== CONNECTION MANAGEMENT =====
def get_connection_stats(self) -> Dict[str, Any]:
"""
Get Redis connection statistics.
Returns:
Dictionary with connection statistics
"""
stats = {
"connected": self.is_connected(),
"config": self.config.to_dict(),
"last_health_check": self._last_health_check,
"connection_pool_created": self.connection_pool is not None
}
if self.connection_pool:
stats["connection_pool_stats"] = {
"created_connections": self.connection_pool.created_connections,
"available_connections": len(self.connection_pool._available_connections),
"in_use_connections": len(self.connection_pool._in_use_connections)
}
# Get Redis server info if connected
if self.is_connected():
try:
info_result = self.info()
if info_result.success:
redis_info = info_result.data
stats["server_info"] = {
"redis_version": redis_info.get("redis_version"),
"connected_clients": redis_info.get("connected_clients"),
"used_memory": redis_info.get("used_memory"),
"used_memory_human": redis_info.get("used_memory_human"),
"total_commands_processed": redis_info.get("total_commands_processed"),
"uptime_in_seconds": redis_info.get("uptime_in_seconds")
}
except Exception as e:
stats["server_info_error"] = str(e)
return stats
# ===== CONVENIENCE FUNCTIONS =====
# These provide compatibility and simplified access
def create_redis_client(config: Dict[str, Any]) -> Optional[RedisClientManager]:
"""
Create Redis client with configuration validation.
Args:
config: Redis configuration
Returns:
RedisClientManager instance or None if connection failed
"""
try:
client_manager = RedisClientManager(config)
if client_manager.connect():
logger.info(f"Successfully created Redis client for {config['host']}:{config['port']}")
return client_manager
else:
logger.error("Failed to establish initial Redis connection")
return None
except Exception as e:
logger.error(f"Error creating Redis client: {e}")
return None
def validate_redis_config(config: Dict[str, Any]) -> bool:
"""
Validate Redis configuration.
Args:
config: Configuration dictionary
Returns:
True if configuration is valid
"""
required_fields = ["host", "port"]
for field in required_fields:
if field not in config:
logger.error(f"Missing required Redis config field: {field}")
return False
if not config[field]:
logger.error(f"Empty Redis config field: {field}")
return False
# Validate port is numeric
try:
port = int(config["port"])
if port <= 0 or port > 65535:
logger.error(f"Invalid Redis port: {port}")
return False
except (ValueError, TypeError):
logger.error(f"Redis port must be numeric: {config['port']}")
return False
return True

View file

@ -0,0 +1,688 @@
"""
Session and cache management for detection workflows.
This module provides comprehensive session and cache management including:
- Session state tracking and lifecycle management
- Detection result caching with TTL
- Pipeline mode state management
- Session cleanup and garbage collection
"""
import time
import logging
from typing import Dict, List, Any, Optional, Set, Tuple
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from ..core.constants import (
SESSION_CACHE_TTL_MINUTES,
DETECTION_CACHE_CLEANUP_INTERVAL,
SESSION_TIMEOUT_SECONDS
)
from ..core.exceptions import SessionError
logger = logging.getLogger(__name__)
class PipelineMode(Enum):
"""Pipeline execution modes."""
VALIDATION_DETECTING = "validation_detecting"
SEND_DETECTIONS = "send_detections"
WAITING_FOR_SESSION_ID = "waiting_for_session_id"
FULL_PIPELINE = "full_pipeline"
LIGHTWEIGHT = "lightweight"
CAR_GONE_WAITING = "car_gone_waiting"
@dataclass
class SessionState:
"""Session state information."""
session_id: Optional[str] = None
backend_session_id: Optional[str] = None
mode: PipelineMode = PipelineMode.VALIDATION_DETECTING
session_id_received: bool = False
created_at: float = field(default_factory=time.time)
last_updated: float = field(default_factory=time.time)
last_detection: Optional[float] = None
detection_count: int = 0
# Mode-specific state
validation_frames_processed: int = 0
stable_track_achieved: bool = False
waiting_start_time: Optional[float] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"session_id": self.session_id,
"backend_session_id": self.backend_session_id,
"mode": self.mode.value,
"session_id_received": self.session_id_received,
"created_at": self.created_at,
"last_updated": self.last_updated,
"last_detection": self.last_detection,
"detection_count": self.detection_count,
"validation_frames_processed": self.validation_frames_processed,
"stable_track_achieved": self.stable_track_achieved,
"waiting_start_time": self.waiting_start_time
}
def update_activity(self) -> None:
"""Update last activity timestamp."""
self.last_updated = time.time()
def record_detection(self) -> None:
"""Record a detection occurrence."""
current_time = time.time()
self.last_detection = current_time
self.detection_count += 1
self.update_activity()
def is_expired(self, ttl_seconds: int) -> bool:
"""Check if session has expired based on TTL."""
return time.time() - self.last_updated > ttl_seconds
@dataclass
class CachedDetection:
"""Cached detection result."""
detection_data: Dict[str, Any]
camera_id: str
session_id: Optional[str] = None
created_at: float = field(default_factory=time.time)
access_count: int = 0
last_accessed: float = field(default_factory=time.time)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"detection_data": self.detection_data,
"camera_id": self.camera_id,
"session_id": self.session_id,
"created_at": self.created_at,
"access_count": self.access_count,
"last_accessed": self.last_accessed
}
def access(self) -> None:
"""Record access to this cached detection."""
self.access_count += 1
self.last_accessed = time.time()
def is_expired(self, ttl_seconds: int) -> bool:
"""Check if cached detection has expired."""
return time.time() - self.created_at > ttl_seconds
class SessionManager:
"""
Session lifecycle and state management.
This class provides comprehensive session management including:
- Session creation and cleanup
- Pipeline mode transitions
- Session timeout handling
"""
def __init__(self, session_timeout_seconds: int = SESSION_TIMEOUT_SECONDS):
"""
Initialize session manager.
Args:
session_timeout_seconds: Default session timeout
"""
self.session_timeout_seconds = session_timeout_seconds
self._sessions: Dict[str, SessionState] = {}
self._session_ids: Dict[str, str] = {} # display_id -> session_id mapping
self._lock = None
def _ensure_thread_safety(self):
"""Initialize thread safety if not already done."""
if self._lock is None:
import threading
self._lock = threading.RLock()
def create_session(self, display_id: str, session_id: Optional[str] = None) -> str:
"""
Create a new session or get existing one.
Args:
display_id: Display identifier
session_id: Optional session ID
Returns:
Session ID
"""
self._ensure_thread_safety()
with self._lock:
# Check if session already exists for this display
if display_id in self._session_ids:
existing_session_id = self._session_ids[display_id]
if existing_session_id in self._sessions:
session_state = self._sessions[existing_session_id]
session_state.update_activity()
logger.debug(f"Using existing session for display {display_id}: {existing_session_id}")
return existing_session_id
# Create new session
if not session_id:
import uuid
session_id = str(uuid.uuid4())
session_state = SessionState(session_id=session_id)
self._sessions[session_id] = session_state
self._session_ids[display_id] = session_id
logger.info(f"Created new session for display {display_id}: {session_id}")
return session_id
def get_session(self, session_id: str) -> Optional[SessionState]:
"""
Get session state by session ID.
Args:
session_id: Session identifier
Returns:
SessionState or None if not found
"""
self._ensure_thread_safety()
with self._lock:
session = self._sessions.get(session_id)
if session:
session.update_activity()
return session
def get_session_by_display(self, display_id: str) -> Optional[SessionState]:
"""
Get session state by display ID.
Args:
display_id: Display identifier
Returns:
SessionState or None if not found
"""
self._ensure_thread_safety()
with self._lock:
session_id = self._session_ids.get(display_id)
if session_id:
return self.get_session(session_id)
return None
def update_session_mode(self,
session_id: str,
new_mode: PipelineMode,
backend_session_id: Optional[str] = None) -> bool:
"""
Update session pipeline mode.
Args:
session_id: Session identifier
new_mode: New pipeline mode
backend_session_id: Optional backend session ID
Returns:
True if updated successfully
"""
self._ensure_thread_safety()
with self._lock:
session = self.get_session(session_id)
if not session:
logger.warning(f"Session not found for mode update: {session_id}")
return False
old_mode = session.mode
session.mode = new_mode
if backend_session_id:
session.backend_session_id = backend_session_id
session.session_id_received = True
# Handle mode-specific state changes
if new_mode == PipelineMode.WAITING_FOR_SESSION_ID:
session.waiting_start_time = time.time()
elif old_mode == PipelineMode.WAITING_FOR_SESSION_ID:
session.waiting_start_time = None
session.update_activity()
logger.info(f"Session {session_id}: Mode changed from {old_mode.value} to {new_mode.value}")
return True
def record_detection(self, session_id: str) -> bool:
"""
Record a detection for a session.
Args:
session_id: Session identifier
Returns:
True if recorded successfully
"""
self._ensure_thread_safety()
with self._lock:
session = self.get_session(session_id)
if session:
session.record_detection()
return True
return False
def cleanup_expired_sessions(self, ttl_seconds: Optional[int] = None) -> int:
"""
Clean up expired sessions.
Args:
ttl_seconds: TTL in seconds (uses default if not provided)
Returns:
Number of sessions cleaned up
"""
self._ensure_thread_safety()
if ttl_seconds is None:
ttl_seconds = SESSION_CACHE_TTL_MINUTES * 60
with self._lock:
expired_sessions = []
expired_displays = []
# Find expired sessions
for session_id, session in self._sessions.items():
if session.is_expired(ttl_seconds):
expired_sessions.append(session_id)
# Find displays pointing to expired sessions
for display_id, session_id in self._session_ids.items():
if session_id in expired_sessions:
expired_displays.append(display_id)
# Remove expired sessions and mappings
for session_id in expired_sessions:
del self._sessions[session_id]
for display_id in expired_displays:
del self._session_ids[display_id]
if expired_sessions:
logger.info(f"Cleaned up {len(expired_sessions)} expired sessions")
return len(expired_sessions)
def get_session_stats(self) -> Dict[str, Any]:
"""
Get session management statistics.
Returns:
Dictionary with session statistics
"""
self._ensure_thread_safety()
with self._lock:
current_time = time.time()
mode_counts = {}
total_detections = 0
for session in self._sessions.values():
mode = session.mode.value
mode_counts[mode] = mode_counts.get(mode, 0) + 1
total_detections += session.detection_count
return {
"total_sessions": len(self._sessions),
"total_display_mappings": len(self._session_ids),
"mode_distribution": mode_counts,
"total_detections_processed": total_detections,
"session_timeout_seconds": self.session_timeout_seconds
}
class DetectionCache:
"""
Detection result caching with TTL and access tracking.
This class provides caching for detection results with automatic
expiration and access pattern tracking.
"""
def __init__(self, ttl_minutes: int = SESSION_CACHE_TTL_MINUTES):
"""
Initialize detection cache.
Args:
ttl_minutes: Time to live for cached detections in minutes
"""
self.ttl_seconds = ttl_minutes * 60
self._cache: Dict[str, CachedDetection] = {}
self._camera_index: Dict[str, Set[str]] = {} # camera_id -> set of cache keys
self._session_index: Dict[str, Set[str]] = {} # session_id -> set of cache keys
self._lock = None
self._last_cleanup = time.time()
def _ensure_thread_safety(self):
"""Initialize thread safety if not already done."""
if self._lock is None:
import threading
self._lock = threading.RLock()
def _generate_cache_key(self, camera_id: str, detection_type: str = "default") -> str:
"""Generate cache key for detection."""
return f"{camera_id}:{detection_type}:{time.time()}"
def store_detection(self,
camera_id: str,
detection_data: Dict[str, Any],
session_id: Optional[str] = None,
detection_type: str = "default") -> str:
"""
Store detection in cache.
Args:
camera_id: Camera identifier
detection_data: Detection result data
session_id: Optional session identifier
detection_type: Type of detection for categorization
Returns:
Cache key for the stored detection
"""
self._ensure_thread_safety()
with self._lock:
cache_key = self._generate_cache_key(camera_id, detection_type)
cached_detection = CachedDetection(
detection_data=detection_data,
camera_id=camera_id,
session_id=session_id
)
self._cache[cache_key] = cached_detection
# Update indexes
if camera_id not in self._camera_index:
self._camera_index[camera_id] = set()
self._camera_index[camera_id].add(cache_key)
if session_id:
if session_id not in self._session_index:
self._session_index[session_id] = set()
self._session_index[session_id].add(cache_key)
logger.debug(f"Stored detection in cache: {cache_key}")
return cache_key
def get_detection(self, cache_key: str) -> Optional[Dict[str, Any]]:
"""
Get detection from cache by key.
Args:
cache_key: Cache key
Returns:
Detection data or None if not found/expired
"""
self._ensure_thread_safety()
with self._lock:
cached_detection = self._cache.get(cache_key)
if not cached_detection:
return None
if cached_detection.is_expired(self.ttl_seconds):
self._remove_from_cache(cache_key)
return None
cached_detection.access()
return cached_detection.detection_data
def get_latest_detection(self, camera_id: str) -> Optional[Dict[str, Any]]:
"""
Get latest detection for a camera.
Args:
camera_id: Camera identifier
Returns:
Latest detection data or None if not found
"""
self._ensure_thread_safety()
with self._lock:
camera_keys = self._camera_index.get(camera_id, set())
if not camera_keys:
return None
# Find the most recent non-expired detection
latest_detection = None
latest_time = 0
for cache_key in camera_keys:
cached_detection = self._cache.get(cache_key)
if cached_detection and not cached_detection.is_expired(self.ttl_seconds):
if cached_detection.created_at > latest_time:
latest_time = cached_detection.created_at
latest_detection = cached_detection
if latest_detection:
latest_detection.access()
return latest_detection.detection_data
return None
def get_session_detections(self, session_id: str) -> List[Dict[str, Any]]:
"""
Get all detections for a session.
Args:
session_id: Session identifier
Returns:
List of detection data dictionaries
"""
self._ensure_thread_safety()
with self._lock:
session_keys = self._session_index.get(session_id, set())
detections = []
for cache_key in session_keys:
cached_detection = self._cache.get(cache_key)
if cached_detection and not cached_detection.is_expired(self.ttl_seconds):
cached_detection.access()
detections.append(cached_detection.detection_data)
# Sort by creation time (newest first)
detections.sort(key=lambda x: x.get("timestamp", 0), reverse=True)
return detections
def _remove_from_cache(self, cache_key: str) -> None:
"""Remove detection from cache and indexes."""
cached_detection = self._cache.get(cache_key)
if cached_detection:
# Remove from indexes
camera_id = cached_detection.camera_id
if camera_id in self._camera_index:
self._camera_index[camera_id].discard(cache_key)
if not self._camera_index[camera_id]:
del self._camera_index[camera_id]
session_id = cached_detection.session_id
if session_id and session_id in self._session_index:
self._session_index[session_id].discard(cache_key)
if not self._session_index[session_id]:
del self._session_index[session_id]
# Remove from main cache
if cache_key in self._cache:
del self._cache[cache_key]
def cleanup_expired_detections(self) -> int:
"""
Clean up expired cached detections.
Returns:
Number of detections cleaned up
"""
self._ensure_thread_safety()
with self._lock:
expired_keys = []
# Find expired detections
for cache_key, cached_detection in self._cache.items():
if cached_detection.is_expired(self.ttl_seconds):
expired_keys.append(cache_key)
# Remove expired detections
for cache_key in expired_keys:
self._remove_from_cache(cache_key)
if expired_keys:
logger.info(f"Cleaned up {len(expired_keys)} expired cached detections")
self._last_cleanup = time.time()
return len(expired_keys)
def get_cache_stats(self) -> Dict[str, Any]:
"""
Get cache statistics.
Returns:
Dictionary with cache statistics
"""
self._ensure_thread_safety()
with self._lock:
total_access_count = sum(detection.access_count for detection in self._cache.values())
return {
"total_cached_detections": len(self._cache),
"cameras_with_cache": len(self._camera_index),
"sessions_with_cache": len(self._session_index),
"total_access_count": total_access_count,
"ttl_seconds": self.ttl_seconds,
"last_cleanup": self._last_cleanup
}
class SessionCacheManager:
"""
Combined session and cache management.
This class provides unified management of sessions and detection caching
with automatic cleanup and comprehensive statistics.
"""
def __init__(self,
session_timeout_seconds: int = SESSION_TIMEOUT_SECONDS,
cache_ttl_minutes: int = SESSION_CACHE_TTL_MINUTES):
"""
Initialize session cache manager.
Args:
session_timeout_seconds: Session timeout in seconds
cache_ttl_minutes: Cache TTL in minutes
"""
self.session_manager = SessionManager(session_timeout_seconds)
self.detection_cache = DetectionCache(cache_ttl_minutes)
self._last_cleanup = time.time()
def cleanup_expired_data(self, force: bool = False) -> Dict[str, int]:
"""
Clean up expired sessions and cached detections.
Args:
force: Force cleanup regardless of interval
Returns:
Dictionary with cleanup statistics
"""
current_time = time.time()
# Check if cleanup is needed
if not force and current_time - self._last_cleanup < DETECTION_CACHE_CLEANUP_INTERVAL:
return {"sessions_cleaned": 0, "detections_cleaned": 0, "cleanup_skipped": True}
sessions_cleaned = self.session_manager.cleanup_expired_sessions()
detections_cleaned = self.detection_cache.cleanup_expired_detections()
self._last_cleanup = current_time
return {
"sessions_cleaned": sessions_cleaned,
"detections_cleaned": detections_cleaned,
"cleanup_skipped": False
}
def get_comprehensive_stats(self) -> Dict[str, Any]:
"""
Get comprehensive statistics for sessions and cache.
Returns:
Dictionary with all statistics
"""
return {
"session_stats": self.session_manager.get_session_stats(),
"cache_stats": self.detection_cache.get_cache_stats(),
"last_cleanup": self._last_cleanup,
"cleanup_interval_seconds": DETECTION_CACHE_CLEANUP_INTERVAL
}
# Global session cache manager instance
session_cache_manager = SessionCacheManager()
# ===== CONVENIENCE FUNCTIONS =====
# These provide simplified access to session and cache functionality
def create_session(display_id: str, session_id: Optional[str] = None) -> str:
"""Create a new session using global manager."""
return session_cache_manager.session_manager.create_session(display_id, session_id)
def get_session_state(session_id: str) -> Optional[Dict[str, Any]]:
"""Get session state by session ID."""
session = session_cache_manager.session_manager.get_session(session_id)
return session.to_dict() if session else None
def update_pipeline_mode(session_id: str,
new_mode: str,
backend_session_id: Optional[str] = None) -> bool:
"""Update session pipeline mode."""
try:
mode = PipelineMode(new_mode)
return session_cache_manager.session_manager.update_session_mode(session_id, mode, backend_session_id)
except ValueError:
logger.error(f"Invalid pipeline mode: {new_mode}")
return False
def cache_detection(camera_id: str,
detection_data: Dict[str, Any],
session_id: Optional[str] = None) -> str:
"""Cache detection data using global manager."""
return session_cache_manager.detection_cache.store_detection(camera_id, detection_data, session_id)
def get_cached_detection(camera_id: str) -> Optional[Dict[str, Any]]:
"""Get latest cached detection for a camera."""
return session_cache_manager.detection_cache.get_latest_detection(camera_id)
def cleanup_expired_sessions_and_cache() -> Dict[str, int]:
"""Clean up expired sessions and cached data."""
return session_cache_manager.cleanup_expired_data()
def get_session_and_cache_stats() -> Dict[str, Any]:
"""Get comprehensive session and cache statistics."""
return session_cache_manager.get_comprehensive_stats()