357 lines
No EOL
12 KiB
Python
357 lines
No EOL
12 KiB
Python
"""
|
|
Pipeline Configuration Parser - Handles pipeline.json parsing and validation
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Dict, List, Any, Optional, Set
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ActionType(Enum):
|
|
"""Supported action types in pipeline"""
|
|
REDIS_SAVE_IMAGE = "redis_save_image"
|
|
REDIS_PUBLISH = "redis_publish"
|
|
POSTGRESQL_UPDATE = "postgresql_update"
|
|
POSTGRESQL_UPDATE_COMBINED = "postgresql_update_combined"
|
|
POSTGRESQL_INSERT = "postgresql_insert"
|
|
|
|
|
|
@dataclass
|
|
class RedisConfig:
|
|
"""Redis connection configuration"""
|
|
host: str
|
|
port: int = 6379
|
|
password: Optional[str] = None
|
|
db: int = 0
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'RedisConfig':
|
|
return cls(
|
|
host=data['host'],
|
|
port=data.get('port', 6379),
|
|
password=data.get('password'),
|
|
db=data.get('db', 0)
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class PostgreSQLConfig:
|
|
"""PostgreSQL connection configuration"""
|
|
host: str
|
|
port: int
|
|
database: str
|
|
username: str
|
|
password: str
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'PostgreSQLConfig':
|
|
return cls(
|
|
host=data['host'],
|
|
port=data.get('port', 5432),
|
|
database=data['database'],
|
|
username=data['username'],
|
|
password=data['password']
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Action:
|
|
"""Represents an action in the pipeline"""
|
|
type: ActionType
|
|
params: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'Action':
|
|
action_type = ActionType(data['type'])
|
|
params = {k: v for k, v in data.items() if k != 'type'}
|
|
return cls(type=action_type, params=params)
|
|
|
|
|
|
@dataclass
|
|
class ModelBranch:
|
|
"""Represents a branch in the pipeline with its own model"""
|
|
model_id: str
|
|
model_file: str
|
|
trigger_classes: List[str]
|
|
min_confidence: float = 0.5
|
|
crop: bool = False
|
|
crop_class: Optional[Any] = None # Can be string or list
|
|
parallel: bool = False
|
|
actions: List[Action] = field(default_factory=list)
|
|
branches: List['ModelBranch'] = field(default_factory=list)
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'ModelBranch':
|
|
actions = [Action.from_dict(a) for a in data.get('actions', [])]
|
|
branches = [cls.from_dict(b) for b in data.get('branches', [])]
|
|
|
|
return cls(
|
|
model_id=data['modelId'],
|
|
model_file=data['modelFile'],
|
|
trigger_classes=data.get('triggerClasses', []),
|
|
min_confidence=data.get('minConfidence', 0.5),
|
|
crop=data.get('crop', False),
|
|
crop_class=data.get('cropClass'),
|
|
parallel=data.get('parallel', False),
|
|
actions=actions,
|
|
branches=branches
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class TrackingConfig:
|
|
"""Configuration for the tracking phase"""
|
|
model_id: str
|
|
model_file: str
|
|
trigger_classes: List[str]
|
|
min_confidence: float = 0.6
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'TrackingConfig':
|
|
return cls(
|
|
model_id=data['modelId'],
|
|
model_file=data['modelFile'],
|
|
trigger_classes=data.get('triggerClasses', []),
|
|
min_confidence=data.get('minConfidence', 0.6)
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class PipelineConfig:
|
|
"""Main pipeline configuration"""
|
|
model_id: str
|
|
model_file: str
|
|
trigger_classes: List[str]
|
|
min_confidence: float = 0.5
|
|
crop: bool = False
|
|
branches: List[ModelBranch] = field(default_factory=list)
|
|
parallel_actions: List[Action] = field(default_factory=list)
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'PipelineConfig':
|
|
branches = [ModelBranch.from_dict(b) for b in data.get('branches', [])]
|
|
parallel_actions = [Action.from_dict(a) for a in data.get('parallelActions', [])]
|
|
|
|
return cls(
|
|
model_id=data['modelId'],
|
|
model_file=data['modelFile'],
|
|
trigger_classes=data.get('triggerClasses', []),
|
|
min_confidence=data.get('minConfidence', 0.5),
|
|
crop=data.get('crop', False),
|
|
branches=branches,
|
|
parallel_actions=parallel_actions
|
|
)
|
|
|
|
|
|
class PipelineParser:
|
|
"""Parser for pipeline.json configuration files"""
|
|
|
|
def __init__(self):
|
|
self.redis_config: Optional[RedisConfig] = None
|
|
self.postgresql_config: Optional[PostgreSQLConfig] = None
|
|
self.tracking_config: Optional[TrackingConfig] = None
|
|
self.pipeline_config: Optional[PipelineConfig] = None
|
|
self._model_dependencies: Set[str] = set()
|
|
|
|
def parse(self, config_path: Path) -> bool:
|
|
"""
|
|
Parse a pipeline.json configuration file
|
|
|
|
Args:
|
|
config_path: Path to the pipeline.json file
|
|
|
|
Returns:
|
|
True if parsing was successful, False otherwise
|
|
"""
|
|
try:
|
|
if not config_path.exists():
|
|
logger.error(f"Pipeline config not found: {config_path}")
|
|
return False
|
|
|
|
with open(config_path, 'r') as f:
|
|
data = json.load(f)
|
|
|
|
return self.parse_dict(data)
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"Invalid JSON in pipeline config: {str(e)}")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Failed to parse pipeline config: {str(e)}", exc_info=True)
|
|
return False
|
|
|
|
def parse_dict(self, data: Dict[str, Any]) -> bool:
|
|
"""
|
|
Parse a pipeline configuration from a dictionary
|
|
|
|
Args:
|
|
data: The configuration dictionary
|
|
|
|
Returns:
|
|
True if parsing was successful, False otherwise
|
|
"""
|
|
try:
|
|
# Parse Redis configuration
|
|
if 'redis' in data:
|
|
self.redis_config = RedisConfig.from_dict(data['redis'])
|
|
logger.debug(f"Parsed Redis config: {self.redis_config.host}:{self.redis_config.port}")
|
|
|
|
# Parse PostgreSQL configuration
|
|
if 'postgresql' in data:
|
|
self.postgresql_config = PostgreSQLConfig.from_dict(data['postgresql'])
|
|
logger.debug(f"Parsed PostgreSQL config: {self.postgresql_config.host}:{self.postgresql_config.port}/{self.postgresql_config.database}")
|
|
|
|
# Parse tracking configuration
|
|
if 'tracking' in data:
|
|
self.tracking_config = TrackingConfig.from_dict(data['tracking'])
|
|
self._model_dependencies.add(self.tracking_config.model_file)
|
|
logger.debug(f"Parsed tracking config: {self.tracking_config.model_id}")
|
|
|
|
# Parse main pipeline configuration
|
|
if 'pipeline' in data:
|
|
self.pipeline_config = PipelineConfig.from_dict(data['pipeline'])
|
|
self._collect_model_dependencies(self.pipeline_config)
|
|
logger.debug(f"Parsed pipeline config: {self.pipeline_config.model_id}")
|
|
|
|
logger.info(f"Successfully parsed pipeline configuration")
|
|
logger.debug(f"Model dependencies: {self._model_dependencies}")
|
|
return True
|
|
|
|
except KeyError as e:
|
|
logger.error(f"Missing required field in pipeline config: {str(e)}")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Failed to parse pipeline config: {str(e)}", exc_info=True)
|
|
return False
|
|
|
|
def _collect_model_dependencies(self, config: Any) -> None:
|
|
"""
|
|
Recursively collect all model file dependencies
|
|
|
|
Args:
|
|
config: Pipeline or branch configuration
|
|
"""
|
|
if hasattr(config, 'model_file'):
|
|
self._model_dependencies.add(config.model_file)
|
|
|
|
if hasattr(config, 'branches'):
|
|
for branch in config.branches:
|
|
self._collect_model_dependencies(branch)
|
|
|
|
def get_model_dependencies(self) -> Set[str]:
|
|
"""
|
|
Get all model file dependencies from the pipeline
|
|
|
|
Returns:
|
|
Set of model filenames required by the pipeline
|
|
"""
|
|
return self._model_dependencies.copy()
|
|
|
|
def validate(self) -> bool:
|
|
"""
|
|
Validate the parsed configuration
|
|
|
|
Returns:
|
|
True if configuration is valid, False otherwise
|
|
"""
|
|
if not self.pipeline_config:
|
|
logger.error("No pipeline configuration found")
|
|
return False
|
|
|
|
# Check that all required model files are specified
|
|
if not self.pipeline_config.model_file:
|
|
logger.error("Main pipeline model file not specified")
|
|
return False
|
|
|
|
# Validate action configurations
|
|
if not self._validate_actions(self.pipeline_config):
|
|
return False
|
|
|
|
# Validate parallel actions
|
|
for action in self.pipeline_config.parallel_actions:
|
|
if action.type == ActionType.POSTGRESQL_UPDATE_COMBINED:
|
|
wait_for = action.params.get('waitForBranches', [])
|
|
if wait_for:
|
|
# Check that referenced branches exist
|
|
branch_ids = self._get_all_branch_ids(self.pipeline_config)
|
|
for branch_id in wait_for:
|
|
if branch_id not in branch_ids:
|
|
logger.error(f"Referenced branch '{branch_id}' in waitForBranches not found")
|
|
return False
|
|
|
|
logger.info("Pipeline configuration validated successfully")
|
|
return True
|
|
|
|
def _validate_actions(self, config: Any) -> bool:
|
|
"""
|
|
Validate actions in a pipeline or branch configuration
|
|
|
|
Args:
|
|
config: Pipeline or branch configuration
|
|
|
|
Returns:
|
|
True if valid, False otherwise
|
|
"""
|
|
if hasattr(config, 'actions'):
|
|
for action in config.actions:
|
|
# Validate Redis actions need Redis config
|
|
if action.type in [ActionType.REDIS_SAVE_IMAGE, ActionType.REDIS_PUBLISH]:
|
|
if not self.redis_config:
|
|
logger.error(f"Action {action.type} requires Redis configuration")
|
|
return False
|
|
|
|
# Validate PostgreSQL actions need PostgreSQL config
|
|
if action.type in [ActionType.POSTGRESQL_UPDATE, ActionType.POSTGRESQL_UPDATE_COMBINED, ActionType.POSTGRESQL_INSERT]:
|
|
if not self.postgresql_config:
|
|
logger.error(f"Action {action.type} requires PostgreSQL configuration")
|
|
return False
|
|
|
|
# Recursively validate branches
|
|
if hasattr(config, 'branches'):
|
|
for branch in config.branches:
|
|
if not self._validate_actions(branch):
|
|
return False
|
|
|
|
return True
|
|
|
|
def _get_all_branch_ids(self, config: Any, branch_ids: Set[str] = None) -> Set[str]:
|
|
"""
|
|
Recursively collect all branch model IDs
|
|
|
|
Args:
|
|
config: Pipeline or branch configuration
|
|
branch_ids: Set to collect IDs into
|
|
|
|
Returns:
|
|
Set of all branch model IDs
|
|
"""
|
|
if branch_ids is None:
|
|
branch_ids = set()
|
|
|
|
if hasattr(config, 'branches'):
|
|
for branch in config.branches:
|
|
branch_ids.add(branch.model_id)
|
|
self._get_all_branch_ids(branch, branch_ids)
|
|
|
|
return branch_ids
|
|
|
|
def get_redis_config(self) -> Optional[RedisConfig]:
|
|
"""Get the Redis configuration"""
|
|
return self.redis_config
|
|
|
|
def get_postgresql_config(self) -> Optional[PostgreSQLConfig]:
|
|
"""Get the PostgreSQL configuration"""
|
|
return self.postgresql_config
|
|
|
|
def get_tracking_config(self) -> Optional[TrackingConfig]:
|
|
"""Get the tracking configuration"""
|
|
return self.tracking_config
|
|
|
|
def get_pipeline_config(self) -> Optional[PipelineConfig]:
|
|
"""Get the main pipeline configuration"""
|
|
return self.pipeline_config |