python-detector-worker/core/models/pipeline.py
2025-09-23 16:13:11 +07:00

357 lines
No EOL
12 KiB
Python

"""
Pipeline Configuration Parser - Handles pipeline.json parsing and validation
"""
import json
import logging
from pathlib import Path
from typing import Dict, List, Any, Optional, Set
from dataclasses import dataclass, field
from enum import Enum
logger = logging.getLogger(__name__)
class ActionType(Enum):
"""Supported action types in pipeline"""
REDIS_SAVE_IMAGE = "redis_save_image"
REDIS_PUBLISH = "redis_publish"
POSTGRESQL_UPDATE = "postgresql_update"
POSTGRESQL_UPDATE_COMBINED = "postgresql_update_combined"
POSTGRESQL_INSERT = "postgresql_insert"
@dataclass
class RedisConfig:
"""Redis connection configuration"""
host: str
port: int = 6379
password: Optional[str] = None
db: int = 0
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'RedisConfig':
return cls(
host=data['host'],
port=data.get('port', 6379),
password=data.get('password'),
db=data.get('db', 0)
)
@dataclass
class PostgreSQLConfig:
"""PostgreSQL connection configuration"""
host: str
port: int
database: str
username: str
password: str
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'PostgreSQLConfig':
return cls(
host=data['host'],
port=data.get('port', 5432),
database=data['database'],
username=data['username'],
password=data['password']
)
@dataclass
class Action:
"""Represents an action in the pipeline"""
type: ActionType
params: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Action':
action_type = ActionType(data['type'])
params = {k: v for k, v in data.items() if k != 'type'}
return cls(type=action_type, params=params)
@dataclass
class ModelBranch:
"""Represents a branch in the pipeline with its own model"""
model_id: str
model_file: str
trigger_classes: List[str]
min_confidence: float = 0.5
crop: bool = False
crop_class: Optional[Any] = None # Can be string or list
parallel: bool = False
actions: List[Action] = field(default_factory=list)
branches: List['ModelBranch'] = field(default_factory=list)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ModelBranch':
actions = [Action.from_dict(a) for a in data.get('actions', [])]
branches = [cls.from_dict(b) for b in data.get('branches', [])]
return cls(
model_id=data['modelId'],
model_file=data['modelFile'],
trigger_classes=data.get('triggerClasses', []),
min_confidence=data.get('minConfidence', 0.5),
crop=data.get('crop', False),
crop_class=data.get('cropClass'),
parallel=data.get('parallel', False),
actions=actions,
branches=branches
)
@dataclass
class TrackingConfig:
"""Configuration for the tracking phase"""
model_id: str
model_file: str
trigger_classes: List[str]
min_confidence: float = 0.6
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'TrackingConfig':
return cls(
model_id=data['modelId'],
model_file=data['modelFile'],
trigger_classes=data.get('triggerClasses', []),
min_confidence=data.get('minConfidence', 0.6)
)
@dataclass
class PipelineConfig:
"""Main pipeline configuration"""
model_id: str
model_file: str
trigger_classes: List[str]
min_confidence: float = 0.5
crop: bool = False
branches: List[ModelBranch] = field(default_factory=list)
parallel_actions: List[Action] = field(default_factory=list)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'PipelineConfig':
branches = [ModelBranch.from_dict(b) for b in data.get('branches', [])]
parallel_actions = [Action.from_dict(a) for a in data.get('parallelActions', [])]
return cls(
model_id=data['modelId'],
model_file=data['modelFile'],
trigger_classes=data.get('triggerClasses', []),
min_confidence=data.get('minConfidence', 0.5),
crop=data.get('crop', False),
branches=branches,
parallel_actions=parallel_actions
)
class PipelineParser:
"""Parser for pipeline.json configuration files"""
def __init__(self):
self.redis_config: Optional[RedisConfig] = None
self.postgresql_config: Optional[PostgreSQLConfig] = None
self.tracking_config: Optional[TrackingConfig] = None
self.pipeline_config: Optional[PipelineConfig] = None
self._model_dependencies: Set[str] = set()
def parse(self, config_path: Path) -> bool:
"""
Parse a pipeline.json configuration file
Args:
config_path: Path to the pipeline.json file
Returns:
True if parsing was successful, False otherwise
"""
try:
if not config_path.exists():
logger.error(f"Pipeline config not found: {config_path}")
return False
with open(config_path, 'r') as f:
data = json.load(f)
return self.parse_dict(data)
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in pipeline config: {str(e)}")
return False
except Exception as e:
logger.error(f"Failed to parse pipeline config: {str(e)}", exc_info=True)
return False
def parse_dict(self, data: Dict[str, Any]) -> bool:
"""
Parse a pipeline configuration from a dictionary
Args:
data: The configuration dictionary
Returns:
True if parsing was successful, False otherwise
"""
try:
# Parse Redis configuration
if 'redis' in data:
self.redis_config = RedisConfig.from_dict(data['redis'])
logger.debug(f"Parsed Redis config: {self.redis_config.host}:{self.redis_config.port}")
# Parse PostgreSQL configuration
if 'postgresql' in data:
self.postgresql_config = PostgreSQLConfig.from_dict(data['postgresql'])
logger.debug(f"Parsed PostgreSQL config: {self.postgresql_config.host}:{self.postgresql_config.port}/{self.postgresql_config.database}")
# Parse tracking configuration
if 'tracking' in data:
self.tracking_config = TrackingConfig.from_dict(data['tracking'])
self._model_dependencies.add(self.tracking_config.model_file)
logger.debug(f"Parsed tracking config: {self.tracking_config.model_id}")
# Parse main pipeline configuration
if 'pipeline' in data:
self.pipeline_config = PipelineConfig.from_dict(data['pipeline'])
self._collect_model_dependencies(self.pipeline_config)
logger.debug(f"Parsed pipeline config: {self.pipeline_config.model_id}")
logger.info(f"Successfully parsed pipeline configuration")
logger.debug(f"Model dependencies: {self._model_dependencies}")
return True
except KeyError as e:
logger.error(f"Missing required field in pipeline config: {str(e)}")
return False
except Exception as e:
logger.error(f"Failed to parse pipeline config: {str(e)}", exc_info=True)
return False
def _collect_model_dependencies(self, config: Any) -> None:
"""
Recursively collect all model file dependencies
Args:
config: Pipeline or branch configuration
"""
if hasattr(config, 'model_file'):
self._model_dependencies.add(config.model_file)
if hasattr(config, 'branches'):
for branch in config.branches:
self._collect_model_dependencies(branch)
def get_model_dependencies(self) -> Set[str]:
"""
Get all model file dependencies from the pipeline
Returns:
Set of model filenames required by the pipeline
"""
return self._model_dependencies.copy()
def validate(self) -> bool:
"""
Validate the parsed configuration
Returns:
True if configuration is valid, False otherwise
"""
if not self.pipeline_config:
logger.error("No pipeline configuration found")
return False
# Check that all required model files are specified
if not self.pipeline_config.model_file:
logger.error("Main pipeline model file not specified")
return False
# Validate action configurations
if not self._validate_actions(self.pipeline_config):
return False
# Validate parallel actions
for action in self.pipeline_config.parallel_actions:
if action.type == ActionType.POSTGRESQL_UPDATE_COMBINED:
wait_for = action.params.get('waitForBranches', [])
if wait_for:
# Check that referenced branches exist
branch_ids = self._get_all_branch_ids(self.pipeline_config)
for branch_id in wait_for:
if branch_id not in branch_ids:
logger.error(f"Referenced branch '{branch_id}' in waitForBranches not found")
return False
logger.info("Pipeline configuration validated successfully")
return True
def _validate_actions(self, config: Any) -> bool:
"""
Validate actions in a pipeline or branch configuration
Args:
config: Pipeline or branch configuration
Returns:
True if valid, False otherwise
"""
if hasattr(config, 'actions'):
for action in config.actions:
# Validate Redis actions need Redis config
if action.type in [ActionType.REDIS_SAVE_IMAGE, ActionType.REDIS_PUBLISH]:
if not self.redis_config:
logger.error(f"Action {action.type} requires Redis configuration")
return False
# Validate PostgreSQL actions need PostgreSQL config
if action.type in [ActionType.POSTGRESQL_UPDATE, ActionType.POSTGRESQL_UPDATE_COMBINED, ActionType.POSTGRESQL_INSERT]:
if not self.postgresql_config:
logger.error(f"Action {action.type} requires PostgreSQL configuration")
return False
# Recursively validate branches
if hasattr(config, 'branches'):
for branch in config.branches:
if not self._validate_actions(branch):
return False
return True
def _get_all_branch_ids(self, config: Any, branch_ids: Set[str] = None) -> Set[str]:
"""
Recursively collect all branch model IDs
Args:
config: Pipeline or branch configuration
branch_ids: Set to collect IDs into
Returns:
Set of all branch model IDs
"""
if branch_ids is None:
branch_ids = set()
if hasattr(config, 'branches'):
for branch in config.branches:
branch_ids.add(branch.model_id)
self._get_all_branch_ids(branch, branch_ids)
return branch_ids
def get_redis_config(self) -> Optional[RedisConfig]:
"""Get the Redis configuration"""
return self.redis_config
def get_postgresql_config(self) -> Optional[PostgreSQLConfig]:
"""Get the PostgreSQL configuration"""
return self.postgresql_config
def get_tracking_config(self) -> Optional[TrackingConfig]:
"""Get the tracking configuration"""
return self.tracking_config
def get_pipeline_config(self) -> Optional[PipelineConfig]:
"""Get the main pipeline configuration"""
return self.pipeline_config