Done brand and body type detection with postgresql integration
This commit is contained in:
		
							parent
							
								
									18c62a2370
								
							
						
					
					
						commit
						8c429cc8f6
					
				
					 2 changed files with 282 additions and 10 deletions
				
			
		| 
						 | 
				
			
			@ -2,6 +2,7 @@ import psycopg2
 | 
			
		|||
import psycopg2.extras
 | 
			
		||||
from typing import Optional, Dict, Any
 | 
			
		||||
import logging
 | 
			
		||||
import uuid
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -90,8 +91,11 @@ class DatabaseManager:
 | 
			
		|||
                    set_clauses.append(f"{field} = %s")
 | 
			
		||||
                    values.append(value)
 | 
			
		||||
            
 | 
			
		||||
            # Add schema prefix if table doesn't already have it
 | 
			
		||||
            full_table_name = table if '.' in table else f"gas_station_1.{table}"
 | 
			
		||||
            
 | 
			
		||||
            query = f"""
 | 
			
		||||
            INSERT INTO {table} ({key_field}, {', '.join(fields.keys())})
 | 
			
		||||
            INSERT INTO {full_table_name} ({key_field}, {', '.join(fields.keys())})
 | 
			
		||||
            VALUES (%s, {', '.join(['%s'] * len(fields))})
 | 
			
		||||
            ON CONFLICT ({key_field})
 | 
			
		||||
            DO UPDATE SET {', '.join(set_clauses)}
 | 
			
		||||
| 
						 | 
				
			
			@ -109,4 +113,80 @@ class DatabaseManager:
 | 
			
		|||
            logger.error(f"Failed to execute update on {table}: {e}")
 | 
			
		||||
            if self.connection:
 | 
			
		||||
                self.connection.rollback()
 | 
			
		||||
            return False
 | 
			
		||||
            return False
 | 
			
		||||
    
 | 
			
		||||
    def create_car_frontal_info_table(self) -> bool:
 | 
			
		||||
        """Create the car_frontal_info table in gas_station_1 schema if it doesn't exist."""
 | 
			
		||||
        if not self.is_connected():
 | 
			
		||||
            if not self.connect():
 | 
			
		||||
                return False
 | 
			
		||||
        
 | 
			
		||||
        try:
 | 
			
		||||
            cur = self.connection.cursor()
 | 
			
		||||
            
 | 
			
		||||
            # Create schema if it doesn't exist
 | 
			
		||||
            cur.execute("CREATE SCHEMA IF NOT EXISTS gas_station_1")
 | 
			
		||||
            
 | 
			
		||||
            # Create table if it doesn't exist
 | 
			
		||||
            create_table_query = """
 | 
			
		||||
            CREATE TABLE IF NOT EXISTS gas_station_1.car_frontal_info (
 | 
			
		||||
                camera_id VARCHAR(255),
 | 
			
		||||
                captured_timestamp VARCHAR(255),
 | 
			
		||||
                session_id VARCHAR(255) PRIMARY KEY,
 | 
			
		||||
                license_character VARCHAR(255) DEFAULT NULL,
 | 
			
		||||
                license_type VARCHAR(255) DEFAULT 'No model available',
 | 
			
		||||
                car_brand VARCHAR(255) DEFAULT NULL,
 | 
			
		||||
                car_model VARCHAR(255) DEFAULT NULL,
 | 
			
		||||
                car_body_type VARCHAR(255) DEFAULT NULL,
 | 
			
		||||
                created_at TIMESTAMP DEFAULT NOW(),
 | 
			
		||||
                updated_at TIMESTAMP DEFAULT NOW()
 | 
			
		||||
            )
 | 
			
		||||
            """
 | 
			
		||||
            
 | 
			
		||||
            cur.execute(create_table_query)
 | 
			
		||||
            self.connection.commit()
 | 
			
		||||
            cur.close()
 | 
			
		||||
            logger.info("Successfully created/verified car_frontal_info table in gas_station_1 schema")
 | 
			
		||||
            return True
 | 
			
		||||
            
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Failed to create car_frontal_info table: {e}")
 | 
			
		||||
            if self.connection:
 | 
			
		||||
                self.connection.rollback()
 | 
			
		||||
            return False
 | 
			
		||||
    
 | 
			
		||||
    def insert_initial_detection(self, camera_id: str, captured_timestamp: str, session_id: str = None) -> str:
 | 
			
		||||
        """Insert initial detection record and return the session_id."""
 | 
			
		||||
        if not self.is_connected():
 | 
			
		||||
            if not self.connect():
 | 
			
		||||
                return None
 | 
			
		||||
        
 | 
			
		||||
        # Generate session_id if not provided
 | 
			
		||||
        if not session_id:
 | 
			
		||||
            session_id = str(uuid.uuid4())
 | 
			
		||||
        
 | 
			
		||||
        try:
 | 
			
		||||
            # Ensure table exists
 | 
			
		||||
            if not self.create_car_frontal_info_table():
 | 
			
		||||
                logger.error("Failed to create/verify table before insertion")
 | 
			
		||||
                return None
 | 
			
		||||
            
 | 
			
		||||
            cur = self.connection.cursor()
 | 
			
		||||
            insert_query = """
 | 
			
		||||
            INSERT INTO gas_station_1.car_frontal_info 
 | 
			
		||||
            (camera_id, captured_timestamp, session_id, license_character, license_type, car_brand, car_model, car_body_type)
 | 
			
		||||
            VALUES (%s, %s, %s, NULL, 'No model available', NULL, NULL, NULL)
 | 
			
		||||
            ON CONFLICT (session_id) DO NOTHING
 | 
			
		||||
            """
 | 
			
		||||
            
 | 
			
		||||
            cur.execute(insert_query, (camera_id, captured_timestamp, session_id))
 | 
			
		||||
            self.connection.commit()
 | 
			
		||||
            cur.close()
 | 
			
		||||
            logger.info(f"Inserted initial detection record with session_id: {session_id}")
 | 
			
		||||
            return session_id
 | 
			
		||||
            
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Failed to insert initial detection record: {e}")
 | 
			
		||||
            if self.connection:
 | 
			
		||||
                self.connection.rollback()
 | 
			
		||||
            return None
 | 
			
		||||
