From dc47eb858001b0aa3f98aa4798f50e2045a0666b Mon Sep 17 00:00:00 2001 From: ziesorx Date: Thu, 25 Sep 2025 00:18:02 +0700 Subject: [PATCH] refactor: remove hardcoded modelid --- core/communication/websocket.py | 2 +- core/detection/branches.py | 26 +++++++++++--------------- core/detection/pipeline.py | 33 +++++++++++++-------------------- core/tracking/integration.py | 22 +++++++++------------- 4 files changed, 34 insertions(+), 49 deletions(-) diff --git a/core/communication/websocket.py b/core/communication/websocket.py index a2da785..da3b7ee 100644 --- a/core/communication/websocket.py +++ b/core/communication/websocket.py @@ -306,7 +306,7 @@ class WebSocketHandler: if pipeline_parser: # Create tracking integration with message sender tracking_integration = TrackingPipelineIntegration( - pipeline_parser, model_manager, self._send_message + pipeline_parser, model_manager, model_id, self._send_message ) # Initialize tracking model diff --git a/core/detection/branches.py b/core/detection/branches.py index e0ca1df..247c5f8 100644 --- a/core/detection/branches.py +++ b/core/detection/branches.py @@ -21,14 +21,16 @@ class BranchProcessor: Manages branch synchronization and result collection. """ - def __init__(self, model_manager: Any): + def __init__(self, model_manager: Any, model_id: int): """ Initialize branch processor. Args: model_manager: Model manager for loading models + model_id: The model ID to use for loading models """ self.model_manager = model_manager + self.model_id = model_id # Branch models cache self.branch_models: Dict[str, YOLOWrapper] = {} @@ -123,22 +125,16 @@ class BranchProcessor: # Load model logger.info(f"Loading branch model: {model_id} ({model_file})") - # Get the first available model ID from ModelManager - pipeline_models = list(self.model_manager.get_all_downloaded_models()) - if pipeline_models: - actual_model_id = pipeline_models[0] # Use the first available model - model = self.model_manager.get_yolo_model(actual_model_id, model_file) + # Load model using the proper model ID + model = self.model_manager.get_yolo_model(self.model_id, model_file) - if model: - self.branch_models[model_id] = model - self.stats['models_loaded'] += 1 - logger.info(f"Branch model {model_id} loaded successfully") - return model - else: - logger.error(f"Failed to load branch model {model_id}") - return None + if model: + self.branch_models[model_id] = model + self.stats['models_loaded'] += 1 + logger.info(f"Branch model {model_id} loaded successfully") + return model else: - logger.error("No models available in ModelManager for branch loading") + logger.error(f"Failed to load branch model {model_id}") return None except Exception as e: diff --git a/core/detection/pipeline.py b/core/detection/pipeline.py index cfab8dd..df1106f 100644 --- a/core/detection/pipeline.py +++ b/core/detection/pipeline.py @@ -27,21 +27,23 @@ class DetectionPipeline: Handles detection execution, branch coordination, and result aggregation. """ - def __init__(self, pipeline_parser: PipelineParser, model_manager: Any, message_sender=None): + def __init__(self, pipeline_parser: PipelineParser, model_manager: Any, model_id: int, message_sender=None): """ Initialize detection pipeline. Args: pipeline_parser: Pipeline parser with loaded configuration model_manager: Model manager for loading models + model_id: The model ID to use for loading models message_sender: Optional callback function for sending WebSocket messages """ self.pipeline_parser = pipeline_parser self.model_manager = model_manager + self.model_id = model_id self.message_sender = message_sender # Initialize components - self.branch_processor = BranchProcessor(model_manager) + self.branch_processor = BranchProcessor(model_manager, model_id) self.redis_manager = None self.db_manager = None self.license_plate_manager = None @@ -150,23 +152,14 @@ class DetectionPipeline: # Load detection model logger.info(f"Loading detection model: {model_id} ({model_file})") - # Get the model ID from the ModelManager context - pipeline_models = list(self.model_manager.get_all_downloaded_models()) - if pipeline_models: - actual_model_id = pipeline_models[0] # Use the first available model - self.detection_model = self.model_manager.get_yolo_model(actual_model_id, model_file) - else: - logger.error("No models available in ModelManager") + self.detection_model = self.model_manager.get_yolo_model(self.model_id, model_file) + if not self.detection_model: + logger.error(f"Failed to load detection model {model_file} from model {self.model_id}") return False self.detection_model_id = model_id - - if self.detection_model: - logger.info(f"Detection model {model_id} loaded successfully") - return True - else: - logger.error(f"Failed to load detection model {model_id}") - return False + logger.info(f"Detection model {model_id} loaded successfully") + return True except Exception as e: logger.error(f"Error initializing detection model: {e}", exc_info=True) @@ -301,8 +294,8 @@ class DetectionPipeline: "licensePlateText": license_text, "licensePlateConfidence": confidence }, - modelId=52, # Default model ID - modelName="yolo11m" # Default model name + modelId=self.model_id, + modelName=self.pipeline_parser.pipeline_config.model_id if self.pipeline_parser.pipeline_config else "detection_model" ) # Create imageDetection message @@ -342,8 +335,8 @@ class DetectionPipeline: "licensePlateText": None, "licensePlateConfidence": None }, - modelId=52, # Default model ID - modelName="yolo11m" # Default model name + modelId=self.model_id, + modelName=self.pipeline_parser.pipeline_config.model_id if self.pipeline_parser.pipeline_config else "detection_model" ) # Create imageDetection message diff --git a/core/tracking/integration.py b/core/tracking/integration.py index 74e636d..a10acf8 100644 --- a/core/tracking/integration.py +++ b/core/tracking/integration.py @@ -25,17 +25,19 @@ class TrackingPipelineIntegration: Manages tracking state transitions and pipeline execution triggers. """ - def __init__(self, pipeline_parser: PipelineParser, model_manager: Any, message_sender=None): + def __init__(self, pipeline_parser: PipelineParser, model_manager: Any, model_id: int, message_sender=None): """ Initialize tracking-pipeline integration. Args: pipeline_parser: Pipeline parser with loaded configuration model_manager: Model manager for loading models + model_id: The model ID to use for loading models message_sender: Optional callback function for sending WebSocket messages """ self.pipeline_parser = pipeline_parser self.model_manager = model_manager + self.model_id = model_id self.message_sender = message_sender # Store subscription info for snapshot access @@ -101,15 +103,9 @@ class TrackingPipelineIntegration: # Load tracking model logger.info(f"Loading tracking model: {model_id} ({model_file})") - # Get the model ID from the ModelManager context - # We need the actual model ID, not the model string identifier - # For now, let's extract it from the model manager - pipeline_models = list(self.model_manager.get_all_downloaded_models()) - if pipeline_models: - actual_model_id = pipeline_models[0] # Use the first available model - self.tracking_model = self.model_manager.get_yolo_model(actual_model_id, model_file) - else: - logger.error("No models available in ModelManager") + self.tracking_model = self.model_manager.get_yolo_model(self.model_id, model_file) + if not self.tracking_model: + logger.error(f"Failed to load tracking model {model_file} from model {self.model_id}") return False self.tracking_model_id = model_id @@ -141,7 +137,7 @@ class TrackingPipelineIntegration: return False # Create detection pipeline with message sender capability - self.detection_pipeline = DetectionPipeline(self.pipeline_parser, self.model_manager, self.message_sender) + self.detection_pipeline = DetectionPipeline(self.pipeline_parser, self.model_manager, self.model_id, self.message_sender) # Initialize detection pipeline if await self.detection_pipeline.initialize(): @@ -637,8 +633,8 @@ class TrackingPipelineIntegration: detection_message = create_image_detection( subscription_identifier=subscription_id, detection_data=None, # Null detection indicates abandonment - model_id=52, - model_name="front_rear_detection_v1" + model_id=self.model_id, + model_name=self.pipeline_parser.tracking_config.model_id if self.pipeline_parser.tracking_config else "tracking_model" ) # Send to backend via WebSocket if sender is available