Refactor: PHASE 3: Action & Storage Extraction
This commit is contained in:
parent
4e9ae6bcc4
commit
cdeaaf4a4f
5 changed files with 3048 additions and 0 deletions
617
detector_worker/storage/database_manager.py
Normal file
617
detector_worker/storage/database_manager.py
Normal file
|
@ -0,0 +1,617 @@
|
|||
"""
|
||||
Database management and operations.
|
||||
|
||||
This module provides comprehensive database functionality including:
|
||||
- PostgreSQL connection management
|
||||
- Table schema management
|
||||
- Dynamic query execution
|
||||
- Transaction handling
|
||||
"""
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from ..core.constants import (
|
||||
DB_CONNECTION_TIMEOUT,
|
||||
DB_OPERATION_TIMEOUT,
|
||||
DB_RETRY_ATTEMPTS
|
||||
)
|
||||
from ..core.exceptions import DatabaseError, create_detection_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""Database connection configuration."""
|
||||
host: str
|
||||
port: int
|
||||
database: str
|
||||
username: str
|
||||
password: str
|
||||
schema: str = "gas_station_1"
|
||||
connection_timeout: int = DB_CONNECTION_TIMEOUT
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"database": self.database,
|
||||
"username": self.username,
|
||||
"password": self.password,
|
||||
"schema": self.schema,
|
||||
"connection_timeout": self.connection_timeout
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'DatabaseConfig':
|
||||
"""Create from dictionary data."""
|
||||
return cls(
|
||||
host=data["host"],
|
||||
port=data["port"],
|
||||
database=data["database"],
|
||||
username=data.get("username", data.get("user", "")),
|
||||
password=data["password"],
|
||||
schema=data.get("schema", "gas_station_1"),
|
||||
connection_timeout=data.get("connection_timeout", DB_CONNECTION_TIMEOUT)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryResult:
|
||||
"""Result from database query execution."""
|
||||
success: bool
|
||||
affected_rows: int = 0
|
||||
data: Optional[List[Dict[str, Any]]] = None
|
||||
error: Optional[str] = None
|
||||
query: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
result = {
|
||||
"success": self.success,
|
||||
"affected_rows": self.affected_rows
|
||||
}
|
||||
if self.data:
|
||||
result["data"] = self.data
|
||||
if self.error:
|
||||
result["error"] = self.error
|
||||
if self.query:
|
||||
result["query"] = self.query
|
||||
return result
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""
|
||||
Comprehensive database manager for PostgreSQL operations.
|
||||
|
||||
This class provides connection management, schema operations,
|
||||
and dynamic query execution with transaction support.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialize database manager.
|
||||
|
||||
Args:
|
||||
config: Database configuration dictionary
|
||||
"""
|
||||
if isinstance(config, dict):
|
||||
self.config = DatabaseConfig.from_dict(config)
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
self.connection: Optional[psycopg2.extensions.connection] = None
|
||||
self._lock = None
|
||||
|
||||
def _ensure_thread_safety(self):
|
||||
"""Initialize thread safety if not already done."""
|
||||
if self._lock is None:
|
||||
import threading
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""
|
||||
Establish database connection.
|
||||
|
||||
Returns:
|
||||
True if connection successful, False otherwise
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
try:
|
||||
if self.connection and not self.connection.closed:
|
||||
# Connection already exists and is open
|
||||
return True
|
||||
|
||||
self.connection = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
database=self.config.database,
|
||||
user=self.config.username,
|
||||
password=self.config.password,
|
||||
connect_timeout=self.config.connection_timeout
|
||||
)
|
||||
|
||||
# Set connection properties
|
||||
self.connection.set_client_encoding('UTF8')
|
||||
|
||||
logger.info(f"PostgreSQL connection established successfully to {self.config.host}:{self.config.port}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to PostgreSQL: {e}")
|
||||
self.connection = None
|
||||
return False
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Close database connection."""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
if self.connection:
|
||||
try:
|
||||
self.connection.close()
|
||||
logger.info("PostgreSQL connection closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing PostgreSQL connection: {e}")
|
||||
finally:
|
||||
self.connection = None
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""
|
||||
Check if database connection is active.
|
||||
|
||||
Returns:
|
||||
True if connected, False otherwise
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
try:
|
||||
if self.connection and not self.connection.closed:
|
||||
cur = self.connection.cursor()
|
||||
cur.execute("SELECT 1")
|
||||
cur.fetchone()
|
||||
cur.close()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"Connection check failed: {e}")
|
||||
|
||||
return False
|
||||
|
||||
def _ensure_connected(self) -> bool:
|
||||
"""Ensure database connection is active."""
|
||||
if not self.is_connected():
|
||||
return self.connect()
|
||||
return True
|
||||
|
||||
def execute_query(self,
|
||||
query: str,
|
||||
params: Optional[Tuple] = None,
|
||||
fetch_results: bool = False) -> QueryResult:
|
||||
"""
|
||||
Execute a database query.
|
||||
|
||||
Args:
|
||||
query: SQL query string
|
||||
params: Query parameters
|
||||
fetch_results: Whether to fetch and return results
|
||||
|
||||
Returns:
|
||||
QueryResult with execution details
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
if not self._ensure_connected():
|
||||
return QueryResult(
|
||||
success=False,
|
||||
error="Failed to establish database connection",
|
||||
query=query
|
||||
)
|
||||
|
||||
cursor = None
|
||||
try:
|
||||
cursor = self.connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
||||
|
||||
logger.debug(f"Executing query: {query}")
|
||||
if params:
|
||||
logger.debug(f"Query parameters: {params}")
|
||||
|
||||
cursor.execute(query, params)
|
||||
affected_rows = cursor.rowcount
|
||||
|
||||
data = None
|
||||
if fetch_results:
|
||||
data = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
self.connection.commit()
|
||||
|
||||
logger.debug(f"Query executed successfully, affected rows: {affected_rows}")
|
||||
|
||||
return QueryResult(
|
||||
success=True,
|
||||
affected_rows=affected_rows,
|
||||
data=data,
|
||||
query=query
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Database query execution failed: {e}"
|
||||
logger.error(error_msg)
|
||||
logger.debug(f"Failed query: {query}")
|
||||
if params:
|
||||
logger.debug(f"Query parameters: {params}")
|
||||
|
||||
if self.connection:
|
||||
try:
|
||||
self.connection.rollback()
|
||||
except Exception as rollback_error:
|
||||
logger.error(f"Rollback failed: {rollback_error}")
|
||||
|
||||
return QueryResult(
|
||||
success=False,
|
||||
error=error_msg,
|
||||
query=query
|
||||
)
|
||||
finally:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
|
||||
def create_schema_if_not_exists(self, schema_name: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Create database schema if it doesn't exist.
|
||||
|
||||
Args:
|
||||
schema_name: Schema name (uses config schema if not provided)
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
schema_name = schema_name or self.config.schema
|
||||
query = f"CREATE SCHEMA IF NOT EXISTS {schema_name}"
|
||||
result = self.execute_query(query)
|
||||
|
||||
if result.success:
|
||||
logger.info(f"Schema '{schema_name}' created or verified successfully")
|
||||
|
||||
return result.success
|
||||
|
||||
def create_car_frontal_info_table(self) -> bool:
|
||||
"""
|
||||
Create the car_frontal_info table in the configured schema if it doesn't exist.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
schema_name = self.config.schema
|
||||
|
||||
# Ensure schema exists
|
||||
if not self.create_schema_if_not_exists(schema_name):
|
||||
return False
|
||||
|
||||
try:
|
||||
# Create table if it doesn't exist
|
||||
create_table_query = f"""
|
||||
CREATE TABLE IF NOT EXISTS {schema_name}.car_frontal_info (
|
||||
display_id VARCHAR(255),
|
||||
captured_timestamp VARCHAR(255),
|
||||
session_id VARCHAR(255) PRIMARY KEY,
|
||||
license_character VARCHAR(255) DEFAULT NULL,
|
||||
license_type VARCHAR(255) DEFAULT 'No model available',
|
||||
car_brand VARCHAR(255) DEFAULT NULL,
|
||||
car_model VARCHAR(255) DEFAULT NULL,
|
||||
car_body_type VARCHAR(255) DEFAULT NULL,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
)
|
||||
"""
|
||||
|
||||
result = self.execute_query(create_table_query)
|
||||
if not result.success:
|
||||
return False
|
||||
|
||||
# Add columns if they don't exist (for existing tables)
|
||||
alter_queries = [
|
||||
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS car_brand VARCHAR(255) DEFAULT NULL",
|
||||
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS car_model VARCHAR(255) DEFAULT NULL",
|
||||
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS car_body_type VARCHAR(255) DEFAULT NULL",
|
||||
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS created_at TIMESTAMP DEFAULT NOW()",
|
||||
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP DEFAULT NOW()"
|
||||
]
|
||||
|
||||
for alter_query in alter_queries:
|
||||
try:
|
||||
alter_result = self.execute_query(alter_query)
|
||||
if alter_result.success:
|
||||
logger.debug(f"Executed: {alter_query}")
|
||||
else:
|
||||
# Check if it's just a "column already exists" error
|
||||
if "already exists" not in (alter_result.error or "").lower():
|
||||
logger.warning(f"ALTER TABLE failed: {alter_result.error}")
|
||||
except Exception as e:
|
||||
# Ignore errors if column already exists (for older PostgreSQL versions)
|
||||
if "already exists" in str(e).lower():
|
||||
logger.debug(f"Column already exists, skipping: {alter_query}")
|
||||
else:
|
||||
logger.warning(f"Error in ALTER TABLE: {e}")
|
||||
|
||||
logger.info("Successfully created/verified car_frontal_info table with all required columns")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create car_frontal_info table: {e}")
|
||||
return False
|
||||
|
||||
def insert_initial_detection(self,
|
||||
display_id: str,
|
||||
captured_timestamp: str,
|
||||
session_id: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Insert initial detection record and return the session_id.
|
||||
|
||||
Args:
|
||||
display_id: Display identifier
|
||||
captured_timestamp: Timestamp of capture
|
||||
session_id: Optional session ID (generated if not provided)
|
||||
|
||||
Returns:
|
||||
Session ID if successful, None otherwise
|
||||
"""
|
||||
# Generate session_id if not provided
|
||||
if not session_id:
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Ensure table exists
|
||||
if not self.create_car_frontal_info_table():
|
||||
logger.error("Failed to create/verify table before insertion")
|
||||
return None
|
||||
|
||||
schema_name = self.config.schema
|
||||
insert_query = f"""
|
||||
INSERT INTO {schema_name}.car_frontal_info
|
||||
(display_id, captured_timestamp, session_id, license_character, license_type, car_brand, car_model, car_body_type, created_at)
|
||||
VALUES (%s, %s, %s, NULL, 'No model available', NULL, NULL, NULL, NOW())
|
||||
ON CONFLICT (session_id) DO NOTHING
|
||||
"""
|
||||
|
||||
result = self.execute_query(insert_query, (display_id, captured_timestamp, session_id))
|
||||
|
||||
if result.success:
|
||||
logger.info(f"Inserted initial detection record with session_id: {session_id}")
|
||||
return session_id
|
||||
else:
|
||||
logger.error(f"Failed to insert initial detection record: {result.error}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert initial detection record: {e}")
|
||||
return None
|
||||
|
||||
def execute_update(self,
|
||||
table: str,
|
||||
key_field: str,
|
||||
key_value: str,
|
||||
fields: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Execute dynamic update/insert operation.
|
||||
|
||||
Args:
|
||||
table: Table name
|
||||
key_field: Primary key field name
|
||||
key_value: Primary key value
|
||||
fields: Dictionary of fields to update
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Add schema prefix if table doesn't already have it
|
||||
if '.' not in table:
|
||||
table = f"{self.config.schema}.{table}"
|
||||
|
||||
# Build the INSERT and UPDATE query dynamically
|
||||
insert_placeholders = []
|
||||
insert_values = [key_value] # Start with key_value
|
||||
|
||||
set_clauses = []
|
||||
update_values = []
|
||||
|
||||
for field, value in fields.items():
|
||||
if value == "NOW()":
|
||||
# Special handling for NOW()
|
||||
insert_placeholders.append("NOW()")
|
||||
set_clauses.append(f"{field} = NOW()")
|
||||
else:
|
||||
insert_placeholders.append("%s")
|
||||
insert_values.append(value)
|
||||
set_clauses.append(f"{field} = %s")
|
||||
update_values.append(value)
|
||||
|
||||
# Build the complete query
|
||||
query = f"""
|
||||
INSERT INTO {table} ({key_field}, {', '.join(fields.keys())})
|
||||
VALUES (%s, {', '.join(insert_placeholders)})
|
||||
ON CONFLICT ({key_field})
|
||||
DO UPDATE SET {', '.join(set_clauses)}
|
||||
"""
|
||||
|
||||
# Combine values for the query: insert_values + update_values
|
||||
all_values = tuple(insert_values + update_values)
|
||||
|
||||
result = self.execute_query(query, all_values)
|
||||
|
||||
if result.success:
|
||||
logger.info(f"✅ Updated {table} for {key_field}={key_value} with {len(fields)} fields")
|
||||
logger.debug(f"Updated fields: {fields}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"❌ Failed to update {table}: {result.error}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to execute update on {table}: {e}")
|
||||
return False
|
||||
|
||||
def update_car_info(self,
|
||||
session_id: str,
|
||||
brand: str,
|
||||
model: str,
|
||||
body_type: str) -> bool:
|
||||
"""
|
||||
Update car information for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
brand: Car brand
|
||||
model: Car model
|
||||
body_type: Car body type
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
schema_name = self.config.schema
|
||||
query = f"""
|
||||
INSERT INTO {schema_name}.car_frontal_info (session_id, car_brand, car_model, car_body_type, updated_at)
|
||||
VALUES (%s, %s, %s, %s, NOW())
|
||||
ON CONFLICT (session_id)
|
||||
DO UPDATE SET
|
||||
car_brand = EXCLUDED.car_brand,
|
||||
car_model = EXCLUDED.car_model,
|
||||
car_body_type = EXCLUDED.car_body_type,
|
||||
updated_at = NOW()
|
||||
"""
|
||||
|
||||
result = self.execute_query(query, (session_id, brand, model, body_type))
|
||||
|
||||
if result.success:
|
||||
logger.info(f"Updated car info for session {session_id}: {brand} {model} ({body_type})")
|
||||
|
||||
return result.success
|
||||
|
||||
def get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get session data by session ID.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
Session data dictionary or None if not found
|
||||
"""
|
||||
schema_name = self.config.schema
|
||||
query = f"SELECT * FROM {schema_name}.car_frontal_info WHERE session_id = %s"
|
||||
|
||||
result = self.execute_query(query, (session_id,), fetch_results=True)
|
||||
|
||||
if result.success and result.data:
|
||||
return result.data[0]
|
||||
|
||||
return None
|
||||
|
||||
def get_connection_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get database connection statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with connection statistics
|
||||
"""
|
||||
stats = {
|
||||
"connected": self.is_connected(),
|
||||
"config": self.config.to_dict(),
|
||||
"connection_closed": self.connection.closed if self.connection else True
|
||||
}
|
||||
|
||||
if self.is_connected():
|
||||
try:
|
||||
# Get database version and basic stats
|
||||
version_result = self.execute_query("SELECT version()", fetch_results=True)
|
||||
if version_result.success and version_result.data:
|
||||
stats["database_version"] = version_result.data[0]["version"]
|
||||
|
||||
# Get current database name
|
||||
db_result = self.execute_query("SELECT current_database()", fetch_results=True)
|
||||
if db_result.success and db_result.data:
|
||||
stats["current_database"] = db_result.data[0]["current_database"]
|
||||
|
||||
except Exception as e:
|
||||
stats["stats_error"] = str(e)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# ===== CONVENIENCE FUNCTIONS =====
|
||||
# These provide compatibility with the original database.py interface
|
||||
|
||||
def validate_postgresql_config(config: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate PostgreSQL configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
True if configuration is valid
|
||||
"""
|
||||
required_fields = ["host", "port", "database", "password"]
|
||||
|
||||
for field in required_fields:
|
||||
if field not in config:
|
||||
logger.error(f"Missing required PostgreSQL config field: {field}")
|
||||
return False
|
||||
|
||||
if not config[field]:
|
||||
logger.error(f"Empty PostgreSQL config field: {field}")
|
||||
return False
|
||||
|
||||
# Check for username (could be 'username' or 'user')
|
||||
if not config.get("username") and not config.get("user"):
|
||||
logger.error("Missing PostgreSQL username (provide either 'username' or 'user')")
|
||||
return False
|
||||
|
||||
# Validate port is numeric
|
||||
try:
|
||||
port = int(config["port"])
|
||||
if port <= 0 or port > 65535:
|
||||
logger.error(f"Invalid PostgreSQL port: {port}")
|
||||
return False
|
||||
except (ValueError, TypeError):
|
||||
logger.error(f"PostgreSQL port must be numeric: {config['port']}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def create_database_manager(config: Dict[str, Any]) -> Optional[DatabaseManager]:
|
||||
"""
|
||||
Create database manager with configuration validation.
|
||||
|
||||
Args:
|
||||
config: Database configuration
|
||||
|
||||
Returns:
|
||||
DatabaseManager instance or None if invalid config
|
||||
"""
|
||||
if not validate_postgresql_config(config):
|
||||
return None
|
||||
|
||||
try:
|
||||
db_manager = DatabaseManager(config)
|
||||
if db_manager.connect():
|
||||
logger.info(f"Successfully created database manager for {config['host']}:{config['port']}")
|
||||
return db_manager
|
||||
else:
|
||||
logger.error("Failed to establish initial database connection")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating database manager: {e}")
|
||||
return None
|
733
detector_worker/storage/redis_client.py
Normal file
733
detector_worker/storage/redis_client.py
Normal file
|
@ -0,0 +1,733 @@
|
|||
"""
|
||||
Redis client management and operations.
|
||||
|
||||
This module provides comprehensive Redis functionality including:
|
||||
- Connection management with retries
|
||||
- Key-value operations
|
||||
- Pub/Sub messaging
|
||||
- Image storage with compression
|
||||
- Pipeline operations
|
||||
"""
|
||||
|
||||
import redis
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List, Union, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from ..core.constants import (
|
||||
REDIS_CONNECTION_TIMEOUT,
|
||||
REDIS_SOCKET_TIMEOUT,
|
||||
REDIS_IMAGE_DEFAULT_QUALITY,
|
||||
REDIS_IMAGE_DEFAULT_FORMAT
|
||||
)
|
||||
from ..core.exceptions import RedisError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RedisConfig:
|
||||
"""Redis connection configuration."""
|
||||
host: str
|
||||
port: int
|
||||
password: Optional[str] = None
|
||||
db: int = 0
|
||||
connection_timeout: int = REDIS_CONNECTION_TIMEOUT
|
||||
socket_timeout: int = REDIS_SOCKET_TIMEOUT
|
||||
retry_on_timeout: bool = True
|
||||
health_check_interval: int = 30
|
||||
max_connections: int = 10
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"password": self.password,
|
||||
"db": self.db,
|
||||
"connection_timeout": self.connection_timeout,
|
||||
"socket_timeout": self.socket_timeout,
|
||||
"retry_on_timeout": self.retry_on_timeout,
|
||||
"health_check_interval": self.health_check_interval,
|
||||
"max_connections": self.max_connections
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'RedisConfig':
|
||||
"""Create from dictionary data."""
|
||||
return cls(
|
||||
host=data["host"],
|
||||
port=data["port"],
|
||||
password=data.get("password"),
|
||||
db=data.get("db", 0),
|
||||
connection_timeout=data.get("connection_timeout", REDIS_CONNECTION_TIMEOUT),
|
||||
socket_timeout=data.get("socket_timeout", REDIS_SOCKET_TIMEOUT),
|
||||
retry_on_timeout=data.get("retry_on_timeout", True),
|
||||
health_check_interval=data.get("health_check_interval", 30),
|
||||
max_connections=data.get("max_connections", 10)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RedisOperationResult:
|
||||
"""Result from Redis operation."""
|
||||
success: bool
|
||||
data: Optional[Any] = None
|
||||
error: Optional[str] = None
|
||||
operation: Optional[str] = None
|
||||
key: Optional[str] = None
|
||||
execution_time: Optional[float] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
result = {
|
||||
"success": self.success,
|
||||
"operation": self.operation
|
||||
}
|
||||
if self.data is not None:
|
||||
result["data"] = self.data
|
||||
if self.error:
|
||||
result["error"] = self.error
|
||||
if self.key:
|
||||
result["key"] = self.key
|
||||
if self.execution_time:
|
||||
result["execution_time"] = self.execution_time
|
||||
return result
|
||||
|
||||
|
||||
class RedisClientManager:
|
||||
"""
|
||||
Comprehensive Redis client manager with connection pooling and retry logic.
|
||||
|
||||
This class provides high-level Redis operations with automatic reconnection,
|
||||
connection pooling, and comprehensive error handling.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Union[Dict[str, Any], RedisConfig]):
|
||||
"""
|
||||
Initialize Redis client manager.
|
||||
|
||||
Args:
|
||||
config: Redis configuration dictionary or RedisConfig object
|
||||
"""
|
||||
if isinstance(config, dict):
|
||||
self.config = RedisConfig.from_dict(config)
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
self.client: Optional[redis.Redis] = None
|
||||
self.connection_pool: Optional[redis.ConnectionPool] = None
|
||||
self._lock = None
|
||||
self._last_health_check = 0.0
|
||||
|
||||
def _ensure_thread_safety(self):
|
||||
"""Initialize thread safety if not already done."""
|
||||
if self._lock is None:
|
||||
import threading
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _create_connection_pool(self) -> redis.ConnectionPool:
|
||||
"""Create Redis connection pool."""
|
||||
pool_kwargs = {
|
||||
"host": self.config.host,
|
||||
"port": self.config.port,
|
||||
"db": self.config.db,
|
||||
"socket_timeout": self.config.socket_timeout,
|
||||
"socket_connect_timeout": self.config.connection_timeout,
|
||||
"retry_on_timeout": self.config.retry_on_timeout,
|
||||
"health_check_interval": self.config.health_check_interval,
|
||||
"max_connections": self.config.max_connections
|
||||
}
|
||||
|
||||
if self.config.password:
|
||||
pool_kwargs["password"] = self.config.password
|
||||
|
||||
return redis.ConnectionPool(**pool_kwargs)
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""
|
||||
Establish Redis connection.
|
||||
|
||||
Returns:
|
||||
True if connection successful, False otherwise
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
try:
|
||||
# Create connection pool
|
||||
self.connection_pool = self._create_connection_pool()
|
||||
|
||||
# Create Redis client
|
||||
self.client = redis.Redis(connection_pool=self.connection_pool)
|
||||
|
||||
# Test connection
|
||||
self.client.ping()
|
||||
self._last_health_check = time.time()
|
||||
|
||||
logger.info(f"Redis connection established successfully to {self.config.host}:{self.config.port}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Redis: {e}")
|
||||
self.client = None
|
||||
self.connection_pool = None
|
||||
return False
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Close Redis connection and cleanup resources."""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
if self.connection_pool:
|
||||
try:
|
||||
self.connection_pool.disconnect()
|
||||
logger.info("Redis connection pool disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"Error disconnecting Redis pool: {e}")
|
||||
finally:
|
||||
self.connection_pool = None
|
||||
|
||||
self.client = None
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""
|
||||
Check if Redis connection is active.
|
||||
|
||||
Returns:
|
||||
True if connected, False otherwise
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
try:
|
||||
if self.client:
|
||||
# Perform periodic health check
|
||||
current_time = time.time()
|
||||
if current_time - self._last_health_check > self.config.health_check_interval:
|
||||
self.client.ping()
|
||||
self._last_health_check = current_time
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"Redis connection check failed: {e}")
|
||||
|
||||
return False
|
||||
|
||||
def _ensure_connected(self) -> bool:
|
||||
"""Ensure Redis connection is active."""
|
||||
if not self.is_connected():
|
||||
return self.connect()
|
||||
return True
|
||||
|
||||
def _execute_operation(self,
|
||||
operation_name: str,
|
||||
operation_func,
|
||||
key: Optional[str] = None) -> RedisOperationResult:
|
||||
"""Execute Redis operation with error handling and timing."""
|
||||
if not self._ensure_connected():
|
||||
return RedisOperationResult(
|
||||
success=False,
|
||||
error="Failed to establish Redis connection",
|
||||
operation=operation_name,
|
||||
key=key
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = operation_func()
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
return RedisOperationResult(
|
||||
success=True,
|
||||
data=result,
|
||||
operation=operation_name,
|
||||
key=key,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
error_msg = f"Redis {operation_name} operation failed: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
return RedisOperationResult(
|
||||
success=False,
|
||||
error=error_msg,
|
||||
operation=operation_name,
|
||||
key=key,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
# ===== KEY-VALUE OPERATIONS =====
|
||||
|
||||
def set(self, key: str, value: Any, ex: Optional[int] = None) -> RedisOperationResult:
|
||||
"""
|
||||
Set key-value pair with optional expiration.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
value: Value to store
|
||||
ex: Expiration time in seconds
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with operation status
|
||||
"""
|
||||
def operation():
|
||||
return self.client.set(key, value, ex=ex)
|
||||
|
||||
result = self._execute_operation("SET", operation, key)
|
||||
|
||||
if result.success:
|
||||
expire_msg = f" (expires in {ex}s)" if ex else ""
|
||||
logger.debug(f"Set Redis key '{key}'{expire_msg}")
|
||||
|
||||
return result
|
||||
|
||||
def setex(self, key: str, time: int, value: Any) -> RedisOperationResult:
|
||||
"""
|
||||
Set key-value pair with expiration time.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
time: Expiration time in seconds
|
||||
value: Value to store
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with operation status
|
||||
"""
|
||||
def operation():
|
||||
return self.client.setex(key, time, value)
|
||||
|
||||
result = self._execute_operation("SETEX", operation, key)
|
||||
|
||||
if result.success:
|
||||
logger.debug(f"Set Redis key '{key}' with {time}s expiration")
|
||||
|
||||
return result
|
||||
|
||||
def get(self, key: str) -> RedisOperationResult:
|
||||
"""
|
||||
Get value by key.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with value data
|
||||
"""
|
||||
def operation():
|
||||
return self.client.get(key)
|
||||
|
||||
return self._execute_operation("GET", operation, key)
|
||||
|
||||
def delete(self, *keys: str) -> RedisOperationResult:
|
||||
"""
|
||||
Delete one or more keys.
|
||||
|
||||
Args:
|
||||
*keys: Redis keys to delete
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with number of deleted keys
|
||||
"""
|
||||
def operation():
|
||||
return self.client.delete(*keys)
|
||||
|
||||
result = self._execute_operation("DELETE", operation)
|
||||
|
||||
if result.success:
|
||||
logger.debug(f"Deleted {result.data} Redis key(s): {keys}")
|
||||
|
||||
return result
|
||||
|
||||
def exists(self, *keys: str) -> RedisOperationResult:
|
||||
"""
|
||||
Check if keys exist.
|
||||
|
||||
Args:
|
||||
*keys: Redis keys to check
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with count of existing keys
|
||||
"""
|
||||
def operation():
|
||||
return self.client.exists(*keys)
|
||||
|
||||
return self._execute_operation("EXISTS", operation)
|
||||
|
||||
def expire(self, key: str, time: int) -> RedisOperationResult:
|
||||
"""
|
||||
Set expiration time for a key.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
time: Expiration time in seconds
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with operation status
|
||||
"""
|
||||
def operation():
|
||||
return self.client.expire(key, time)
|
||||
|
||||
return self._execute_operation("EXPIRE", operation, key)
|
||||
|
||||
def ttl(self, key: str) -> RedisOperationResult:
|
||||
"""
|
||||
Get time to live for a key.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with TTL in seconds
|
||||
"""
|
||||
def operation():
|
||||
return self.client.ttl(key)
|
||||
|
||||
return self._execute_operation("TTL", operation, key)
|
||||
|
||||
# ===== PUB/SUB OPERATIONS =====
|
||||
|
||||
def publish(self, channel: str, message: Union[str, Dict[str, Any]]) -> RedisOperationResult:
|
||||
"""
|
||||
Publish message to channel.
|
||||
|
||||
Args:
|
||||
channel: Channel name
|
||||
message: Message to publish (string or dict)
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with number of subscribers
|
||||
"""
|
||||
def operation():
|
||||
# Convert dict to JSON string
|
||||
if isinstance(message, dict):
|
||||
message_str = json.dumps(message)
|
||||
else:
|
||||
message_str = str(message)
|
||||
|
||||
return self.client.publish(channel, message_str)
|
||||
|
||||
result = self._execute_operation("PUBLISH", operation)
|
||||
|
||||
if result.success:
|
||||
logger.debug(f"Published to channel '{channel}', subscribers: {result.data}")
|
||||
if result.data == 0:
|
||||
logger.warning(f"No subscribers listening to channel '{channel}'")
|
||||
|
||||
return result
|
||||
|
||||
def subscribe(self, *channels: str):
|
||||
"""
|
||||
Subscribe to channels (returns pubsub object).
|
||||
|
||||
Args:
|
||||
*channels: Channel names to subscribe to
|
||||
|
||||
Returns:
|
||||
PubSub object for listening to messages
|
||||
"""
|
||||
if not self._ensure_connected():
|
||||
raise RedisError("Failed to establish Redis connection for subscription")
|
||||
|
||||
pubsub = self.client.pubsub()
|
||||
pubsub.subscribe(*channels)
|
||||
|
||||
logger.info(f"Subscribed to channels: {channels}")
|
||||
return pubsub
|
||||
|
||||
# ===== HASH OPERATIONS =====
|
||||
|
||||
def hset(self, key: str, field: str, value: Any) -> RedisOperationResult:
|
||||
"""
|
||||
Set hash field.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
field: Hash field name
|
||||
value: Field value
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with operation status
|
||||
"""
|
||||
def operation():
|
||||
return self.client.hset(key, field, value)
|
||||
|
||||
return self._execute_operation("HSET", operation, key)
|
||||
|
||||
def hget(self, key: str, field: str) -> RedisOperationResult:
|
||||
"""
|
||||
Get hash field value.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
field: Hash field name
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with field value
|
||||
"""
|
||||
def operation():
|
||||
return self.client.hget(key, field)
|
||||
|
||||
return self._execute_operation("HGET", operation, key)
|
||||
|
||||
def hgetall(self, key: str) -> RedisOperationResult:
|
||||
"""
|
||||
Get all hash fields and values.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with hash dictionary
|
||||
"""
|
||||
def operation():
|
||||
return self.client.hgetall(key)
|
||||
|
||||
return self._execute_operation("HGETALL", operation, key)
|
||||
|
||||
# ===== LIST OPERATIONS =====
|
||||
|
||||
def lpush(self, key: str, *values: Any) -> RedisOperationResult:
|
||||
"""
|
||||
Push values to left of list.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
*values: Values to push
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with list length
|
||||
"""
|
||||
def operation():
|
||||
return self.client.lpush(key, *values)
|
||||
|
||||
return self._execute_operation("LPUSH", operation, key)
|
||||
|
||||
def rpush(self, key: str, *values: Any) -> RedisOperationResult:
|
||||
"""
|
||||
Push values to right of list.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
*values: Values to push
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with list length
|
||||
"""
|
||||
def operation():
|
||||
return self.client.rpush(key, *values)
|
||||
|
||||
return self._execute_operation("RPUSH", operation, key)
|
||||
|
||||
def lrange(self, key: str, start: int, end: int) -> RedisOperationResult:
|
||||
"""
|
||||
Get range of list elements.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
start: Start index
|
||||
end: End index
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with list elements
|
||||
"""
|
||||
def operation():
|
||||
return self.client.lrange(key, start, end)
|
||||
|
||||
return self._execute_operation("LRANGE", operation, key)
|
||||
|
||||
# ===== UTILITY OPERATIONS =====
|
||||
|
||||
def ping(self) -> RedisOperationResult:
|
||||
"""
|
||||
Ping Redis server.
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with ping response
|
||||
"""
|
||||
def operation():
|
||||
return self.client.ping()
|
||||
|
||||
return self._execute_operation("PING", operation)
|
||||
|
||||
def info(self, section: Optional[str] = None) -> RedisOperationResult:
|
||||
"""
|
||||
Get Redis server information.
|
||||
|
||||
Args:
|
||||
section: Optional info section
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with server info
|
||||
"""
|
||||
def operation():
|
||||
return self.client.info(section)
|
||||
|
||||
return self._execute_operation("INFO", operation)
|
||||
|
||||
def flushdb(self) -> RedisOperationResult:
|
||||
"""
|
||||
Flush current database.
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with operation status
|
||||
"""
|
||||
def operation():
|
||||
return self.client.flushdb()
|
||||
|
||||
result = self._execute_operation("FLUSHDB", operation)
|
||||
|
||||
if result.success:
|
||||
logger.warning(f"Flushed Redis database {self.config.db}")
|
||||
|
||||
return result
|
||||
|
||||
def keys(self, pattern: str = "*") -> RedisOperationResult:
|
||||
"""
|
||||
Get keys matching pattern.
|
||||
|
||||
Args:
|
||||
pattern: Key pattern (default: all keys)
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with matching keys
|
||||
"""
|
||||
def operation():
|
||||
return self.client.keys(pattern)
|
||||
|
||||
return self._execute_operation("KEYS", operation)
|
||||
|
||||
# ===== BATCH OPERATIONS =====
|
||||
|
||||
def pipeline(self):
|
||||
"""
|
||||
Create Redis pipeline for batch operations.
|
||||
|
||||
Returns:
|
||||
Redis pipeline object
|
||||
"""
|
||||
if not self._ensure_connected():
|
||||
raise RedisError("Failed to establish Redis connection for pipeline")
|
||||
|
||||
return self.client.pipeline()
|
||||
|
||||
def execute_pipeline(self, pipeline) -> RedisOperationResult:
|
||||
"""
|
||||
Execute Redis pipeline.
|
||||
|
||||
Args:
|
||||
pipeline: Redis pipeline object
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with pipeline results
|
||||
"""
|
||||
def operation():
|
||||
return pipeline.execute()
|
||||
|
||||
return self._execute_operation("PIPELINE", operation)
|
||||
|
||||
# ===== CONNECTION MANAGEMENT =====
|
||||
|
||||
def get_connection_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get Redis connection statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with connection statistics
|
||||
"""
|
||||
stats = {
|
||||
"connected": self.is_connected(),
|
||||
"config": self.config.to_dict(),
|
||||
"last_health_check": self._last_health_check,
|
||||
"connection_pool_created": self.connection_pool is not None
|
||||
}
|
||||
|
||||
if self.connection_pool:
|
||||
stats["connection_pool_stats"] = {
|
||||
"created_connections": self.connection_pool.created_connections,
|
||||
"available_connections": len(self.connection_pool._available_connections),
|
||||
"in_use_connections": len(self.connection_pool._in_use_connections)
|
||||
}
|
||||
|
||||
# Get Redis server info if connected
|
||||
if self.is_connected():
|
||||
try:
|
||||
info_result = self.info()
|
||||
if info_result.success:
|
||||
redis_info = info_result.data
|
||||
stats["server_info"] = {
|
||||
"redis_version": redis_info.get("redis_version"),
|
||||
"connected_clients": redis_info.get("connected_clients"),
|
||||
"used_memory": redis_info.get("used_memory"),
|
||||
"used_memory_human": redis_info.get("used_memory_human"),
|
||||
"total_commands_processed": redis_info.get("total_commands_processed"),
|
||||
"uptime_in_seconds": redis_info.get("uptime_in_seconds")
|
||||
}
|
||||
except Exception as e:
|
||||
stats["server_info_error"] = str(e)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# ===== CONVENIENCE FUNCTIONS =====
|
||||
# These provide compatibility and simplified access
|
||||
|
||||
def create_redis_client(config: Dict[str, Any]) -> Optional[RedisClientManager]:
|
||||
"""
|
||||
Create Redis client with configuration validation.
|
||||
|
||||
Args:
|
||||
config: Redis configuration
|
||||
|
||||
Returns:
|
||||
RedisClientManager instance or None if connection failed
|
||||
"""
|
||||
try:
|
||||
client_manager = RedisClientManager(config)
|
||||
if client_manager.connect():
|
||||
logger.info(f"Successfully created Redis client for {config['host']}:{config['port']}")
|
||||
return client_manager
|
||||
else:
|
||||
logger.error("Failed to establish initial Redis connection")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating Redis client: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def validate_redis_config(config: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate Redis configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
True if configuration is valid
|
||||
"""
|
||||
required_fields = ["host", "port"]
|
||||
|
||||
for field in required_fields:
|
||||
if field not in config:
|
||||
logger.error(f"Missing required Redis config field: {field}")
|
||||
return False
|
||||
|
||||
if not config[field]:
|
||||
logger.error(f"Empty Redis config field: {field}")
|
||||
return False
|
||||
|
||||
# Validate port is numeric
|
||||
try:
|
||||
port = int(config["port"])
|
||||
if port <= 0 or port > 65535:
|
||||
logger.error(f"Invalid Redis port: {port}")
|
||||
return False
|
||||
except (ValueError, TypeError):
|
||||
logger.error(f"Redis port must be numeric: {config['port']}")
|
||||
return False
|
||||
|
||||
return True
|
688
detector_worker/storage/session_cache.py
Normal file
688
detector_worker/storage/session_cache.py
Normal file
|
@ -0,0 +1,688 @@
|
|||
"""
|
||||
Session and cache management for detection workflows.
|
||||
|
||||
This module provides comprehensive session and cache management including:
|
||||
- Session state tracking and lifecycle management
|
||||
- Detection result caching with TTL
|
||||
- Pipeline mode state management
|
||||
- Session cleanup and garbage collection
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Set, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
|
||||
from ..core.constants import (
|
||||
SESSION_CACHE_TTL_MINUTES,
|
||||
DETECTION_CACHE_CLEANUP_INTERVAL,
|
||||
SESSION_TIMEOUT_SECONDS
|
||||
)
|
||||
from ..core.exceptions import SessionError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineMode(Enum):
|
||||
"""Pipeline execution modes."""
|
||||
VALIDATION_DETECTING = "validation_detecting"
|
||||
SEND_DETECTIONS = "send_detections"
|
||||
WAITING_FOR_SESSION_ID = "waiting_for_session_id"
|
||||
FULL_PIPELINE = "full_pipeline"
|
||||
LIGHTWEIGHT = "lightweight"
|
||||
CAR_GONE_WAITING = "car_gone_waiting"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionState:
|
||||
"""Session state information."""
|
||||
session_id: Optional[str] = None
|
||||
backend_session_id: Optional[str] = None
|
||||
mode: PipelineMode = PipelineMode.VALIDATION_DETECTING
|
||||
session_id_received: bool = False
|
||||
created_at: float = field(default_factory=time.time)
|
||||
last_updated: float = field(default_factory=time.time)
|
||||
last_detection: Optional[float] = None
|
||||
detection_count: int = 0
|
||||
|
||||
# Mode-specific state
|
||||
validation_frames_processed: int = 0
|
||||
stable_track_achieved: bool = False
|
||||
waiting_start_time: Optional[float] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"backend_session_id": self.backend_session_id,
|
||||
"mode": self.mode.value,
|
||||
"session_id_received": self.session_id_received,
|
||||
"created_at": self.created_at,
|
||||
"last_updated": self.last_updated,
|
||||
"last_detection": self.last_detection,
|
||||
"detection_count": self.detection_count,
|
||||
"validation_frames_processed": self.validation_frames_processed,
|
||||
"stable_track_achieved": self.stable_track_achieved,
|
||||
"waiting_start_time": self.waiting_start_time
|
||||
}
|
||||
|
||||
def update_activity(self) -> None:
|
||||
"""Update last activity timestamp."""
|
||||
self.last_updated = time.time()
|
||||
|
||||
def record_detection(self) -> None:
|
||||
"""Record a detection occurrence."""
|
||||
current_time = time.time()
|
||||
self.last_detection = current_time
|
||||
self.detection_count += 1
|
||||
self.update_activity()
|
||||
|
||||
def is_expired(self, ttl_seconds: int) -> bool:
|
||||
"""Check if session has expired based on TTL."""
|
||||
return time.time() - self.last_updated > ttl_seconds
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedDetection:
|
||||
"""Cached detection result."""
|
||||
detection_data: Dict[str, Any]
|
||||
camera_id: str
|
||||
session_id: Optional[str] = None
|
||||
created_at: float = field(default_factory=time.time)
|
||||
access_count: int = 0
|
||||
last_accessed: float = field(default_factory=time.time)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"detection_data": self.detection_data,
|
||||
"camera_id": self.camera_id,
|
||||
"session_id": self.session_id,
|
||||
"created_at": self.created_at,
|
||||
"access_count": self.access_count,
|
||||
"last_accessed": self.last_accessed
|
||||
}
|
||||
|
||||
def access(self) -> None:
|
||||
"""Record access to this cached detection."""
|
||||
self.access_count += 1
|
||||
self.last_accessed = time.time()
|
||||
|
||||
def is_expired(self, ttl_seconds: int) -> bool:
|
||||
"""Check if cached detection has expired."""
|
||||
return time.time() - self.created_at > ttl_seconds
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Session lifecycle and state management.
|
||||
|
||||
This class provides comprehensive session management including:
|
||||
- Session creation and cleanup
|
||||
- Pipeline mode transitions
|
||||
- Session timeout handling
|
||||
"""
|
||||
|
||||
def __init__(self, session_timeout_seconds: int = SESSION_TIMEOUT_SECONDS):
|
||||
"""
|
||||
Initialize session manager.
|
||||
|
||||
Args:
|
||||
session_timeout_seconds: Default session timeout
|
||||
"""
|
||||
self.session_timeout_seconds = session_timeout_seconds
|
||||
self._sessions: Dict[str, SessionState] = {}
|
||||
self._session_ids: Dict[str, str] = {} # display_id -> session_id mapping
|
||||
self._lock = None
|
||||
|
||||
def _ensure_thread_safety(self):
|
||||
"""Initialize thread safety if not already done."""
|
||||
if self._lock is None:
|
||||
import threading
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def create_session(self, display_id: str, session_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Create a new session or get existing one.
|
||||
|
||||
Args:
|
||||
display_id: Display identifier
|
||||
session_id: Optional session ID
|
||||
|
||||
Returns:
|
||||
Session ID
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
# Check if session already exists for this display
|
||||
if display_id in self._session_ids:
|
||||
existing_session_id = self._session_ids[display_id]
|
||||
if existing_session_id in self._sessions:
|
||||
session_state = self._sessions[existing_session_id]
|
||||
session_state.update_activity()
|
||||
logger.debug(f"Using existing session for display {display_id}: {existing_session_id}")
|
||||
return existing_session_id
|
||||
|
||||
# Create new session
|
||||
if not session_id:
|
||||
import uuid
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
session_state = SessionState(session_id=session_id)
|
||||
self._sessions[session_id] = session_state
|
||||
self._session_ids[display_id] = session_id
|
||||
|
||||
logger.info(f"Created new session for display {display_id}: {session_id}")
|
||||
return session_id
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[SessionState]:
|
||||
"""
|
||||
Get session state by session ID.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
SessionState or None if not found
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
session = self._sessions.get(session_id)
|
||||
if session:
|
||||
session.update_activity()
|
||||
return session
|
||||
|
||||
def get_session_by_display(self, display_id: str) -> Optional[SessionState]:
|
||||
"""
|
||||
Get session state by display ID.
|
||||
|
||||
Args:
|
||||
display_id: Display identifier
|
||||
|
||||
Returns:
|
||||
SessionState or None if not found
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
session_id = self._session_ids.get(display_id)
|
||||
if session_id:
|
||||
return self.get_session(session_id)
|
||||
return None
|
||||
|
||||
def update_session_mode(self,
|
||||
session_id: str,
|
||||
new_mode: PipelineMode,
|
||||
backend_session_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Update session pipeline mode.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
new_mode: New pipeline mode
|
||||
backend_session_id: Optional backend session ID
|
||||
|
||||
Returns:
|
||||
True if updated successfully
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
session = self.get_session(session_id)
|
||||
if not session:
|
||||
logger.warning(f"Session not found for mode update: {session_id}")
|
||||
return False
|
||||
|
||||
old_mode = session.mode
|
||||
session.mode = new_mode
|
||||
|
||||
if backend_session_id:
|
||||
session.backend_session_id = backend_session_id
|
||||
session.session_id_received = True
|
||||
|
||||
# Handle mode-specific state changes
|
||||
if new_mode == PipelineMode.WAITING_FOR_SESSION_ID:
|
||||
session.waiting_start_time = time.time()
|
||||
elif old_mode == PipelineMode.WAITING_FOR_SESSION_ID:
|
||||
session.waiting_start_time = None
|
||||
|
||||
session.update_activity()
|
||||
|
||||
logger.info(f"Session {session_id}: Mode changed from {old_mode.value} to {new_mode.value}")
|
||||
return True
|
||||
|
||||
def record_detection(self, session_id: str) -> bool:
|
||||
"""
|
||||
Record a detection for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
True if recorded successfully
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
session = self.get_session(session_id)
|
||||
if session:
|
||||
session.record_detection()
|
||||
return True
|
||||
return False
|
||||
|
||||
def cleanup_expired_sessions(self, ttl_seconds: Optional[int] = None) -> int:
|
||||
"""
|
||||
Clean up expired sessions.
|
||||
|
||||
Args:
|
||||
ttl_seconds: TTL in seconds (uses default if not provided)
|
||||
|
||||
Returns:
|
||||
Number of sessions cleaned up
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
if ttl_seconds is None:
|
||||
ttl_seconds = SESSION_CACHE_TTL_MINUTES * 60
|
||||
|
||||
with self._lock:
|
||||
expired_sessions = []
|
||||
expired_displays = []
|
||||
|
||||
# Find expired sessions
|
||||
for session_id, session in self._sessions.items():
|
||||
if session.is_expired(ttl_seconds):
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
# Find displays pointing to expired sessions
|
||||
for display_id, session_id in self._session_ids.items():
|
||||
if session_id in expired_sessions:
|
||||
expired_displays.append(display_id)
|
||||
|
||||
# Remove expired sessions and mappings
|
||||
for session_id in expired_sessions:
|
||||
del self._sessions[session_id]
|
||||
|
||||
for display_id in expired_displays:
|
||||
del self._session_ids[display_id]
|
||||
|
||||
if expired_sessions:
|
||||
logger.info(f"Cleaned up {len(expired_sessions)} expired sessions")
|
||||
|
||||
return len(expired_sessions)
|
||||
|
||||
def get_session_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get session management statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with session statistics
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
current_time = time.time()
|
||||
mode_counts = {}
|
||||
total_detections = 0
|
||||
|
||||
for session in self._sessions.values():
|
||||
mode = session.mode.value
|
||||
mode_counts[mode] = mode_counts.get(mode, 0) + 1
|
||||
total_detections += session.detection_count
|
||||
|
||||
return {
|
||||
"total_sessions": len(self._sessions),
|
||||
"total_display_mappings": len(self._session_ids),
|
||||
"mode_distribution": mode_counts,
|
||||
"total_detections_processed": total_detections,
|
||||
"session_timeout_seconds": self.session_timeout_seconds
|
||||
}
|
||||
|
||||
|
||||
class DetectionCache:
|
||||
"""
|
||||
Detection result caching with TTL and access tracking.
|
||||
|
||||
This class provides caching for detection results with automatic
|
||||
expiration and access pattern tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, ttl_minutes: int = SESSION_CACHE_TTL_MINUTES):
|
||||
"""
|
||||
Initialize detection cache.
|
||||
|
||||
Args:
|
||||
ttl_minutes: Time to live for cached detections in minutes
|
||||
"""
|
||||
self.ttl_seconds = ttl_minutes * 60
|
||||
self._cache: Dict[str, CachedDetection] = {}
|
||||
self._camera_index: Dict[str, Set[str]] = {} # camera_id -> set of cache keys
|
||||
self._session_index: Dict[str, Set[str]] = {} # session_id -> set of cache keys
|
||||
self._lock = None
|
||||
self._last_cleanup = time.time()
|
||||
|
||||
def _ensure_thread_safety(self):
|
||||
"""Initialize thread safety if not already done."""
|
||||
if self._lock is None:
|
||||
import threading
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _generate_cache_key(self, camera_id: str, detection_type: str = "default") -> str:
|
||||
"""Generate cache key for detection."""
|
||||
return f"{camera_id}:{detection_type}:{time.time()}"
|
||||
|
||||
def store_detection(self,
|
||||
camera_id: str,
|
||||
detection_data: Dict[str, Any],
|
||||
session_id: Optional[str] = None,
|
||||
detection_type: str = "default") -> str:
|
||||
"""
|
||||
Store detection in cache.
|
||||
|
||||
Args:
|
||||
camera_id: Camera identifier
|
||||
detection_data: Detection result data
|
||||
session_id: Optional session identifier
|
||||
detection_type: Type of detection for categorization
|
||||
|
||||
Returns:
|
||||
Cache key for the stored detection
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
cache_key = self._generate_cache_key(camera_id, detection_type)
|
||||
|
||||
cached_detection = CachedDetection(
|
||||
detection_data=detection_data,
|
||||
camera_id=camera_id,
|
||||
session_id=session_id
|
||||
)
|
||||
|
||||
self._cache[cache_key] = cached_detection
|
||||
|
||||
# Update indexes
|
||||
if camera_id not in self._camera_index:
|
||||
self._camera_index[camera_id] = set()
|
||||
self._camera_index[camera_id].add(cache_key)
|
||||
|
||||
if session_id:
|
||||
if session_id not in self._session_index:
|
||||
self._session_index[session_id] = set()
|
||||
self._session_index[session_id].add(cache_key)
|
||||
|
||||
logger.debug(f"Stored detection in cache: {cache_key}")
|
||||
return cache_key
|
||||
|
||||
def get_detection(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get detection from cache by key.
|
||||
|
||||
Args:
|
||||
cache_key: Cache key
|
||||
|
||||
Returns:
|
||||
Detection data or None if not found/expired
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
cached_detection = self._cache.get(cache_key)
|
||||
if not cached_detection:
|
||||
return None
|
||||
|
||||
if cached_detection.is_expired(self.ttl_seconds):
|
||||
self._remove_from_cache(cache_key)
|
||||
return None
|
||||
|
||||
cached_detection.access()
|
||||
return cached_detection.detection_data
|
||||
|
||||
def get_latest_detection(self, camera_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get latest detection for a camera.
|
||||
|
||||
Args:
|
||||
camera_id: Camera identifier
|
||||
|
||||
Returns:
|
||||
Latest detection data or None if not found
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
camera_keys = self._camera_index.get(camera_id, set())
|
||||
if not camera_keys:
|
||||
return None
|
||||
|
||||
# Find the most recent non-expired detection
|
||||
latest_detection = None
|
||||
latest_time = 0
|
||||
|
||||
for cache_key in camera_keys:
|
||||
cached_detection = self._cache.get(cache_key)
|
||||
if cached_detection and not cached_detection.is_expired(self.ttl_seconds):
|
||||
if cached_detection.created_at > latest_time:
|
||||
latest_time = cached_detection.created_at
|
||||
latest_detection = cached_detection
|
||||
|
||||
if latest_detection:
|
||||
latest_detection.access()
|
||||
return latest_detection.detection_data
|
||||
|
||||
return None
|
||||
|
||||
def get_session_detections(self, session_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all detections for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
List of detection data dictionaries
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
session_keys = self._session_index.get(session_id, set())
|
||||
detections = []
|
||||
|
||||
for cache_key in session_keys:
|
||||
cached_detection = self._cache.get(cache_key)
|
||||
if cached_detection and not cached_detection.is_expired(self.ttl_seconds):
|
||||
cached_detection.access()
|
||||
detections.append(cached_detection.detection_data)
|
||||
|
||||
# Sort by creation time (newest first)
|
||||
detections.sort(key=lambda x: x.get("timestamp", 0), reverse=True)
|
||||
return detections
|
||||
|
||||
def _remove_from_cache(self, cache_key: str) -> None:
|
||||
"""Remove detection from cache and indexes."""
|
||||
cached_detection = self._cache.get(cache_key)
|
||||
if cached_detection:
|
||||
# Remove from indexes
|
||||
camera_id = cached_detection.camera_id
|
||||
if camera_id in self._camera_index:
|
||||
self._camera_index[camera_id].discard(cache_key)
|
||||
if not self._camera_index[camera_id]:
|
||||
del self._camera_index[camera_id]
|
||||
|
||||
session_id = cached_detection.session_id
|
||||
if session_id and session_id in self._session_index:
|
||||
self._session_index[session_id].discard(cache_key)
|
||||
if not self._session_index[session_id]:
|
||||
del self._session_index[session_id]
|
||||
|
||||
# Remove from main cache
|
||||
if cache_key in self._cache:
|
||||
del self._cache[cache_key]
|
||||
|
||||
def cleanup_expired_detections(self) -> int:
|
||||
"""
|
||||
Clean up expired cached detections.
|
||||
|
||||
Returns:
|
||||
Number of detections cleaned up
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
expired_keys = []
|
||||
|
||||
# Find expired detections
|
||||
for cache_key, cached_detection in self._cache.items():
|
||||
if cached_detection.is_expired(self.ttl_seconds):
|
||||
expired_keys.append(cache_key)
|
||||
|
||||
# Remove expired detections
|
||||
for cache_key in expired_keys:
|
||||
self._remove_from_cache(cache_key)
|
||||
|
||||
if expired_keys:
|
||||
logger.info(f"Cleaned up {len(expired_keys)} expired cached detections")
|
||||
|
||||
self._last_cleanup = time.time()
|
||||
return len(expired_keys)
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
total_access_count = sum(detection.access_count for detection in self._cache.values())
|
||||
|
||||
return {
|
||||
"total_cached_detections": len(self._cache),
|
||||
"cameras_with_cache": len(self._camera_index),
|
||||
"sessions_with_cache": len(self._session_index),
|
||||
"total_access_count": total_access_count,
|
||||
"ttl_seconds": self.ttl_seconds,
|
||||
"last_cleanup": self._last_cleanup
|
||||
}
|
||||
|
||||
|
||||
class SessionCacheManager:
|
||||
"""
|
||||
Combined session and cache management.
|
||||
|
||||
This class provides unified management of sessions and detection caching
|
||||
with automatic cleanup and comprehensive statistics.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
session_timeout_seconds: int = SESSION_TIMEOUT_SECONDS,
|
||||
cache_ttl_minutes: int = SESSION_CACHE_TTL_MINUTES):
|
||||
"""
|
||||
Initialize session cache manager.
|
||||
|
||||
Args:
|
||||
session_timeout_seconds: Session timeout in seconds
|
||||
cache_ttl_minutes: Cache TTL in minutes
|
||||
"""
|
||||
self.session_manager = SessionManager(session_timeout_seconds)
|
||||
self.detection_cache = DetectionCache(cache_ttl_minutes)
|
||||
self._last_cleanup = time.time()
|
||||
|
||||
def cleanup_expired_data(self, force: bool = False) -> Dict[str, int]:
|
||||
"""
|
||||
Clean up expired sessions and cached detections.
|
||||
|
||||
Args:
|
||||
force: Force cleanup regardless of interval
|
||||
|
||||
Returns:
|
||||
Dictionary with cleanup statistics
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# Check if cleanup is needed
|
||||
if not force and current_time - self._last_cleanup < DETECTION_CACHE_CLEANUP_INTERVAL:
|
||||
return {"sessions_cleaned": 0, "detections_cleaned": 0, "cleanup_skipped": True}
|
||||
|
||||
sessions_cleaned = self.session_manager.cleanup_expired_sessions()
|
||||
detections_cleaned = self.detection_cache.cleanup_expired_detections()
|
||||
|
||||
self._last_cleanup = current_time
|
||||
|
||||
return {
|
||||
"sessions_cleaned": sessions_cleaned,
|
||||
"detections_cleaned": detections_cleaned,
|
||||
"cleanup_skipped": False
|
||||
}
|
||||
|
||||
def get_comprehensive_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get comprehensive statistics for sessions and cache.
|
||||
|
||||
Returns:
|
||||
Dictionary with all statistics
|
||||
"""
|
||||
return {
|
||||
"session_stats": self.session_manager.get_session_stats(),
|
||||
"cache_stats": self.detection_cache.get_cache_stats(),
|
||||
"last_cleanup": self._last_cleanup,
|
||||
"cleanup_interval_seconds": DETECTION_CACHE_CLEANUP_INTERVAL
|
||||
}
|
||||
|
||||
|
||||
# Global session cache manager instance
|
||||
session_cache_manager = SessionCacheManager()
|
||||
|
||||
|
||||
# ===== CONVENIENCE FUNCTIONS =====
|
||||
# These provide simplified access to session and cache functionality
|
||||
|
||||
def create_session(display_id: str, session_id: Optional[str] = None) -> str:
|
||||
"""Create a new session using global manager."""
|
||||
return session_cache_manager.session_manager.create_session(display_id, session_id)
|
||||
|
||||
|
||||
def get_session_state(session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get session state by session ID."""
|
||||
session = session_cache_manager.session_manager.get_session(session_id)
|
||||
return session.to_dict() if session else None
|
||||
|
||||
|
||||
def update_pipeline_mode(session_id: str,
|
||||
new_mode: str,
|
||||
backend_session_id: Optional[str] = None) -> bool:
|
||||
"""Update session pipeline mode."""
|
||||
try:
|
||||
mode = PipelineMode(new_mode)
|
||||
return session_cache_manager.session_manager.update_session_mode(session_id, mode, backend_session_id)
|
||||
except ValueError:
|
||||
logger.error(f"Invalid pipeline mode: {new_mode}")
|
||||
return False
|
||||
|
||||
|
||||
def cache_detection(camera_id: str,
|
||||
detection_data: Dict[str, Any],
|
||||
session_id: Optional[str] = None) -> str:
|
||||
"""Cache detection data using global manager."""
|
||||
return session_cache_manager.detection_cache.store_detection(camera_id, detection_data, session_id)
|
||||
|
||||
|
||||
def get_cached_detection(camera_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get latest cached detection for a camera."""
|
||||
return session_cache_manager.detection_cache.get_latest_detection(camera_id)
|
||||
|
||||
|
||||
def cleanup_expired_sessions_and_cache() -> Dict[str, int]:
|
||||
"""Clean up expired sessions and cached data."""
|
||||
return session_cache_manager.cleanup_expired_data()
|
||||
|
||||
|
||||
def get_session_and_cache_stats() -> Dict[str, Any]:
|
||||
"""Get comprehensive session and cache statistics."""
|
||||
return session_cache_manager.get_comprehensive_stats()
|
Loading…
Add table
Add a link
Reference in a new issue