| 
						 | 
				
			
			@ -386,6 +386,134 @@ def execute_actions(node, frame, detection_result, regions_dict=None):
 | 
			
		|||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Error executing action {action['type']}: {e}")
 | 
			
		||||
 | 
			
		||||
def execute_parallel_actions(node, frame, detection_result, regions_dict):
 | 
			
		||||
    """Execute parallel actions after all required branches have completed."""
 | 
			
		||||
    if not node.get("parallelActions"):
 | 
			
		||||
        return
 | 
			
		||||
    
 | 
			
		||||
    logger.debug("Executing parallel actions...")
 | 
			
		||||
    branch_results = detection_result.get("branch_results", {})
 | 
			
		||||
    
 | 
			
		||||
    for action in node["parallelActions"]:
 | 
			
		||||
        try:
 | 
			
		||||
            action_type = action.get("type")
 | 
			
		||||
            logger.debug(f"Processing parallel action: {action_type}")
 | 
			
		||||
            
 | 
			
		||||
            if action_type == "postgresql_update_combined":
 | 
			
		||||
                # Check if all required branches have completed
 | 
			
		||||
                wait_for_branches = action.get("waitForBranches", [])
 | 
			
		||||
                missing_branches = [branch for branch in wait_for_branches if branch not in branch_results]
 | 
			
		||||
                
 | 
			
		||||
                if missing_branches:
 | 
			
		||||
                    logger.warning(f"Cannot execute postgresql_update_combined: missing branch results for {missing_branches}")
 | 
			
		||||
                    continue
 | 
			
		||||
                
 | 
			
		||||
                logger.info(f"All required branches completed: {wait_for_branches}")
 | 
			
		||||
                
 | 
			
		||||
                # Execute the database update
 | 
			
		||||
                execute_postgresql_update_combined(node, action, detection_result, branch_results)
 | 
			
		||||
            else:
 | 
			
		||||
                logger.warning(f"Unknown parallel action type: {action_type}")
 | 
			
		||||
                
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Error executing parallel action {action.get('type', 'unknown')}: {e}")
 | 
			
		||||
            import traceback
 | 
			
		||||
            logger.debug(f"Full traceback: {traceback.format_exc()}")
 | 
			
		||||
 | 
			
		||||
