Done brand and body type detection with postgresql integration
This commit is contained in:
parent
18c62a2370
commit
8c429cc8f6
2 changed files with 282 additions and 10 deletions
|
@ -2,6 +2,7 @@ import psycopg2
|
|||
import psycopg2.extras
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -90,8 +91,11 @@ class DatabaseManager:
|
|||
set_clauses.append(f"{field} = %s")
|
||||
values.append(value)
|
||||
|
||||
# Add schema prefix if table doesn't already have it
|
||||
full_table_name = table if '.' in table else f"gas_station_1.{table}"
|
||||
|
||||
query = f"""
|
||||
INSERT INTO {table} ({key_field}, {', '.join(fields.keys())})
|
||||
INSERT INTO {full_table_name} ({key_field}, {', '.join(fields.keys())})
|
||||
VALUES (%s, {', '.join(['%s'] * len(fields))})
|
||||
ON CONFLICT ({key_field})
|
||||
DO UPDATE SET {', '.join(set_clauses)}
|
||||
|
@ -109,4 +113,80 @@ class DatabaseManager:
|
|||
logger.error(f"Failed to execute update on {table}: {e}")
|
||||
if self.connection:
|
||||
self.connection.rollback()
|
||||
return False
|
||||
return False
|
||||
|
||||
def create_car_frontal_info_table(self) -> bool:
|
||||
"""Create the car_frontal_info table in gas_station_1 schema if it doesn't exist."""
|
||||
if not self.is_connected():
|
||||
if not self.connect():
|
||||
return False
|
||||
|
||||
try:
|
||||
cur = self.connection.cursor()
|
||||
|
||||
# Create schema if it doesn't exist
|
||||
cur.execute("CREATE SCHEMA IF NOT EXISTS gas_station_1")
|
||||
|
||||
# Create table if it doesn't exist
|
||||
create_table_query = """
|
||||
CREATE TABLE IF NOT EXISTS gas_station_1.car_frontal_info (
|
||||
camera_id VARCHAR(255),
|
||||
captured_timestamp VARCHAR(255),
|
||||
session_id VARCHAR(255) PRIMARY KEY,
|
||||
license_character VARCHAR(255) DEFAULT NULL,
|
||||
license_type VARCHAR(255) DEFAULT 'No model available',
|
||||
car_brand VARCHAR(255) DEFAULT NULL,
|
||||
car_model VARCHAR(255) DEFAULT NULL,
|
||||
car_body_type VARCHAR(255) DEFAULT NULL,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
)
|
||||
"""
|
||||
|
||||
cur.execute(create_table_query)
|
||||
self.connection.commit()
|
||||
cur.close()
|
||||
logger.info("Successfully created/verified car_frontal_info table in gas_station_1 schema")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create car_frontal_info table: {e}")
|
||||
if self.connection:
|
||||
self.connection.rollback()
|
||||
return False
|
||||
|
||||
def insert_initial_detection(self, camera_id: str, captured_timestamp: str, session_id: str = None) -> str:
|
||||
"""Insert initial detection record and return the session_id."""
|
||||
if not self.is_connected():
|
||||
if not self.connect():
|
||||
return None
|
||||
|
||||
# Generate session_id if not provided
|
||||
if not session_id:
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Ensure table exists
|
||||
if not self.create_car_frontal_info_table():
|
||||
logger.error("Failed to create/verify table before insertion")
|
||||
return None
|
||||
|
||||
cur = self.connection.cursor()
|
||||
insert_query = """
|
||||
INSERT INTO gas_station_1.car_frontal_info
|
||||
(camera_id, captured_timestamp, session_id, license_character, license_type, car_brand, car_model, car_body_type)
|
||||
VALUES (%s, %s, %s, NULL, 'No model available', NULL, NULL, NULL)
|
||||
ON CONFLICT (session_id) DO NOTHING
|
||||
"""
|
||||
|
||||
cur.execute(insert_query, (camera_id, captured_timestamp, session_id))
|
||||
self.connection.commit()
|
||||
cur.close()
|
||||
logger.info(f"Inserted initial detection record with session_id: {session_id}")
|
||||
return session_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert initial detection record: {e}")
|
||||
if self.connection:
|
||||
self.connection.rollback()
|
||||
return None
|
|
@ -386,6 +386,134 @@ def execute_actions(node, frame, detection_result, regions_dict=None):
|
|||
except Exception as e:
|
||||
logger.error(f"Error executing action {action['type']}: {e}")
|
||||
|
||||
def execute_parallel_actions(node, frame, detection_result, regions_dict):
|
||||
"""Execute parallel actions after all required branches have completed."""
|
||||
if not node.get("parallelActions"):
|
||||
return
|
||||
|
||||
logger.debug("Executing parallel actions...")
|
||||
branch_results = detection_result.get("branch_results", {})
|
||||
|
||||
for action in node["parallelActions"]:
|
||||
try:
|
||||
action_type = action.get("type")
|
||||
logger.debug(f"Processing parallel action: {action_type}")
|
||||
|
||||
if action_type == "postgresql_update_combined":
|
||||
# Check if all required branches have completed
|
||||
wait_for_branches = action.get("waitForBranches", [])
|
||||
missing_branches = [branch for branch in wait_for_branches if branch not in branch_results]
|
||||
|
||||
if missing_branches:
|
||||
logger.warning(f"Cannot execute postgresql_update_combined: missing branch results for {missing_branches}")
|
||||
continue
|
||||
|
||||
logger.info(f"All required branches completed: {wait_for_branches}")
|
||||
|
||||
# Execute the database update
|
||||
execute_postgresql_update_combined(node, action, detection_result, branch_results)
|
||||
else:
|
||||
logger.warning(f"Unknown parallel action type: {action_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing parallel action {action.get('type', 'unknown')}: {e}")
|
||||
import traceback
|
||||
logger.debug(f"Full traceback: {traceback.format_exc()}")
|
||||
|
||||
def execute_postgresql_update_combined(node, action, detection_result, branch_results):
|
||||
"""Execute a PostgreSQL update with combined branch results."""
|
||||
if not node.get("db_manager"):
|
||||
logger.error("No database manager available for postgresql_update_combined action")
|
||||
return
|
||||
|
||||
try:
|
||||
table = action["table"]
|
||||
key_field = action["key_field"]
|
||||
key_value_template = action["key_value"]
|
||||
fields = action["fields"]
|
||||
|
||||
# Create context for key value formatting
|
||||
action_context = {**detection_result}
|
||||
key_value = key_value_template.format(**action_context)
|
||||
|
||||
logger.info(f"Executing database update: table={table}, {key_field}={key_value}")
|
||||
|
||||
# Process field mappings
|
||||
mapped_fields = {}
|
||||
for db_field, value_template in fields.items():
|
||||
try:
|
||||
mapped_value = resolve_field_mapping(value_template, branch_results, action_context)
|
||||
if mapped_value is not None:
|
||||
mapped_fields[db_field] = mapped_value
|
||||
logger.debug(f"Mapped field: {db_field} = {mapped_value}")
|
||||
else:
|
||||
logger.warning(f"Could not resolve field mapping for {db_field}: {value_template}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error mapping field {db_field} with template '{value_template}': {e}")
|
||||
|
||||
if not mapped_fields:
|
||||
logger.warning("No fields mapped successfully, skipping database update")
|
||||
return
|
||||
|
||||
# Execute the database update
|
||||
success = node["db_manager"].execute_update(table, key_field, key_value, mapped_fields)
|
||||
|
||||
if success:
|
||||
logger.info(f"Successfully updated database: {table} with {len(mapped_fields)} fields")
|
||||
else:
|
||||
logger.error(f"Failed to update database: {table}")
|
||||
|
||||
except KeyError as e:
|
||||
logger.error(f"Missing required field in postgresql_update_combined action: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in postgresql_update_combined action: {e}")
|
||||
import traceback
|
||||
logger.debug(f"Full traceback: {traceback.format_exc()}")
|
||||
|
||||
def resolve_field_mapping(value_template, branch_results, action_context):
|
||||
"""Resolve field mapping templates like {car_brand_cls_v1.brand}."""
|
||||
try:
|
||||
# Handle simple context variables first (non-branch references)
|
||||
if not '.' in value_template:
|
||||
return value_template.format(**action_context)
|
||||
|
||||
# Handle branch result references like {model_id.field}
|
||||
import re
|
||||
branch_refs = re.findall(r'\{([^}]+\.[^}]+)\}', value_template)
|
||||
|
||||
resolved_template = value_template
|
||||
for ref in branch_refs:
|
||||
try:
|
||||
model_id, field_name = ref.split('.', 1)
|
||||
|
||||
if model_id in branch_results:
|
||||
branch_data = branch_results[model_id]
|
||||
if field_name in branch_data:
|
||||
field_value = branch_data[field_name]
|
||||
resolved_template = resolved_template.replace(f'{{{ref}}}', str(field_value))
|
||||
logger.debug(f"Resolved {ref} to {field_value}")
|
||||
else:
|
||||
logger.warning(f"Field '{field_name}' not found in branch '{model_id}' results. Available fields: {list(branch_data.keys())}")
|
||||
return None
|
||||
else:
|
||||
logger.warning(f"Branch '{model_id}' not found in results. Available branches: {list(branch_results.keys())}")
|
||||
return None
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid branch reference format: {ref}")
|
||||
return None
|
||||
|
||||
# Format any remaining simple variables
|
||||
try:
|
||||
final_value = resolved_template.format(**action_context)
|
||||
return final_value
|
||||
except KeyError as e:
|
||||
logger.warning(f"Could not resolve context variable in template: {e}")
|
||||
return resolved_template
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error resolving field mapping '{value_template}': {e}")
|
||||
return None
|
||||
|
||||
def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
|
||||
"""
|
||||
Enhanced pipeline that supports:
|
||||
|
@ -410,13 +538,24 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
|
|||
|
||||
top1_idx = int(probs.top1)
|
||||
top1_conf = float(probs.top1conf)
|
||||
class_name = node["model"].names[top1_idx]
|
||||
|
||||
det = {
|
||||
"class": node["model"].names[top1_idx],
|
||||
"class": class_name,
|
||||
"confidence": top1_conf,
|
||||
"id": None,
|
||||
node["model"].names[top1_idx]: node["model"].names[top1_idx] # Add class name as key
|
||||
class_name: class_name # Add class name as key for backward compatibility
|
||||
}
|
||||
|
||||
# Add specific field mappings for database operations based on model type
|
||||
model_id = node.get("modelId", "").lower()
|
||||
if "brand" in model_id or "brand_cls" in model_id:
|
||||
det["brand"] = class_name
|
||||
elif "bodytype" in model_id or "body" in model_id:
|
||||
det["body_type"] = class_name
|
||||
elif "color" in model_id:
|
||||
det["color"] = class_name
|
||||
|
||||
execute_actions(node, frame, det)
|
||||
return (det, None) if return_bbox else det
|
||||
|
||||
|
@ -486,6 +625,30 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
|
|||
"regions": regions_dict,
|
||||
**(context or {})
|
||||
}
|
||||
|
||||
# ─── Create initial database record when Car+Frontal detected ────
|
||||
if node.get("db_manager") and node.get("multiClass", False):
|
||||
# Generate UUID session_id since client session is None for now
|
||||
import uuid as uuid_lib
|
||||
from datetime import datetime
|
||||
generated_session_id = str(uuid_lib.uuid4())
|
||||
|
||||
# Insert initial detection record
|
||||
camera_id = detection_result.get("camera_id", "unknown")
|
||||
timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
||||
|
||||
inserted_session_id = node["db_manager"].insert_initial_detection(
|
||||
camera_id=camera_id,
|
||||
captured_timestamp=timestamp,
|
||||
session_id=generated_session_id
|
||||
)
|
||||
|
||||
if inserted_session_id:
|
||||
# Update detection_result with the generated session_id for actions and branches
|
||||
detection_result["session_id"] = inserted_session_id
|
||||
detection_result["timestamp"] = timestamp # Update with proper timestamp
|
||||
logger.info(f"Created initial database record with session_id: {inserted_session_id}")
|
||||
|
||||
execute_actions(node, frame, detection_result, regions_dict)
|
||||
|
||||
# ─── Parallel branch processing ─────────────────────────────
|
||||
|
@ -498,12 +661,22 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
|
|||
trigger_classes = br.get("triggerClasses", [])
|
||||
min_conf = br.get("minConfidence", 0)
|
||||
|
||||
logger.debug(f"Evaluating branch {br['modelId']}: trigger_classes={trigger_classes}, min_conf={min_conf}")
|
||||
|
||||
# Check if any detected class matches branch trigger
|
||||
branch_triggered = False
|
||||
for det_class in regions_dict:
|
||||
if (det_class in trigger_classes and
|
||||
regions_dict[det_class]["confidence"] >= min_conf):
|
||||
det_confidence = regions_dict[det_class]["confidence"]
|
||||
logger.debug(f" Checking detected class '{det_class}' (confidence={det_confidence:.3f}) against triggers {trigger_classes}")
|
||||
|
||||
if (det_class in trigger_classes and det_confidence >= min_conf):
|
||||
active_branches.append(br)
|
||||
branch_triggered = True
|
||||
logger.info(f"Branch {br['modelId']} activated by class '{det_class}' (conf={det_confidence:.3f} >= {min_conf})")
|
||||
break
|
||||
|
||||
if not branch_triggered:
|
||||
logger.debug(f"Branch {br['modelId']} not triggered - no matching classes or insufficient confidence")
|
||||
|
||||
if active_branches:
|
||||
if node.get("parallel", False) or any(br.get("parallel", False) for br in active_branches):
|
||||
|
@ -515,11 +688,15 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
|
|||
crop_class = br.get("cropClass", br.get("triggerClasses", [])[0] if br.get("triggerClasses") else None)
|
||||
sub_frame = frame
|
||||
|
||||
logger.info(f"Starting parallel branch: {br['modelId']}, crop_class: {crop_class}")
|
||||
|
||||
if br.get("crop", False) and crop_class:
|
||||
cropped = crop_region_by_class(frame, regions_dict, crop_class)
|
||||
if cropped is not None:
|
||||
sub_frame = cv2.resize(cropped, (224, 224))
|
||||
logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']}")
|
||||
else:
|
||||
logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch")
|
||||
continue
|
||||
|
||||
future = executor.submit(run_pipeline, sub_frame, br, True, context)
|
||||
|
@ -541,21 +718,36 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
|
|||
crop_class = br.get("cropClass", br.get("triggerClasses", [])[0] if br.get("triggerClasses") else None)
|
||||
sub_frame = frame
|
||||
|
||||
logger.info(f"Starting sequential branch: {br['modelId']}, crop_class: {crop_class}")
|
||||
|
||||
if br.get("crop", False) and crop_class:
|
||||
cropped = crop_region_by_class(frame, regions_dict, crop_class)
|
||||
if cropped is not None:
|
||||
sub_frame = cv2.resize(cropped, (224, 224))
|
||||
logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']}")
|
||||
else:
|
||||
logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch")
|
||||
continue
|
||||
|
||||
result, _ = run_pipeline(sub_frame, br, True, context)
|
||||
if result:
|
||||
branch_results[br["modelId"]] = result
|
||||
logger.info(f"Branch {br['modelId']} completed: {result}")
|
||||
try:
|
||||
result, _ = run_pipeline(sub_frame, br, True, context)
|
||||
if result:
|
||||
branch_results[br["modelId"]] = result
|
||||
logger.info(f"Branch {br['modelId']} completed: {result}")
|
||||
else:
|
||||
logger.warning(f"Branch {br['modelId']} returned no result")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in sequential branch {br['modelId']}: {e}")
|
||||
import traceback
|
||||
logger.debug(f"Branch error traceback: {traceback.format_exc()}")
|
||||
|
||||
# Store branch results in detection_result for parallel actions
|
||||
detection_result["branch_results"] = branch_results
|
||||
|
||||
# ─── Execute Parallel Actions ───────────────────────────────
|
||||
if node.get("parallelActions") and "branch_results" in detection_result:
|
||||
execute_parallel_actions(node, frame, detection_result, regions_dict)
|
||||
|
||||
# ─── Return detection result ────────────────────────────────
|
||||
primary_detection = max(all_detections, key=lambda x: x["confidence"])
|
||||
primary_bbox = primary_detection["bbox"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue