992 lines
No EOL
41 KiB
Python
992 lines
No EOL
41 KiB
Python
"""
|
|
Detection Pipeline Module.
|
|
Main detection pipeline orchestration that coordinates detection flow and execution.
|
|
"""
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Any
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import numpy as np
|
|
|
|
from ..models.inference import YOLOWrapper
|
|
from ..models.pipeline import PipelineParser
|
|
from .branches import BranchProcessor
|
|
from ..storage.redis import RedisManager
|
|
from ..storage.database import DatabaseManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DetectionPipeline:
|
|
"""
|
|
Main detection pipeline that orchestrates the complete detection flow.
|
|
Handles detection execution, branch coordination, and result aggregation.
|
|
"""
|
|
|
|
def __init__(self, pipeline_parser: PipelineParser, model_manager: Any, message_sender=None):
|
|
"""
|
|
Initialize detection pipeline.
|
|
|
|
Args:
|
|
pipeline_parser: Pipeline parser with loaded configuration
|
|
model_manager: Model manager for loading models
|
|
message_sender: Optional callback function for sending WebSocket messages
|
|
"""
|
|
self.pipeline_parser = pipeline_parser
|
|
self.model_manager = model_manager
|
|
self.message_sender = message_sender
|
|
|
|
# Initialize components
|
|
self.branch_processor = BranchProcessor(model_manager)
|
|
self.redis_manager = None
|
|
self.db_manager = None
|
|
|
|
# Main detection model
|
|
self.detection_model: Optional[YOLOWrapper] = None
|
|
self.detection_model_id = None
|
|
|
|
# Thread pool for parallel processing
|
|
self.executor = ThreadPoolExecutor(max_workers=4)
|
|
|
|
# Pipeline configuration
|
|
self.pipeline_config = pipeline_parser.pipeline_config
|
|
|
|
# Statistics
|
|
self.stats = {
|
|
'detections_processed': 0,
|
|
'branches_executed': 0,
|
|
'actions_executed': 0,
|
|
'total_processing_time': 0.0
|
|
}
|
|
|
|
logger.info("DetectionPipeline initialized")
|
|
|
|
async def initialize(self) -> bool:
|
|
"""
|
|
Initialize all pipeline components including models, Redis, and database.
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
# Initialize Redis connection
|
|
if self.pipeline_parser.redis_config:
|
|
self.redis_manager = RedisManager(self.pipeline_parser.redis_config.__dict__)
|
|
if not await self.redis_manager.initialize():
|
|
logger.error("Failed to initialize Redis connection")
|
|
return False
|
|
logger.info("Redis connection initialized")
|
|
|
|
# Initialize database connection
|
|
if self.pipeline_parser.postgresql_config:
|
|
self.db_manager = DatabaseManager(self.pipeline_parser.postgresql_config.__dict__)
|
|
if not self.db_manager.connect():
|
|
logger.error("Failed to initialize database connection")
|
|
return False
|
|
# Create required tables
|
|
if not self.db_manager.create_car_frontal_info_table():
|
|
logger.warning("Failed to create car_frontal_info table")
|
|
logger.info("Database connection initialized")
|
|
|
|
# Initialize main detection model
|
|
if not await self._initialize_detection_model():
|
|
logger.error("Failed to initialize detection model")
|
|
return False
|
|
|
|
# Initialize branch processor
|
|
if not await self.branch_processor.initialize(
|
|
self.pipeline_config,
|
|
self.redis_manager,
|
|
self.db_manager
|
|
):
|
|
logger.error("Failed to initialize branch processor")
|
|
return False
|
|
|
|
logger.info("Detection pipeline initialization completed successfully")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error initializing detection pipeline: {e}", exc_info=True)
|
|
return False
|
|
|
|
async def _initialize_detection_model(self) -> bool:
|
|
"""
|
|
Load and initialize the main detection model.
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
if not self.pipeline_config:
|
|
logger.warning("No pipeline configuration found")
|
|
return False
|
|
|
|
model_file = getattr(self.pipeline_config, 'model_file', None)
|
|
model_id = getattr(self.pipeline_config, 'model_id', None)
|
|
|
|
if not model_file:
|
|
logger.warning("No detection model file specified")
|
|
return False
|
|
|
|
# Load detection model
|
|
logger.info(f"Loading detection model: {model_id} ({model_file})")
|
|
# Get the model ID from the ModelManager context
|
|
pipeline_models = list(self.model_manager.get_all_downloaded_models())
|
|
if pipeline_models:
|
|
actual_model_id = pipeline_models[0] # Use the first available model
|
|
self.detection_model = self.model_manager.get_yolo_model(actual_model_id, model_file)
|
|
else:
|
|
logger.error("No models available in ModelManager")
|
|
return False
|
|
|
|
self.detection_model_id = model_id
|
|
|
|
if self.detection_model:
|
|
logger.info(f"Detection model {model_id} loaded successfully")
|
|
return True
|
|
else:
|
|
logger.error(f"Failed to load detection model {model_id}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error initializing detection model: {e}", exc_info=True)
|
|
return False
|
|
|
|
async def execute_detection_phase(self,
|
|
frame: np.ndarray,
|
|
display_id: str,
|
|
subscription_id: str) -> Dict[str, Any]:
|
|
"""
|
|
Execute only the detection phase - run main detection and send imageDetection message.
|
|
This is the first phase that runs when a vehicle is validated.
|
|
|
|
Args:
|
|
frame: Input frame to process
|
|
display_id: Display identifier
|
|
subscription_id: Subscription identifier
|
|
|
|
Returns:
|
|
Dictionary with detection phase results
|
|
"""
|
|
start_time = time.time()
|
|
result = {
|
|
'status': 'success',
|
|
'detections': [],
|
|
'message_sent': False,
|
|
'processing_time': 0.0,
|
|
'timestamp': datetime.now().isoformat()
|
|
}
|
|
|
|
try:
|
|
# Run main detection model
|
|
if not self.detection_model:
|
|
result['status'] = 'error'
|
|
result['message'] = 'Detection model not available'
|
|
return result
|
|
|
|
# Create detection context
|
|
detection_context = {
|
|
'display_id': display_id,
|
|
'subscription_id': subscription_id,
|
|
'timestamp': datetime.now().strftime("%Y-%m-%dT%H-%M-%S"),
|
|
'timestamp_ms': int(time.time() * 1000)
|
|
}
|
|
|
|
# Run inference on single snapshot using .predict() method
|
|
detection_results = self.detection_model.model.predict(
|
|
frame,
|
|
conf=getattr(self.pipeline_config, 'min_confidence', 0.6),
|
|
verbose=False
|
|
)
|
|
|
|
# Process detection results using clean logic
|
|
valid_detections = []
|
|
detected_regions = {}
|
|
|
|
if detection_results and len(detection_results) > 0:
|
|
result_obj = detection_results[0]
|
|
trigger_classes = getattr(self.pipeline_config, 'trigger_classes', [])
|
|
|
|
# Handle .predict() results which have .boxes for detection models
|
|
if hasattr(result_obj, 'boxes') and result_obj.boxes is not None:
|
|
logger.info(f"[DETECTION PHASE] Found {len(result_obj.boxes)} raw detections from {getattr(self.pipeline_config, 'model_id', 'unknown')}")
|
|
|
|
for i, box in enumerate(result_obj.boxes):
|
|
class_id = int(box.cls[0])
|
|
confidence = float(box.conf[0])
|
|
bbox = box.xyxy[0].cpu().numpy().tolist() # [x1, y1, x2, y2]
|
|
class_name = self.detection_model.model.names[class_id]
|
|
|
|
logger.info(f"[DETECTION PHASE {i+1}] {class_name}: bbox={bbox}, conf={confidence:.3f}")
|
|
|
|
# Check if detection matches trigger classes
|
|
if trigger_classes and class_name not in trigger_classes:
|
|
logger.debug(f"[DETECTION PHASE] Filtered '{class_name}' - not in trigger_classes {trigger_classes}")
|
|
continue
|
|
|
|
logger.info(f"[DETECTION PHASE] Accepted '{class_name}' - matches trigger_classes")
|
|
|
|
# Store detection info
|
|
detection_info = {
|
|
'class_name': class_name,
|
|
'confidence': confidence,
|
|
'bbox': bbox
|
|
}
|
|
valid_detections.append(detection_info)
|
|
|
|
# Store region for processing phase
|
|
detected_regions[class_name] = {
|
|
'bbox': bbox,
|
|
'confidence': confidence
|
|
}
|
|
else:
|
|
logger.warning("[DETECTION PHASE] No boxes found in detection results")
|
|
|
|
# Store detected_regions in result for processing phase
|
|
result['detected_regions'] = detected_regions
|
|
|
|
result['detections'] = valid_detections
|
|
|
|
# If we have valid detections, send imageDetection message with empty detection
|
|
if valid_detections:
|
|
logger.info(f"Found {len(valid_detections)} valid detections, sending imageDetection message")
|
|
|
|
# Send imageDetection with empty detection data
|
|
message_sent = await self._send_image_detection_message(
|
|
subscription_id=subscription_id,
|
|
detection_context=detection_context
|
|
)
|
|
result['message_sent'] = message_sent
|
|
|
|
if message_sent:
|
|
logger.info(f"Detection phase completed - imageDetection message sent for {display_id}")
|
|
else:
|
|
logger.warning(f"Failed to send imageDetection message for {display_id}")
|
|
else:
|
|
logger.debug("No valid detections found in detection phase")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in detection phase: {e}", exc_info=True)
|
|
result['status'] = 'error'
|
|
result['message'] = str(e)
|
|
|
|
result['processing_time'] = time.time() - start_time
|
|
return result
|
|
|
|
async def execute_processing_phase(self,
|
|
frame: np.ndarray,
|
|
display_id: str,
|
|
session_id: str,
|
|
subscription_id: str,
|
|
detected_regions: Dict[str, Any] = None) -> Dict[str, Any]:
|
|
"""
|
|
Execute the processing phase - run branches and database operations after receiving sessionId.
|
|
This is the second phase that runs after backend sends setSessionId.
|
|
|
|
Args:
|
|
frame: Input frame to process
|
|
display_id: Display identifier
|
|
session_id: Session ID from backend
|
|
subscription_id: Subscription identifier
|
|
detected_regions: Pre-detected regions from detection phase
|
|
|
|
Returns:
|
|
Dictionary with processing phase results
|
|
"""
|
|
start_time = time.time()
|
|
result = {
|
|
'status': 'success',
|
|
'branch_results': {},
|
|
'actions_executed': [],
|
|
'session_id': session_id,
|
|
'processing_time': 0.0,
|
|
'timestamp': datetime.now().isoformat()
|
|
}
|
|
|
|
try:
|
|
# Create enhanced detection context with session_id
|
|
detection_context = {
|
|
'display_id': display_id,
|
|
'session_id': session_id,
|
|
'subscription_id': subscription_id,
|
|
'timestamp': datetime.now().strftime("%Y-%m-%dT%H-%M-%S"),
|
|
'timestamp_ms': int(time.time() * 1000),
|
|
'uuid': str(uuid.uuid4()),
|
|
'filename': f"{uuid.uuid4()}.jpg"
|
|
}
|
|
|
|
# If no detected_regions provided, re-run detection to get them
|
|
if not detected_regions:
|
|
# Use .predict() method for detection
|
|
detection_results = self.detection_model.model.predict(
|
|
frame,
|
|
conf=getattr(self.pipeline_config, 'min_confidence', 0.6),
|
|
verbose=False
|
|
)
|
|
|
|
detected_regions = {}
|
|
if detection_results and len(detection_results) > 0:
|
|
result_obj = detection_results[0]
|
|
if hasattr(result_obj, 'boxes') and result_obj.boxes is not None:
|
|
for box in result_obj.boxes:
|
|
class_id = int(box.cls[0])
|
|
confidence = float(box.conf[0])
|
|
bbox = box.xyxy[0].cpu().numpy().tolist() # [x1, y1, x2, y2]
|
|
class_name = self.detection_model.model.names[class_id]
|
|
|
|
detected_regions[class_name] = {
|
|
'bbox': bbox,
|
|
'confidence': confidence
|
|
}
|
|
|
|
# Initialize database record with session_id
|
|
if session_id and self.db_manager:
|
|
success = self.db_manager.insert_initial_detection(
|
|
display_id=display_id,
|
|
captured_timestamp=detection_context['timestamp'],
|
|
session_id=session_id
|
|
)
|
|
if success:
|
|
logger.info(f"Created initial database record with session {session_id}")
|
|
else:
|
|
logger.warning(f"Failed to create initial database record for session {session_id}")
|
|
|
|
# Execute branches in parallel
|
|
if hasattr(self.pipeline_config, 'branches') and self.pipeline_config.branches:
|
|
branch_results = await self.branch_processor.execute_branches(
|
|
frame=frame,
|
|
branches=self.pipeline_config.branches,
|
|
detected_regions=detected_regions,
|
|
detection_context=detection_context
|
|
)
|
|
result['branch_results'] = branch_results
|
|
logger.info(f"Executed {len(branch_results)} branches for session {session_id}")
|
|
|
|
# Execute immediate actions (non-parallel)
|
|
immediate_actions = getattr(self.pipeline_config, 'actions', [])
|
|
if immediate_actions:
|
|
executed_actions = await self._execute_immediate_actions(
|
|
actions=immediate_actions,
|
|
frame=frame,
|
|
detected_regions=detected_regions,
|
|
detection_context=detection_context
|
|
)
|
|
result['actions_executed'].extend(executed_actions)
|
|
|
|
# Execute parallel actions (after all branches complete)
|
|
parallel_actions = getattr(self.pipeline_config, 'parallel_actions', [])
|
|
if parallel_actions:
|
|
# Add branch results to context
|
|
enhanced_context = {**detection_context}
|
|
if result['branch_results']:
|
|
enhanced_context['branch_results'] = result['branch_results']
|
|
|
|
executed_parallel_actions = await self._execute_parallel_actions(
|
|
actions=parallel_actions,
|
|
frame=frame,
|
|
detected_regions=detected_regions,
|
|
context=enhanced_context
|
|
)
|
|
result['actions_executed'].extend(executed_parallel_actions)
|
|
|
|
logger.info(f"Processing phase completed for session {session_id}: "
|
|
f"{len(result['branch_results'])} branches, {len(result['actions_executed'])} actions")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in processing phase: {e}", exc_info=True)
|
|
result['status'] = 'error'
|
|
result['message'] = str(e)
|
|
|
|
result['processing_time'] = time.time() - start_time
|
|
return result
|
|
|
|
async def _send_image_detection_message(self,
|
|
subscription_id: str,
|
|
detection_context: Dict[str, Any]) -> bool:
|
|
"""
|
|
Send imageDetection message with empty detection data to backend.
|
|
|
|
Args:
|
|
subscription_id: Subscription identifier
|
|
detection_context: Detection context data
|
|
|
|
Returns:
|
|
True if message sent successfully, False otherwise
|
|
"""
|
|
try:
|
|
if not self.message_sender:
|
|
logger.warning("No message sender available for imageDetection")
|
|
return False
|
|
|
|
# Import here to avoid circular imports
|
|
from ..communication.messages import create_image_detection
|
|
|
|
# Create empty detection data as specified
|
|
detection_data = {}
|
|
|
|
# Get model info from pipeline configuration
|
|
model_id = 52 # Default model ID
|
|
model_name = "yolo11m" # Default
|
|
|
|
if self.pipeline_config:
|
|
model_name = getattr(self.pipeline_config, 'model_id', 'yolo11m')
|
|
# Try to extract numeric model ID from pipeline context, fallback to default
|
|
if hasattr(self.pipeline_config, 'model_id'):
|
|
# For now, use default model ID since pipeline config stores string identifiers
|
|
model_id = 52
|
|
|
|
# Create imageDetection message
|
|
detection_message = create_image_detection(
|
|
subscription_identifier=subscription_id,
|
|
detection_data=detection_data,
|
|
model_id=model_id,
|
|
model_name=model_name
|
|
)
|
|
|
|
# Send to backend via WebSocket
|
|
await self.message_sender(detection_message)
|
|
logger.info(f"[DETECTION PHASE] Sent imageDetection with empty detection: {detection_data}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error sending imageDetection message: {e}", exc_info=True)
|
|
return False
|
|
|
|
async def execute_detection(self,
|
|
frame: np.ndarray,
|
|
display_id: str,
|
|
session_id: Optional[str] = None,
|
|
subscription_id: Optional[str] = None) -> Dict[str, Any]:
|
|
"""
|
|
Execute the main detection pipeline on a frame.
|
|
|
|
Args:
|
|
frame: Input frame to process
|
|
display_id: Display identifier
|
|
session_id: Optional session ID
|
|
subscription_id: Optional subscription identifier
|
|
|
|
Returns:
|
|
Dictionary with detection results
|
|
"""
|
|
start_time = time.time()
|
|
result = {
|
|
'status': 'success',
|
|
'detections': [],
|
|
'branch_results': {},
|
|
'actions_executed': [],
|
|
'session_id': session_id,
|
|
'processing_time': 0.0,
|
|
'timestamp': datetime.now().isoformat()
|
|
}
|
|
|
|
try:
|
|
# Update stats
|
|
self.stats['detections_processed'] += 1
|
|
|
|
# Run main detection model
|
|
if not self.detection_model:
|
|
result['status'] = 'error'
|
|
result['message'] = 'Detection model not available'
|
|
return result
|
|
|
|
# Create detection context
|
|
detection_context = {
|
|
'display_id': display_id,
|
|
'session_id': session_id,
|
|
'subscription_id': subscription_id,
|
|
'timestamp': datetime.now().strftime("%Y-%m-%dT%H-%M-%S"),
|
|
'timestamp_ms': int(time.time() * 1000),
|
|
'uuid': str(uuid.uuid4()),
|
|
'filename': f"{uuid.uuid4()}.jpg"
|
|
}
|
|
|
|
# Save full frame for debugging
|
|
import cv2
|
|
debug_dir = "/Users/ziesorx/Documents/Work/Adsist/Bangchak/worker/python-detector-worker/debug_frames"
|
|
timestamp = detection_context.get('timestamp', 'unknown')
|
|
session_id = detection_context.get('session_id', 'unknown')
|
|
debug_filename = f"{debug_dir}/pipeline_full_frame_{session_id}_{timestamp}.jpg"
|
|
try:
|
|
cv2.imwrite(debug_filename, frame)
|
|
logger.info(f"[DEBUG PIPELINE] Saved full input frame: {debug_filename} ({frame.shape[1]}x{frame.shape[0]})")
|
|
except Exception as e:
|
|
logger.warning(f"[DEBUG PIPELINE] Failed to save debug frame: {e}")
|
|
|
|
# Run inference on single snapshot using .predict() method
|
|
detection_results = self.detection_model.model.predict(
|
|
frame,
|
|
conf=getattr(self.pipeline_config, 'min_confidence', 0.6),
|
|
verbose=False
|
|
)
|
|
|
|
# Process detection results
|
|
detected_regions = {}
|
|
valid_detections = []
|
|
|
|
if detection_results and len(detection_results) > 0:
|
|
result_obj = detection_results[0]
|
|
trigger_classes = getattr(self.pipeline_config, 'trigger_classes', [])
|
|
|
|
# Handle .predict() results which have .boxes for detection models
|
|
if hasattr(result_obj, 'boxes') and result_obj.boxes is not None:
|
|
logger.info(f"[PIPELINE RAW] Found {len(result_obj.boxes)} raw detections from {getattr(self.pipeline_config, 'model_id', 'unknown')}")
|
|
|
|
for i, box in enumerate(result_obj.boxes):
|
|
class_id = int(box.cls[0])
|
|
confidence = float(box.conf[0])
|
|
bbox = box.xyxy[0].cpu().numpy().tolist() # [x1, y1, x2, y2]
|
|
class_name = self.detection_model.model.names[class_id]
|
|
|
|
logger.info(f"[PIPELINE RAW {i+1}] {class_name}: bbox={bbox}, conf={confidence:.3f}")
|
|
|
|
# Check if detection matches trigger classes
|
|
if trigger_classes and class_name not in trigger_classes:
|
|
continue
|
|
|
|
# Store detection info
|
|
detection_info = {
|
|
'class_name': class_name,
|
|
'confidence': confidence,
|
|
'bbox': bbox
|
|
}
|
|
valid_detections.append(detection_info)
|
|
|
|
# Store region for cropping
|
|
detected_regions[class_name] = {
|
|
'bbox': bbox,
|
|
'confidence': confidence
|
|
}
|
|
logger.info(f"[PIPELINE DETECTION] {class_name}: bbox={bbox}, conf={confidence:.3f}")
|
|
|
|
result['detections'] = valid_detections
|
|
|
|
# If we have valid detections, proceed with branches and actions
|
|
if valid_detections:
|
|
logger.info(f"Found {len(valid_detections)} valid detections for pipeline processing")
|
|
|
|
# Initialize database record if session_id is provided
|
|
if session_id and self.db_manager:
|
|
success = self.db_manager.insert_initial_detection(
|
|
display_id=display_id,
|
|
captured_timestamp=detection_context['timestamp'],
|
|
session_id=session_id
|
|
)
|
|
if not success:
|
|
logger.warning(f"Failed to create initial database record for session {session_id}")
|
|
|
|
# Execute branches in parallel
|
|
if hasattr(self.pipeline_config, 'branches') and self.pipeline_config.branches:
|
|
branch_results = await self.branch_processor.execute_branches(
|
|
frame=frame,
|
|
branches=self.pipeline_config.branches,
|
|
detected_regions=detected_regions,
|
|
detection_context=detection_context
|
|
)
|
|
result['branch_results'] = branch_results
|
|
self.stats['branches_executed'] += len(branch_results)
|
|
|
|
# Execute immediate actions (non-parallel)
|
|
immediate_actions = getattr(self.pipeline_config, 'actions', [])
|
|
if immediate_actions:
|
|
executed_actions = await self._execute_immediate_actions(
|
|
actions=immediate_actions,
|
|
frame=frame,
|
|
detected_regions=detected_regions,
|
|
detection_context=detection_context
|
|
)
|
|
result['actions_executed'].extend(executed_actions)
|
|
|
|
# Execute parallel actions (after all branches complete)
|
|
parallel_actions = getattr(self.pipeline_config, 'parallel_actions', [])
|
|
if parallel_actions:
|
|
# Add branch results to context
|
|
enhanced_context = {**detection_context}
|
|
if result['branch_results']:
|
|
enhanced_context['branch_results'] = result['branch_results']
|
|
|
|
executed_parallel_actions = await self._execute_parallel_actions(
|
|
actions=parallel_actions,
|
|
frame=frame,
|
|
detected_regions=detected_regions,
|
|
context=enhanced_context
|
|
)
|
|
result['actions_executed'].extend(executed_parallel_actions)
|
|
|
|
self.stats['actions_executed'] += len(result['actions_executed'])
|
|
else:
|
|
logger.debug("No valid detections found for pipeline processing")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in detection pipeline execution: {e}", exc_info=True)
|
|
result['status'] = 'error'
|
|
result['message'] = str(e)
|
|
|
|
# Update timing
|
|
processing_time = time.time() - start_time
|
|
result['processing_time'] = processing_time
|
|
self.stats['total_processing_time'] += processing_time
|
|
|
|
return result
|
|
|
|
async def _execute_immediate_actions(self,
|
|
actions: List[Dict],
|
|
frame: np.ndarray,
|
|
detected_regions: Dict[str, Any],
|
|
detection_context: Dict[str, Any]) -> List[Dict]:
|
|
"""
|
|
Execute immediate actions (non-parallel).
|
|
|
|
Args:
|
|
actions: List of action configurations
|
|
frame: Input frame
|
|
detected_regions: Dictionary of detected regions
|
|
detection_context: Detection context data
|
|
|
|
Returns:
|
|
List of executed action results
|
|
"""
|
|
executed_actions = []
|
|
|
|
for action in actions:
|
|
try:
|
|
action_type = action.type.value
|
|
logger.debug(f"Executing immediate action: {action_type}")
|
|
|
|
if action_type == 'redis_save_image':
|
|
result = await self._execute_redis_save_image(
|
|
action, frame, detected_regions, detection_context
|
|
)
|
|
elif action_type == 'redis_publish':
|
|
result = await self._execute_redis_publish(
|
|
action, detection_context
|
|
)
|
|
else:
|
|
logger.warning(f"Unknown immediate action type: {action_type}")
|
|
result = {'status': 'error', 'message': f'Unknown action type: {action_type}'}
|
|
|
|
executed_actions.append({
|
|
'action_type': action_type,
|
|
'result': result
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error executing immediate action {action_type}: {e}", exc_info=True)
|
|
executed_actions.append({
|
|
'action_type': action.type.value,
|
|
'result': {'status': 'error', 'message': str(e)}
|
|
})
|
|
|
|
return executed_actions
|
|
|
|
async def _execute_parallel_actions(self,
|
|
actions: List[Dict],
|
|
frame: np.ndarray,
|
|
detected_regions: Dict[str, Any],
|
|
context: Dict[str, Any]) -> List[Dict]:
|
|
"""
|
|
Execute parallel actions (after branches complete).
|
|
|
|
Args:
|
|
actions: List of parallel action configurations
|
|
frame: Input frame
|
|
detected_regions: Dictionary of detected regions
|
|
context: Enhanced context with branch results
|
|
|
|
Returns:
|
|
List of executed action results
|
|
"""
|
|
executed_actions = []
|
|
|
|
for action in actions:
|
|
try:
|
|
action_type = action.type.value
|
|
logger.debug(f"Executing parallel action: {action_type}")
|
|
|
|
if action_type == 'postgresql_update_combined':
|
|
result = await self._execute_postgresql_update_combined(action, context)
|
|
|
|
# Send imageDetection message with actual processing results after database update
|
|
if result.get('status') == 'success':
|
|
await self._send_processing_results_message(context)
|
|
else:
|
|
logger.warning(f"Unknown parallel action type: {action_type}")
|
|
result = {'status': 'error', 'message': f'Unknown action type: {action_type}'}
|
|
|
|
executed_actions.append({
|
|
'action_type': action_type,
|
|
'result': result
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error executing parallel action {action_type}: {e}", exc_info=True)
|
|
executed_actions.append({
|
|
'action_type': action.type.value,
|
|
'result': {'status': 'error', 'message': str(e)}
|
|
})
|
|
|
|
return executed_actions
|
|
|
|
async def _execute_redis_save_image(self,
|
|
action: Dict,
|
|
frame: np.ndarray,
|
|
detected_regions: Dict[str, Any],
|
|
context: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Execute redis_save_image action."""
|
|
if not self.redis_manager:
|
|
return {'status': 'error', 'message': 'Redis not available'}
|
|
|
|
try:
|
|
# Get image to save (cropped or full frame)
|
|
image_to_save = frame
|
|
region_name = action.params.get('region')
|
|
|
|
if region_name and region_name in detected_regions:
|
|
# Crop the specified region
|
|
bbox = detected_regions[region_name]['bbox']
|
|
x1, y1, x2, y2 = [int(coord) for coord in bbox]
|
|
cropped = frame[y1:y2, x1:x2]
|
|
if cropped.size > 0:
|
|
image_to_save = cropped
|
|
logger.debug(f"Cropped region '{region_name}' for redis_save_image")
|
|
else:
|
|
logger.warning(f"Empty crop for region '{region_name}', using full frame")
|
|
|
|
# Format key with context
|
|
key = action.params['key'].format(**context)
|
|
|
|
# Save image to Redis
|
|
result = await self.redis_manager.save_image(
|
|
key=key,
|
|
image=image_to_save,
|
|
expire_seconds=action.params.get('expire_seconds'),
|
|
image_format=action.params.get('format', 'jpeg'),
|
|
quality=action.params.get('quality', 90)
|
|
)
|
|
|
|
if result:
|
|
# Add image_key to context for subsequent actions
|
|
context['image_key'] = key
|
|
return {'status': 'success', 'key': key}
|
|
else:
|
|
return {'status': 'error', 'message': 'Failed to save image to Redis'}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in redis_save_image action: {e}", exc_info=True)
|
|
return {'status': 'error', 'message': str(e)}
|
|
|
|
async def _execute_redis_publish(self, action: Dict, context: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Execute redis_publish action."""
|
|
if not self.redis_manager:
|
|
return {'status': 'error', 'message': 'Redis not available'}
|
|
|
|
try:
|
|
channel = action.params['channel']
|
|
message_template = action.params['message']
|
|
|
|
# Format message with context
|
|
message = message_template.format(**context)
|
|
|
|
# Publish message
|
|
result = await self.redis_manager.publish_message(channel, message)
|
|
|
|
if result >= 0: # Redis publish returns number of subscribers
|
|
return {'status': 'success', 'subscribers': result, 'channel': channel}
|
|
else:
|
|
return {'status': 'error', 'message': 'Failed to publish message to Redis'}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in redis_publish action: {e}", exc_info=True)
|
|
return {'status': 'error', 'message': str(e)}
|
|
|
|
async def _execute_postgresql_update_combined(self,
|
|
action: Dict,
|
|
context: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Execute postgresql_update_combined action."""
|
|
if not self.db_manager:
|
|
return {'status': 'error', 'message': 'Database not available'}
|
|
|
|
try:
|
|
# Wait for required branches if specified
|
|
wait_for_branches = action.params.get('waitForBranches', [])
|
|
branch_results = context.get('branch_results', {})
|
|
|
|
# Check if all required branches have completed
|
|
for branch_id in wait_for_branches:
|
|
if branch_id not in branch_results:
|
|
logger.warning(f"Branch {branch_id} result not available for database update")
|
|
return {'status': 'error', 'message': f'Missing branch result: {branch_id}'}
|
|
|
|
# Prepare fields for database update
|
|
table = action.params.get('table', 'car_frontal_info')
|
|
key_field = action.params.get('key_field', 'session_id')
|
|
key_value = action.params.get('key_value', '{session_id}').format(**context)
|
|
field_mappings = action.params.get('fields', {})
|
|
|
|
# Resolve field values using branch results
|
|
resolved_fields = {}
|
|
for field_name, field_template in field_mappings.items():
|
|
try:
|
|
# Replace template variables with actual values from branch results
|
|
resolved_value = self._resolve_field_template(field_template, branch_results, context)
|
|
resolved_fields[field_name] = resolved_value
|
|
except Exception as e:
|
|
logger.warning(f"Failed to resolve field {field_name}: {e}")
|
|
resolved_fields[field_name] = None
|
|
|
|
# Execute database update
|
|
success = self.db_manager.execute_update(
|
|
table=table,
|
|
key_field=key_field,
|
|
key_value=key_value,
|
|
fields=resolved_fields
|
|
)
|
|
|
|
if success:
|
|
return {'status': 'success', 'table': table, 'key': f'{key_field}={key_value}', 'fields': resolved_fields}
|
|
else:
|
|
return {'status': 'error', 'message': 'Database update failed'}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in postgresql_update_combined action: {e}", exc_info=True)
|
|
return {'status': 'error', 'message': str(e)}
|
|
|
|
def _resolve_field_template(self, template: str, branch_results: Dict, context: Dict) -> str:
|
|
"""
|
|
Resolve field template using branch results and context.
|
|
|
|
Args:
|
|
template: Template string like "{car_brand_cls_v2.brand}"
|
|
branch_results: Dictionary of branch execution results
|
|
context: Detection context
|
|
|
|
Returns:
|
|
Resolved field value
|
|
"""
|
|
try:
|
|
# Handle simple context variables first
|
|
if template.startswith('{') and template.endswith('}'):
|
|
var_name = template[1:-1]
|
|
|
|
# Check for branch result reference (e.g., "car_brand_cls_v2.brand")
|
|
if '.' in var_name:
|
|
branch_id, field_name = var_name.split('.', 1)
|
|
if branch_id in branch_results:
|
|
branch_data = branch_results[branch_id]
|
|
# Look for the field in branch results
|
|
if isinstance(branch_data, dict) and 'result' in branch_data:
|
|
result_data = branch_data['result']
|
|
if isinstance(result_data, dict) and field_name in result_data:
|
|
return str(result_data[field_name])
|
|
logger.warning(f"Field {field_name} not found in branch {branch_id} results")
|
|
return None
|
|
else:
|
|
logger.warning(f"Branch {branch_id} not found in results")
|
|
return None
|
|
|
|
# Simple context variable
|
|
elif var_name in context:
|
|
return str(context[var_name])
|
|
|
|
logger.warning(f"Template variable {var_name} not found in context or branch results")
|
|
return None
|
|
|
|
# Return template as-is if not a template variable
|
|
return template
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error resolving field template {template}: {e}")
|
|
return None
|
|
|
|
async def _send_processing_results_message(self, context: Dict[str, Any]):
|
|
"""
|
|
Send imageDetection message with actual processing results after database update.
|
|
|
|
Args:
|
|
context: Detection context containing branch results and subscription info
|
|
"""
|
|
try:
|
|
branch_results = context.get('branch_results', {})
|
|
|
|
# Extract detection results from branch results
|
|
detection_data = {
|
|
"carBrand": None,
|
|
"carModel": None,
|
|
"bodyType": None,
|
|
"licensePlateText": None,
|
|
"licensePlateConfidence": None
|
|
}
|
|
|
|
# Extract car brand from car_brand_cls_v2 results
|
|
if 'car_brand_cls_v2' in branch_results:
|
|
brand_result = branch_results['car_brand_cls_v2'].get('result', {})
|
|
detection_data["carBrand"] = brand_result.get('brand')
|
|
|
|
# Extract body type from car_bodytype_cls_v1 results
|
|
if 'car_bodytype_cls_v1' in branch_results:
|
|
bodytype_result = branch_results['car_bodytype_cls_v1'].get('result', {})
|
|
detection_data["bodyType"] = bodytype_result.get('body_type')
|
|
|
|
# Create detection message
|
|
subscription_id = context.get('subscription_id', '')
|
|
# Get the actual numeric model ID from context
|
|
model_id_value = context.get('model_id', 52)
|
|
if isinstance(model_id_value, str):
|
|
try:
|
|
model_id_value = int(model_id_value)
|
|
except (ValueError, TypeError):
|
|
model_id_value = 52
|
|
model_name = str(getattr(self.pipeline_config, 'model_id', 'unknown'))
|
|
|
|
logger.debug(f"Creating DetectionData with modelId={model_id_value}, modelName='{model_name}'")
|
|
|
|
from core.communication.models import ImageDetectionMessage, DetectionData
|
|
detection_data_obj = DetectionData(
|
|
detection=detection_data,
|
|
modelId=model_id_value,
|
|
modelName=model_name
|
|
)
|
|
detection_message = ImageDetectionMessage(
|
|
subscriptionIdentifier=subscription_id,
|
|
data=detection_data_obj
|
|
)
|
|
|
|
# Send to backend via WebSocket
|
|
if self.message_sender:
|
|
await self.message_sender(detection_message)
|
|
logger.info(f"[RESULTS] Sent imageDetection with processing results: {detection_data}")
|
|
else:
|
|
logger.warning("No message sender available for processing results")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error sending processing results message: {e}", exc_info=True)
|
|
|
|
def get_statistics(self) -> Dict[str, Any]:
|
|
"""Get detection pipeline statistics."""
|
|
branch_stats = self.branch_processor.get_statistics() if self.branch_processor else {}
|
|
|
|
return {
|
|
'pipeline': self.stats,
|
|
'branches': branch_stats,
|
|
'redis_available': self.redis_manager is not None,
|
|
'database_available': self.db_manager is not None,
|
|
'detection_model_loaded': self.detection_model is not None
|
|
}
|
|
|
|
def cleanup(self):
|
|
"""Cleanup resources."""
|
|
if self.executor:
|
|
self.executor.shutdown(wait=False)
|
|
|
|
if self.redis_manager:
|
|
self.redis_manager.cleanup()
|
|
|
|
if self.db_manager:
|
|
self.db_manager.disconnect()
|
|
|
|
if self.branch_processor:
|
|
self.branch_processor.cleanup()
|
|
|
|
logger.info("Detection pipeline cleaned up") |