def execute_postgresql_update_combined(node, action, detection_result, branch_results):
 | 
			
		||||
    """Execute a PostgreSQL update with combined branch results."""
 | 
			
		||||
    if not node.get("db_manager"):
 | 
			
		||||
        logger.error("No database manager available for postgresql_update_combined action")
 | 
			
		||||
        return
 | 
			
		||||
        
 | 
			
		||||
    try:
 | 
			
		||||
        table = action["table"]
 | 
			
		||||
        key_field = action["key_field"]
 | 
			
		||||
        key_value_template = action["key_value"]
 | 
			
		||||
        fields = action["fields"]
 | 
			
		||||
        
 | 
			
		||||
        # Create context for key value formatting
 | 
			
		||||
        action_context = {**detection_result}
 | 
			
		||||
        key_value = key_value_template.format(**action_context)
 | 
			
		||||
        
 | 
			
		||||
        logger.info(f"Executing database update: table={table}, {key_field}={key_value}")
 | 
			
		||||
        
 | 
			
		||||
        # Process field mappings
 | 
			
		||||
        mapped_fields = {}
 | 
			
		||||
        for db_field, value_template in fields.items():
 | 
			
		||||
            try:
 | 
			
		||||
                mapped_value = resolve_field_mapping(value_template, branch_results, action_context)
 | 
			
		||||
                if mapped_value is not None:
 | 
			
		||||
                    mapped_fields[db_field] = mapped_value
 | 
			
		||||
                    logger.debug(f"Mapped field: {db_field} = {mapped_value}")
 | 
			
		||||
                else:
 | 
			
		||||
                    logger.warning(f"Could not resolve field mapping for {db_field}: {value_template}")
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.error(f"Error mapping field {db_field} with template '{value_template}': {e}")
 | 
			
		||||
        
 | 
			
		||||
        if not mapped_fields:
 | 
			
		||||
            logger.warning("No fields mapped successfully, skipping database update")
 | 
			
		||||
            return
 | 
			
		||||
            
 | 
			
		||||
        # Execute the database update
 | 
			
		||||
        success = node["db_manager"].execute_update(table, key_field, key_value, mapped_fields)
 | 
			
		||||
        
 | 
			
		||||
        if success:
 | 
			
		||||
            logger.info(f"Successfully updated database: {table} with {len(mapped_fields)} fields")
 | 
			
		||||
        else:
 | 
			
		||||
            logger.error(f"Failed to update database: {table}")
 | 
			
		||||
            
 | 
			
		||||
    except KeyError as e:
 | 
			
		||||
        logger.error(f"Missing required field in postgresql_update_combined action: {e}")
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        logger.error(f"Error in postgresql_update_combined action: {e}")
 | 
			
		||||
        import traceback
 | 
			
		||||
        logger.debug(f"Full traceback: {traceback.format_exc()}")
 | 
			
		||||
 | 
			
		||||
def resolve_field_mapping(value_template, branch_results, action_context):
 | 
			
		||||
    """Resolve field mapping templates like {car_brand_cls_v1.brand}."""
 | 
			
		||||
    try:
 | 
			
		||||
        # Handle simple context variables first (non-branch references)
 | 
			
		||||
        if not '.' in value_template:
 | 
			
		||||
            return value_template.format(**action_context)
 | 
			
		||||
        
 | 
			
		||||
        # Handle branch result references like {model_id.field}
 | 
			
		||||
        import re
 | 
			
		||||
        branch_refs = re.findall(r'\{([^}]+\.[^}]+)\}', value_template)
 | 
			
		||||
        
 | 
			
		||||
        resolved_template = value_template
 | 
			
		||||
        for ref in branch_refs:
 | 
			
		||||
            try:
 | 
			
		||||
                model_id, field_name = ref.split('.', 1)
 | 
			
		||||
                
 | 
			
		||||
                if model_id in branch_results:
 | 
			
		||||
                    branch_data = branch_results[model_id]
 | 
			
		||||
                    if field_name in branch_data:
 | 
			
		||||
                        field_value = branch_data[field_name]
 | 
			
		||||
                        resolved_template = resolved_template.replace(f'{{{ref}}}', str(field_value))
 | 
			
		||||
                        logger.debug(f"Resolved {ref} to {field_value}")
 | 
			
		||||
                    else:
 | 
			
		||||
                        logger.warning(f"Field '{field_name}' not found in branch '{model_id}' results. Available fields: {list(branch_data.keys())}")
 | 
			
		||||
                        return None
 | 
			
		||||
                else:
 | 
			
		||||
                    logger.warning(f"Branch '{model_id}' not found in results. Available branches: {list(branch_results.keys())}")
 | 
			
		||||
                    return None
 | 
			
		||||
            except ValueError as e:
 | 
			
		||||
                logger.error(f"Invalid branch reference format: {ref}")
 | 
			
		||||
                return None
 | 
			
		||||
        
 | 
			
		||||
        # Format any remaining simple variables
 | 
			
		||||
        try:
 | 
			
		||||
            final_value = resolved_template.format(**action_context)
 | 
			
		||||
            return final_value
 | 
			
		||||
        except KeyError as e:
 | 
			
		||||
            logger.warning(f"Could not resolve context variable in template: {e}")
 | 
			
		||||
            return resolved_template
 | 
			
		||||
            
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        logger.error(f"Error resolving field mapping '{value_template}': {e}")
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
			
		||||
    """
 | 
			
		||||
    Enhanced pipeline that supports:
 | 
			
		||||
| 
						 | 
				
			
			@ -410,13 +538,24 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
			
		|||
 | 
			
		||||
            top1_idx = int(probs.top1)
 | 
			
		||||
            top1_conf = float(probs.top1conf)
 | 
			
		||||
            class_name = node["model"].names[top1_idx]
 | 
			
		||||
 | 
			
		||||
            det = {
 | 
			
		||||
                "class": node["model"].names[top1_idx],
 | 
			
		||||
                "class": class_name,
 | 
			
		||||
                "confidence": top1_conf,
 | 
			
		||||
                "id": None,
 | 
			
		||||
                node["model"].names[top1_idx]: node["model"].names[top1_idx]  # Add class name as key
 | 
			
		||||
                class_name: class_name  # Add class name as key for backward compatibility
 | 
			
		||||
            }
 | 
			
		||||
            
 | 
			
		||||
            # Add specific field mappings for database operations based on model type
 | 
			
		||||
            model_id = node.get("modelId", "").lower()
 | 
			
		||||
            if "brand" in model_id or "brand_cls" in model_id:
 | 
			
		||||
                det["brand"] = class_name
 | 
			
		||||
            elif "bodytype" in model_id or "body" in model_id:
 | 
			
		||||
                det["body_type"] = class_name
 | 
			
		||||
            elif "color" in model_id:
 | 
			
		||||
                det["color"] = class_name
 | 
			
		||||
            
 | 
			
		||||
            execute_actions(node, frame, det)
 | 
			
		||||
            return (det, None) if return_bbox else det
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -486,6 +625,30 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
			
		|||
            "regions": regions_dict,
 | 
			
		||||
            **(context or {})
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        # ─── Create initial database record when Car+Frontal detected ────
 | 
			
		||||
        if node.get("db_manager") and node.get("multiClass", False):
 | 
			
		||||
            # Generate UUID session_id since client session is None for now
 | 
			
		||||
            import uuid as uuid_lib
 | 
			
		||||
            from datetime import datetime
 | 
			
		||||
            generated_session_id = str(uuid_lib.uuid4())
 | 
			
		||||
            
 | 
			
		||||
            # Insert initial detection record
 | 
			
		||||
            camera_id = detection_result.get("camera_id", "unknown")
 | 
			
		||||
            timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
 | 
			
		||||
            
 | 
			
		||||
            inserted_session_id = node["db_manager"].insert_initial_detection(
 | 
			
		||||
                camera_id=camera_id,
 | 
			
		||||
                captured_timestamp=timestamp,
 | 
			
		||||
                session_id=generated_session_id
 | 
			
		||||
            )
 | 
			
		||||
            
 | 
			
		||||
            if inserted_session_id:
 | 
			
		||||
                # Update detection_result with the generated session_id for actions and branches
 | 
			
		||||
                detection_result["session_id"] = inserted_session_id
 | 
			
		||||
                detection_result["timestamp"] = timestamp  # Update with proper timestamp
 | 
			
		||||
                logger.info(f"Created initial database record with session_id: {inserted_session_id}")
 | 
			
		||||
        
 | 
			
		||||
        execute_actions(node, frame, detection_result, regions_dict)
 | 
			
		||||
 | 
			
		||||
        # ─── Parallel branch processing ─────────────────────────────
 | 
			
		||||
| 
						 | 
				
			
			@ -498,12 +661,22 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
			
		|||
                trigger_classes = br.get("triggerClasses", [])
 | 
			
		||||
                min_conf = br.get("minConfidence", 0)
 | 
			
		||||
                
 | 
			
		||||
                logger.debug(f"Evaluating branch {br['modelId']}: trigger_classes={trigger_classes}, min_conf={min_conf}")
 | 
			
		||||
                
 | 
			
		||||
                # Check if any detected class matches branch trigger
 | 
			
		||||
                branch_triggered = False
 | 
			
		||||
                for det_class in regions_dict:
 | 
			
		||||
                    if (det_class in trigger_classes and 
 | 
			
		||||
                        regions_dict[det_class]["confidence"] >= min_conf):
 | 
			
		||||
                    det_confidence = regions_dict[det_class]["confidence"]
 | 
			
		||||
                    logger.debug(f"  Checking detected class '{det_class}' (confidence={det_confidence:.3f}) against triggers {trigger_classes}")
 | 
			
		||||
                    
 | 
			
		||||
                    if (det_class in trigger_classes and det_confidence >= min_conf):
 | 
			
		||||
                        active_branches.append(br)
 | 
			
		||||
                        branch_triggered = True
 | 
			
		||||
                        logger.info(f"Branch {br['modelId']} activated by class '{det_class}' (conf={det_confidence:.3f} >= {min_conf})")
 | 
			
		||||
                        break
 | 
			
		||||
                
 | 
			
		||||
                if not branch_triggered:
 | 
			
		||||
                    logger.debug(f"Branch {br['modelId']} not triggered - no matching classes or insufficient confidence")
 | 
			
		||||
            
 | 
			
		||||
            if active_branches:
 | 
			
		||||
                if node.get("parallel", False) or any(br.get("parallel", False) for br in active_branches):
 | 
			
		||||
| 
						 | 
				
			
			@ -515,11 +688,15 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
			
		|||
                            crop_class = br.get("cropClass", br.get("triggerClasses", [])[0] if br.get("triggerClasses") else None)
 | 
			
		||||
                            sub_frame = frame
 | 
			
		||||
                            
 | 
			
		||||
                            logger.info(f"Starting parallel branch: {br['modelId']}, crop_class: {crop_class}")
 | 
			
		||||
                            
 | 
			
		||||
                            if br.get("crop", False) and crop_class:
 | 
			
		||||
                                cropped = crop_region_by_class(frame, regions_dict, crop_class)
 | 
			
		||||
                                if cropped is not None:
 | 
			
		||||
                                    sub_frame = cv2.resize(cropped, (224, 224))
 | 
			
		||||
                                    logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']}")
 | 
			
		||||
                                else:
 | 
			
		||||
                                    logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch")
 | 
			
		||||
                                    continue
 | 
			
		||||
                            
 | 
			
		||||
                            future = executor.submit(run_pipeline, sub_frame, br, True, context)
 | 
			
		||||
| 
						 | 
				
			
			@ -541,21 +718,36 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
			
		|||
                        crop_class = br.get("cropClass", br.get("triggerClasses", [])[0] if br.get("triggerClasses") else None)
 | 
			
		||||
                        sub_frame = frame
 | 
			
		||||
                        
 | 
			
		||||
                        logger.info(f"Starting sequential branch: {br['modelId']}, crop_class: {crop_class}")
 | 
			
		||||
                        
 | 
			
		||||
                        if br.get("crop", False) and crop_class:
 | 
			
		||||
                            cropped = crop_region_by_class(frame, regions_dict, crop_class)
 | 
			
		||||
                            if cropped is not None:
 | 
			
		||||
                                sub_frame = cv2.resize(cropped, (224, 224))
 | 
			
		||||
                                logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']}")
 | 
			
		||||
                            else:
 | 
			
		||||
                                logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch")
 | 
			
		||||
                                continue
 | 
			
		||||
                        
 | 
			
		||||
                        result, _ = run_pipeline(sub_frame, br, True, context)
 | 
			
		||||
                        if result:
 | 
			
		||||
                            branch_results[br["modelId"]] = result
 | 
			
		||||
                            logger.info(f"Branch {br['modelId']} completed: {result}")
 | 
			
		||||
                        try:
 | 
			
		||||
                            result, _ = run_pipeline(sub_frame, br, True, context)
 | 
			
		||||
                            if result:
 | 
			
		||||
                                branch_results[br["modelId"]] = result
 | 
			
		||||
                                logger.info(f"Branch {br['modelId']} completed: {result}")
 | 
			
		||||
                            else:
 | 
			
		||||
                                logger.warning(f"Branch {br['modelId']} returned no result")
 | 
			
		||||
                        except Exception as e:
 | 
			
		||||
                            logger.error(f"Error in sequential branch {br['modelId']}: {e}")
 | 
			
		||||
                            import traceback
 | 
			
		||||
                            logger.debug(f"Branch error traceback: {traceback.format_exc()}")
 | 
			
		||||
 | 
			
		||||
            # Store branch results in detection_result for parallel actions
 | 
			
		||||
            detection_result["branch_results"] = branch_results
 | 
			
		||||
 | 
			
		||||
        # ─── Execute Parallel Actions ───────────────────────────────
 | 
			
		||||
        if node.get("parallelActions") and "branch_results" in detection_result:
 | 
			
		||||
            execute_parallel_actions(node, frame, detection_result, regions_dict)
 | 
			
		||||
 | 
			
		||||
        # ─── Return detection result ────────────────────────────────
 | 
			
		||||
        primary_detection = max(all_detections, key=lambda x: x["confidence"])
 | 
			
		||||
        primary_bbox = primary_detection["bbox"]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue