Refactor: done phase 2
This commit is contained in:
		
							parent
							
								
									8222e82dd7
								
							
						
					
					
						commit
						aa10d5a55c
					
				
					 6 changed files with 1337 additions and 23 deletions
				
			
		| 
						 | 
				
			
			@ -166,33 +166,40 @@ core/
 | 
			
		|||
- ✅ **Backward Compatibility**: All existing endpoints preserved
 | 
			
		||||
- ✅ **Modern FastAPI**: Lifespan events, Pydantic v2 compatibility
 | 
			
		||||
 | 
			
		||||
## 📋 Phase 2: Pipeline Configuration & Model Management
 | 
			
		||||
## ✅ Phase 2: Pipeline Configuration & Model Management - COMPLETED
 | 
			
		||||
 | 
			
		||||
### 2.1 Models Module (`core/models/`)
 | 
			
		||||
- [ ] **Create `pipeline.py`** - Pipeline.json parser
 | 
			
		||||
  - [ ] Extract pipeline configuration parsing from `pympta.py`
 | 
			
		||||
  - [ ] Implement pipeline validation
 | 
			
		||||
  - [ ] Add configuration schema validation
 | 
			
		||||
  - [ ] Handle Redis and PostgreSQL configuration parsing
 | 
			
		||||
- ✅ **Create `pipeline.py`** - Pipeline.json parser
 | 
			
		||||
  - ✅ Extract pipeline configuration parsing from `pympta.py`
 | 
			
		||||
  - ✅ Implement pipeline validation
 | 
			
		||||
  - ✅ Add configuration schema validation
 | 
			
		||||
  - ✅ Handle Redis and PostgreSQL configuration parsing
 | 
			
		||||
 | 
			
		||||
- [ ] **Create `manager.py`** - MPTA download and model loading
 | 
			
		||||
  - [ ] Extract MPTA download logic from `pympta.py`
 | 
			
		||||
  - [ ] Implement ZIP extraction and validation
 | 
			
		||||
  - [ ] Add model file management and caching
 | 
			
		||||
  - [ ] Handle model loading with GPU optimization
 | 
			
		||||
  - [ ] Implement model dependency resolution
 | 
			
		||||
- ✅ **Create `manager.py`** - MPTA download and model loading
 | 
			
		||||
  - ✅ Extract MPTA download logic from `pympta.py`
 | 
			
		||||
  - ✅ Implement ZIP extraction and validation
 | 
			
		||||
  - ✅ Add model file management and caching
 | 
			
		||||
  - ✅ Handle model loading with GPU optimization
 | 
			
		||||
  - ✅ Implement model dependency resolution
 | 
			
		||||
 | 
			
		||||
- [ ] **Create `inference.py`** - YOLO model wrapper
 | 
			
		||||
  - [ ] Create unified YOLO model interface
 | 
			
		||||
  - [ ] Add inference optimization and caching
 | 
			
		||||
  - [ ] Implement batch processing capabilities
 | 
			
		||||
  - [ ] Handle model switching and memory management
 | 
			
		||||
- ✅ **Create `inference.py`** - YOLO model wrapper
 | 
			
		||||
  - ✅ Create unified YOLO model interface
 | 
			
		||||
  - ✅ Add inference optimization and caching
 | 
			
		||||
  - ✅ Implement batch processing capabilities
 | 
			
		||||
  - ✅ Handle model switching and memory management
 | 
			
		||||
 | 
			
		||||
### 2.2 Testing Phase 2
 | 
			
		||||
- [ ] Test MPTA file download and extraction
 | 
			
		||||
- [ ] Test pipeline.json parsing and validation
 | 
			
		||||
- [ ] Test model loading with different configurations
 | 
			
		||||
- [ ] Verify GPU optimization works correctly
 | 
			
		||||
- ✅ Test MPTA file download and extraction
 | 
			
		||||
- ✅ Test pipeline.json parsing and validation
 | 
			
		||||
- ✅ Test model loading with different configurations
 | 
			
		||||
- ✅ Verify GPU optimization works correctly
 | 
			
		||||
 | 
			
		||||
### 2.3 Phase 2 Results
 | 
			
		||||
- ✅ **ModelManager**: Downloads, extracts, and manages MPTA files with model ID-based directory structure
 | 
			
		||||
- ✅ **PipelineParser**: Parses and validates pipeline.json with full support for Redis, PostgreSQL, tracking, and branches
 | 
			
		||||
- ✅ **YOLOWrapper**: Unified interface for YOLO models with caching, tracking, and classification support
 | 
			
		||||
- ✅ **Model Caching**: Shared model cache across instances to optimize memory usage
 | 
			
		||||
- ✅ **Dependency Resolution**: Automatically identifies and tracks all model file dependencies
 | 
			
		||||
 | 
			
		||||
## 📋 Phase 3: Streaming System
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,6 +17,7 @@ from .models import (
 | 
			
		|||
    RequestStateMessage, PatchSessionResultMessage
 | 
			
		||||
)
 | 
			
		||||
from .state import worker_state, SystemMetrics
 | 
			
		||||
from ..models import ModelManager
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -24,6 +25,9 @@ logger = logging.getLogger(__name__)
 | 
			
		|||
HEARTBEAT_INTERVAL = 2.0  # seconds
 | 
			
		||||
WORKER_TIMEOUT_MS = 10000
 | 
			
		||||
 | 
			
		||||
# Global model manager instance
 | 
			
		||||
model_manager = ModelManager()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WebSocketHandler:
 | 
			
		||||
    """
 | 
			
		||||
| 
						 | 
				
			
			@ -184,7 +188,10 @@ class WebSocketHandler:
 | 
			
		|||
        # Update worker state with new subscriptions
 | 
			
		||||
        worker_state.set_subscriptions(message.subscriptions)
 | 
			
		||||
 | 
			
		||||
        # TODO: Phase 2 - Integrate with model management and streaming
 | 
			
		||||
        # Phase 2: Download and manage models
 | 
			
		||||
        await self._ensure_models(message.subscriptions)
 | 
			
		||||
 | 
			
		||||
        # TODO: Phase 3 - Integrate with streaming management
 | 
			
		||||
        # For now, just log the subscription changes
 | 
			
		||||
        for subscription in message.subscriptions:
 | 
			
		||||
            logger.info(f"  Subscription: {subscription.subscriptionIdentifier} -> "
 | 
			
		||||
| 
						 | 
				
			
			@ -198,6 +205,79 @@ class WebSocketHandler:
 | 
			
		|||
 | 
			
		||||
        logger.info("Subscription list updated successfully")
 | 
			
		||||
 | 
			
		||||
    async def _ensure_models(self, subscriptions) -> None:
 | 
			
		||||
        """Ensure all required models are downloaded and available."""
 | 
			
		||||
        # Extract unique model requirements
 | 
			
		||||
        unique_models = {}
 | 
			
		||||
        for subscription in subscriptions:
 | 
			
		||||
            model_id = subscription.modelId
 | 
			
		||||
            if model_id not in unique_models:
 | 
			
		||||
                unique_models[model_id] = {
 | 
			
		||||
                    'model_url': subscription.modelUrl,
 | 
			
		||||
                    'model_name': subscription.modelName
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
        logger.info(f"[Model Management] Processing {len(unique_models)} unique models: {list(unique_models.keys())}")
 | 
			
		||||
 | 
			
		||||
        # Check and download models concurrently
 | 
			
		||||
        download_tasks = []
 | 
			
		||||
        for model_id, model_info in unique_models.items():
 | 
			
		||||
            task = asyncio.create_task(
 | 
			
		||||
                self._ensure_single_model(model_id, model_info['model_url'], model_info['model_name'])
 | 
			
		||||
            )
 | 
			
		||||
            download_tasks.append(task)
 | 
			
		||||
 | 
			
		||||
        # Wait for all downloads to complete
 | 
			
		||||
        if download_tasks:
 | 
			
		||||
            results = await asyncio.gather(*download_tasks, return_exceptions=True)
 | 
			
		||||
 | 
			
		||||
            # Log results
 | 
			
		||||
            success_count = 0
 | 
			
		||||
            for i, result in enumerate(results):
 | 
			
		||||
                model_id = list(unique_models.keys())[i]
 | 
			
		||||
                if isinstance(result, Exception):
 | 
			
		||||
                    logger.error(f"[Model Management] Failed to ensure model {model_id}: {result}")
 | 
			
		||||
                elif result:
 | 
			
		||||
                    success_count += 1
 | 
			
		||||
                    logger.info(f"[Model Management] Model {model_id} ready for use")
 | 
			
		||||
                else:
 | 
			
		||||
                    logger.error(f"[Model Management] Failed to ensure model {model_id}")
 | 
			
		||||
 | 
			
		||||
            logger.info(f"[Model Management] Successfully ensured {success_count}/{len(unique_models)} models")
 | 
			
		||||
 | 
			
		||||
    async def _ensure_single_model(self, model_id: int, model_url: str, model_name: str) -> bool:
 | 
			
		||||
        """Ensure a single model is downloaded and available."""
 | 
			
		||||
        try:
 | 
			
		||||
            # Check if model is already available
 | 
			
		||||
            if model_manager.is_model_downloaded(model_id):
 | 
			
		||||
                logger.info(f"[Model Management] Model {model_id} ({model_name}) already available")
 | 
			
		||||
                return True
 | 
			
		||||
 | 
			
		||||
            # Download and extract model in a thread pool to avoid blocking the event loop
 | 
			
		||||
            logger.info(f"[Model Management] Downloading model {model_id} ({model_name}) from {model_url}")
 | 
			
		||||
 | 
			
		||||
            # Use asyncio.to_thread for CPU-bound operations (Python 3.9+)
 | 
			
		||||
            # For compatibility, we'll use run_in_executor
 | 
			
		||||
            loop = asyncio.get_event_loop()
 | 
			
		||||
            model_path = await loop.run_in_executor(
 | 
			
		||||
                None,
 | 
			
		||||
                model_manager.ensure_model,
 | 
			
		||||
                model_id,
 | 
			
		||||
                model_url,
 | 
			
		||||
                model_name
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if model_path:
 | 
			
		||||
                logger.info(f"[Model Management] Successfully prepared model {model_id} at {model_path}")
 | 
			
		||||
                return True
 | 
			
		||||
            else:
 | 
			
		||||
                logger.error(f"[Model Management] Failed to prepare model {model_id}")
 | 
			
		||||
                return False
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"[Model Management] Exception ensuring model {model_id}: {str(e)}", exc_info=True)
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
    async def _handle_set_session_id(self, message: SetSessionIdMessage) -> None:
 | 
			
		||||
        """Handle setSessionId message."""
 | 
			
		||||
        display_identifier = message.payload.displayIdentifier
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1 +1,42 @@
 | 
			
		|||
# Models module for MPTA management and pipeline configuration
 | 
			
		||||
"""
 | 
			
		||||
Models Module - MPTA management, pipeline configuration, and YOLO inference
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
from .manager import ModelManager
 | 
			
		||||
from .pipeline import (
 | 
			
		||||
    PipelineParser,
 | 
			
		||||
    PipelineConfig,
 | 
			
		||||
    TrackingConfig,
 | 
			
		||||
    ModelBranch,
 | 
			
		||||
    Action,
 | 
			
		||||
    ActionType,
 | 
			
		||||
    RedisConfig,
 | 
			
		||||
    PostgreSQLConfig
 | 
			
		||||
)
 | 
			
		||||
from .inference import (
 | 
			
		||||
    YOLOWrapper,
 | 
			
		||||
    ModelInferenceManager,
 | 
			
		||||
    Detection,
 | 
			
		||||
    InferenceResult
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    # Manager
 | 
			
		||||
    'ModelManager',
 | 
			
		||||
 | 
			
		||||
    # Pipeline
 | 
			
		||||
    'PipelineParser',
 | 
			
		||||
    'PipelineConfig',
 | 
			
		||||
    'TrackingConfig',
 | 
			
		||||
    'ModelBranch',
 | 
			
		||||
    'Action',
 | 
			
		||||
    'ActionType',
 | 
			
		||||
    'RedisConfig',
 | 
			
		||||
    'PostgreSQLConfig',
 | 
			
		||||
 | 
			
		||||
    # Inference
 | 
			
		||||
    'YOLOWrapper',
 | 
			
		||||
    'ModelInferenceManager',
 | 
			
		||||
    'Detection',
 | 
			
		||||
    'InferenceResult',
 | 
			
		||||
]
 | 
			
		||||
							
								
								
									
										468
									
								
								core/models/inference.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										468
									
								
								core/models/inference.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,468 @@
 | 
			
		|||
"""
 | 
			
		||||
YOLO Model Inference Wrapper - Handles model loading and inference optimization
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import torch
 | 
			
		||||
import numpy as np
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Dict, List, Optional, Any, Tuple, Union
 | 
			
		||||
from threading import Lock
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
import cv2
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Detection:
 | 
			
		||||
    """Represents a single detection result"""
 | 
			
		||||
    bbox: List[float]  # [x1, y1, x2, y2]
 | 
			
		||||
    confidence: float
 | 
			
		||||
    class_id: int
 | 
			
		||||
    class_name: str
 | 
			
		||||
    track_id: Optional[int] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class InferenceResult:
 | 
			
		||||
    """Result from model inference"""
 | 
			
		||||
    detections: List[Detection]
 | 
			
		||||
    image_shape: Tuple[int, int]  # (height, width)
 | 
			
		||||
    inference_time: float
 | 
			
		||||
    model_id: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class YOLOWrapper:
 | 
			
		||||
    """Wrapper for YOLO models with caching and optimization"""
 | 
			
		||||
 | 
			
		||||
    # Class-level model cache shared across all instances
 | 
			
		||||
    _model_cache: Dict[str, Any] = {}
 | 
			
		||||
    _cache_lock = Lock()
 | 
			
		||||
 | 
			
		||||
    def __init__(self, model_path: Path, model_id: str, device: Optional[str] = None):
 | 
			
		||||
        """
 | 
			
		||||
        Initialize YOLO wrapper
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            model_path: Path to the .pt model file
 | 
			
		||||
            model_id: Unique identifier for the model
 | 
			
		||||
            device: Device to run inference on ('cuda', 'cpu', or None for auto)
 | 
			
		||||
        """
 | 
			
		||||
        self.model_path = model_path
 | 
			
		||||
        self.model_id = model_id
 | 
			
		||||
 | 
			
		||||
        # Auto-detect device if not specified
 | 
			
		||||
        if device is None:
 | 
			
		||||
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
 | 
			
		||||
        else:
 | 
			
		||||
            self.device = device
 | 
			
		||||
 | 
			
		||||
        self.model = None
 | 
			
		||||
        self._class_names = []
 | 
			
		||||
        self._load_model()
 | 
			
		||||
 | 
			
		||||
        logger.info(f"Initialized YOLO wrapper for {model_id} on {self.device}")
 | 
			
		||||
 | 
			
		||||
    def _load_model(self) -> None:
 | 
			
		||||
        """Load the YOLO model with caching"""
 | 
			
		||||
        cache_key = str(self.model_path)
 | 
			
		||||
 | 
			
		||||
        with self._cache_lock:
 | 
			
		||||
            # Check if model is already cached
 | 
			
		||||
            if cache_key in self._model_cache:
 | 
			
		||||
                logger.info(f"Loading model {self.model_id} from cache")
 | 
			
		||||
                self.model = self._model_cache[cache_key]
 | 
			
		||||
                self._extract_class_names()
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
            # Load model
 | 
			
		||||
            try:
 | 
			
		||||
                from ultralytics import YOLO
 | 
			
		||||
 | 
			
		||||
                logger.info(f"Loading YOLO model from {self.model_path}")
 | 
			
		||||
                self.model = YOLO(str(self.model_path))
 | 
			
		||||
 | 
			
		||||
                # Move model to device
 | 
			
		||||
                if self.device == 'cuda' and torch.cuda.is_available():
 | 
			
		||||
                    self.model.to('cuda')
 | 
			
		||||
                    logger.info(f"Model {self.model_id} moved to GPU")
 | 
			
		||||
 | 
			
		||||
                # Cache the model
 | 
			
		||||
                self._model_cache[cache_key] = self.model
 | 
			
		||||
                self._extract_class_names()
 | 
			
		||||
 | 
			
		||||
                logger.info(f"Successfully loaded model {self.model_id}")
 | 
			
		||||
 | 
			
		||||
            except ImportError:
 | 
			
		||||
                logger.error("Ultralytics YOLO not installed. Install with: pip install ultralytics")
 | 
			
		||||
                raise
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.error(f"Failed to load YOLO model {self.model_id}: {str(e)}", exc_info=True)
 | 
			
		||||
                raise
 | 
			
		||||
 | 
			
		||||
    def _extract_class_names(self) -> None:
 | 
			
		||||
        """Extract class names from the model"""
 | 
			
		||||
        try:
 | 
			
		||||
            if hasattr(self.model, 'names'):
 | 
			
		||||
                self._class_names = self.model.names
 | 
			
		||||
            elif hasattr(self.model, 'model') and hasattr(self.model.model, 'names'):
 | 
			
		||||
                self._class_names = self.model.model.names
 | 
			
		||||
            else:
 | 
			
		||||
                logger.warning(f"Could not extract class names from model {self.model_id}")
 | 
			
		||||
                self._class_names = {}
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Failed to extract class names: {str(e)}")
 | 
			
		||||
            self._class_names = {}
 | 
			
		||||
 | 
			
		||||
    def infer(
 | 
			
		||||
        self,
 | 
			
		||||
        image: np.ndarray,
 | 
			
		||||
        confidence_threshold: float = 0.5,
 | 
			
		||||
        trigger_classes: Optional[List[str]] = None,
 | 
			
		||||
        iou_threshold: float = 0.45
 | 
			
		||||
    ) -> InferenceResult:
 | 
			
		||||
        """
 | 
			
		||||
        Run inference on an image
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            image: Input image as numpy array (BGR format)
 | 
			
		||||
            confidence_threshold: Minimum confidence for detections
 | 
			
		||||
            trigger_classes: List of class names to filter (None = all classes)
 | 
			
		||||
            iou_threshold: IoU threshold for NMS
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            InferenceResult containing detections
 | 
			
		||||
        """
 | 
			
		||||
        if self.model is None:
 | 
			
		||||
            raise RuntimeError(f"Model {self.model_id} not loaded")
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            import time
 | 
			
		||||
            start_time = time.time()
 | 
			
		||||
 | 
			
		||||
            # Run inference
 | 
			
		||||
            results = self.model(
 | 
			
		||||
                image,
 | 
			
		||||
                conf=confidence_threshold,
 | 
			
		||||
                iou=iou_threshold,
 | 
			
		||||
                verbose=False
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            inference_time = time.time() - start_time
 | 
			
		||||
 | 
			
		||||
            # Parse results
 | 
			
		||||
            detections = self._parse_results(results[0], trigger_classes)
 | 
			
		||||
 | 
			
		||||
            return InferenceResult(
 | 
			
		||||
                detections=detections,
 | 
			
		||||
                image_shape=(image.shape[0], image.shape[1]),
 | 
			
		||||
                inference_time=inference_time,
 | 
			
		||||
                model_id=self.model_id
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Inference failed for model {self.model_id}: {str(e)}", exc_info=True)
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
    def _parse_results(
 | 
			
		||||
        self,
 | 
			
		||||
        result: Any,
 | 
			
		||||
        trigger_classes: Optional[List[str]] = None
 | 
			
		||||
    ) -> List[Detection]:
 | 
			
		||||
        """
 | 
			
		||||
        Parse YOLO results into Detection objects
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            result: YOLO result object
 | 
			
		||||
            trigger_classes: Optional list of class names to filter
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            List of Detection objects
 | 
			
		||||
        """
 | 
			
		||||
        detections = []
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            if result.boxes is None:
 | 
			
		||||
                return detections
 | 
			
		||||
 | 
			
		||||
            boxes = result.boxes
 | 
			
		||||
            for i in range(len(boxes)):
 | 
			
		||||
                # Get box coordinates
 | 
			
		||||
                box = boxes.xyxy[i].cpu().numpy()
 | 
			
		||||
                x1, y1, x2, y2 = box
 | 
			
		||||
 | 
			
		||||
                # Get confidence and class
 | 
			
		||||
                conf = float(boxes.conf[i])
 | 
			
		||||
                cls_id = int(boxes.cls[i])
 | 
			
		||||
 | 
			
		||||
                # Get class name
 | 
			
		||||
                class_name = self._class_names.get(cls_id, f"class_{cls_id}")
 | 
			
		||||
 | 
			
		||||
                # Filter by trigger classes if specified
 | 
			
		||||
                if trigger_classes and class_name not in trigger_classes:
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                # Get track ID if available
 | 
			
		||||
                track_id = None
 | 
			
		||||
                if hasattr(boxes, 'id') and boxes.id is not None:
 | 
			
		||||
                    track_id = int(boxes.id[i])
 | 
			
		||||
 | 
			
		||||
                detection = Detection(
 | 
			
		||||
                    bbox=[float(x1), float(y1), float(x2), float(y2)],
 | 
			
		||||
                    confidence=conf,
 | 
			
		||||
                    class_id=cls_id,
 | 
			
		||||
                    class_name=class_name,
 | 
			
		||||
                    track_id=track_id
 | 
			
		||||
                )
 | 
			
		||||
                detections.append(detection)
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Failed to parse results: {str(e)}", exc_info=True)
 | 
			
		||||
 | 
			
		||||
        return detections
 | 
			
		||||
 | 
			
		||||
    def track(
 | 
			
		||||
        self,
 | 
			
		||||
        image: np.ndarray,
 | 
			
		||||
        confidence_threshold: float = 0.5,
 | 
			
		||||
        trigger_classes: Optional[List[str]] = None,
 | 
			
		||||
        persist: bool = True
 | 
			
		||||
    ) -> InferenceResult:
 | 
			
		||||
        """
 | 
			
		||||
        Run tracking on an image
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            image: Input image as numpy array (BGR format)
 | 
			
		||||
            confidence_threshold: Minimum confidence for detections
 | 
			
		||||
            trigger_classes: List of class names to filter
 | 
			
		||||
            persist: Whether to persist tracks across frames
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            InferenceResult containing detections with track IDs
 | 
			
		||||
        """
 | 
			
		||||
        if self.model is None:
 | 
			
		||||
            raise RuntimeError(f"Model {self.model_id} not loaded")
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            import time
 | 
			
		||||
            start_time = time.time()
 | 
			
		||||
 | 
			
		||||
            # Run tracking
 | 
			
		||||
            results = self.model.track(
 | 
			
		||||
                image,
 | 
			
		||||
                conf=confidence_threshold,
 | 
			
		||||
                persist=persist,
 | 
			
		||||
                verbose=False
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            inference_time = time.time() - start_time
 | 
			
		||||
 | 
			
		||||
            # Parse results
 | 
			
		||||
            detections = self._parse_results(results[0], trigger_classes)
 | 
			
		||||
 | 
			
		||||
            return InferenceResult(
 | 
			
		||||
                detections=detections,
 | 
			
		||||
                image_shape=(image.shape[0], image.shape[1]),
 | 
			
		||||
                inference_time=inference_time,
 | 
			
		||||
                model_id=self.model_id
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Tracking failed for model {self.model_id}: {str(e)}", exc_info=True)
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
    def predict_classification(
 | 
			
		||||
        self,
 | 
			
		||||
        image: np.ndarray,
 | 
			
		||||
        top_k: int = 1
 | 
			
		||||
    ) -> Dict[str, float]:
 | 
			
		||||
        """
 | 
			
		||||
        Run classification on an image
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            image: Input image as numpy array (BGR format)
 | 
			
		||||
            top_k: Number of top predictions to return
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Dictionary of class_name -> confidence scores
 | 
			
		||||
        """
 | 
			
		||||
        if self.model is None:
 | 
			
		||||
            raise RuntimeError(f"Model {self.model_id} not loaded")
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            # Run inference
 | 
			
		||||
            results = self.model(image, verbose=False)
 | 
			
		||||
 | 
			
		||||
            # For classification models, extract probabilities
 | 
			
		||||
            if hasattr(results[0], 'probs'):
 | 
			
		||||
                probs = results[0].probs
 | 
			
		||||
                top_indices = probs.top5[:top_k]
 | 
			
		||||
                top_conf = probs.top5conf[:top_k].cpu().numpy()
 | 
			
		||||
 | 
			
		||||
                predictions = {}
 | 
			
		||||
                for idx, conf in zip(top_indices, top_conf):
 | 
			
		||||
                    class_name = self._class_names.get(int(idx), f"class_{idx}")
 | 
			
		||||
                    predictions[class_name] = float(conf)
 | 
			
		||||
 | 
			
		||||
                return predictions
 | 
			
		||||
            else:
 | 
			
		||||
                logger.warning(f"Model {self.model_id} does not support classification")
 | 
			
		||||
                return {}
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Classification failed for model {self.model_id}: {str(e)}", exc_info=True)
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
    def crop_detection(
 | 
			
		||||
        self,
 | 
			
		||||
        image: np.ndarray,
 | 
			
		||||
        detection: Detection,
 | 
			
		||||
        padding: int = 0
 | 
			
		||||
    ) -> np.ndarray:
 | 
			
		||||
        """
 | 
			
		||||
        Crop image to detection bounding box
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            image: Original image
 | 
			
		||||
            detection: Detection to crop
 | 
			
		||||
            padding: Additional padding around the box
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Cropped image region
 | 
			
		||||
        """
 | 
			
		||||
        h, w = image.shape[:2]
 | 
			
		||||
        x1, y1, x2, y2 = detection.bbox
 | 
			
		||||
 | 
			
		||||
        # Add padding and clip to image boundaries
 | 
			
		||||
        x1 = max(0, int(x1) - padding)
 | 
			
		||||
        y1 = max(0, int(y1) - padding)
 | 
			
		||||
        x2 = min(w, int(x2) + padding)
 | 
			
		||||
        y2 = min(h, int(y2) + padding)
 | 
			
		||||
 | 
			
		||||
        return image[y1:y2, x1:x2]
 | 
			
		||||
 | 
			
		||||
    def get_class_names(self) -> Dict[int, str]:
 | 
			
		||||
        """Get the class names dictionary"""
 | 
			
		||||
        return self._class_names.copy()
 | 
			
		||||
 | 
			
		||||
    def get_num_classes(self) -> int:
 | 
			
		||||
        """Get the number of classes the model can detect"""
 | 
			
		||||
        return len(self._class_names)
 | 
			
		||||
 | 
			
		||||
    def clear_cache(self) -> None:
 | 
			
		||||
        """Clear the model cache"""
 | 
			
		||||
        with self._cache_lock:
 | 
			
		||||
            cache_key = str(self.model_path)
 | 
			
		||||
            if cache_key in self._model_cache:
 | 
			
		||||
                del self._model_cache[cache_key]
 | 
			
		||||
                logger.info(f"Cleared cache for model {self.model_id}")
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def clear_all_cache(cls) -> None:
 | 
			
		||||
        """Clear all cached models"""
 | 
			
		||||
        with cls._cache_lock:
 | 
			
		||||
            cls._model_cache.clear()
 | 
			
		||||
            logger.info("Cleared all model cache")
 | 
			
		||||
 | 
			
		||||
    def warmup(self, image_size: Tuple[int, int] = (640, 640)) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Warmup the model with a dummy inference
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            image_size: Size of dummy image (height, width)
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            dummy_image = np.zeros((image_size[0], image_size[1], 3), dtype=np.uint8)
 | 
			
		||||
            self.infer(dummy_image, confidence_threshold=0.5)
 | 
			
		||||
            logger.info(f"Model {self.model_id} warmed up")
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.warning(f"Failed to warmup model {self.model_id}: {str(e)}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelInferenceManager:
 | 
			
		||||
    """Manages multiple YOLO models for a pipeline"""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, model_dir: Path):
 | 
			
		||||
        """
 | 
			
		||||
        Initialize the inference manager
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            model_dir: Directory containing model files
 | 
			
		||||
        """
 | 
			
		||||
        self.model_dir = model_dir
 | 
			
		||||
        self.models: Dict[str, YOLOWrapper] = {}
 | 
			
		||||
        self._lock = Lock()
 | 
			
		||||
 | 
			
		||||
        logger.info(f"Initialized ModelInferenceManager with model directory: {model_dir}")
 | 
			
		||||
 | 
			
		||||
    def load_model(
 | 
			
		||||
        self,
 | 
			
		||||
        model_id: str,
 | 
			
		||||
        model_file: str,
 | 
			
		||||
        device: Optional[str] = None
 | 
			
		||||
    ) -> YOLOWrapper:
 | 
			
		||||
        """
 | 
			
		||||
        Load a model for inference
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            model_id: Unique identifier for the model
 | 
			
		||||
            model_file: Filename of the model
 | 
			
		||||
            device: Device to run on
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            YOLOWrapper instance
 | 
			
		||||
        """
 | 
			
		||||
        with self._lock:
 | 
			
		||||
            # Check if already loaded
 | 
			
		||||
            if model_id in self.models:
 | 
			
		||||
                logger.debug(f"Model {model_id} already loaded")
 | 
			
		||||
                return self.models[model_id]
 | 
			
		||||
 | 
			
		||||
            # Load the model
 | 
			
		||||
            model_path = self.model_dir / model_file
 | 
			
		||||
            if not model_path.exists():
 | 
			
		||||
                raise FileNotFoundError(f"Model file not found: {model_path}")
 | 
			
		||||
 | 
			
		||||
            wrapper = YOLOWrapper(model_path, model_id, device)
 | 
			
		||||
            self.models[model_id] = wrapper
 | 
			
		||||
 | 
			
		||||
            return wrapper
 | 
			
		||||
 | 
			
		||||
    def get_model(self, model_id: str) -> Optional[YOLOWrapper]:
 | 
			
		||||
        """
 | 
			
		||||
        Get a loaded model
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            model_id: Model identifier
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            YOLOWrapper instance or None if not loaded
 | 
			
		||||
        """
 | 
			
		||||
        return self.models.get(model_id)
 | 
			
		||||
 | 
			
		||||
    def unload_model(self, model_id: str) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Unload a model to free memory
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            model_id: Model identifier
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            True if unloaded, False if not found
 | 
			
		||||
        """
 | 
			
		||||
        with self._lock:
 | 
			
		||||
            if model_id in self.models:
 | 
			
		||||
                self.models[model_id].clear_cache()
 | 
			
		||||
                del self.models[model_id]
 | 
			
		||||
                logger.info(f"Unloaded model {model_id}")
 | 
			
		||||
                return True
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
    def unload_all(self) -> None:
 | 
			
		||||
        """Unload all models"""
 | 
			
		||||
        with self._lock:
 | 
			
		||||
            for model_id in list(self.models.keys()):
 | 
			
		||||
                self.models[model_id].clear_cache()
 | 
			
		||||
            self.models.clear()
 | 
			
		||||
            logger.info("Unloaded all models")
 | 
			
		||||
							
								
								
									
										361
									
								
								core/models/manager.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										361
									
								
								core/models/manager.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,361 @@
 | 
			
		|||
"""
 | 
			
		||||
Model Manager Module - Handles MPTA download, extraction, and model loading
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import logging
 | 
			
		||||
import zipfile
 | 
			
		||||
import json
 | 
			
		||||
import hashlib
 | 
			
		||||
import requests
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Dict, Optional, Any, Set
 | 
			
		||||
from threading import Lock
 | 
			
		||||
from urllib.parse import urlparse, parse_qs
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelManager:
 | 
			
		||||
    """Manages MPTA model downloads, extraction, and caching"""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, models_dir: str = "models"):
 | 
			
		||||
        """
 | 
			
		||||
        Initialize the Model Manager
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            models_dir: Base directory for storing models
 | 
			
		||||
        """
 | 
			
		||||
        self.models_dir = Path(models_dir)
 | 
			
		||||
        self.models_dir.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
        # Track downloaded models to avoid duplicates
 | 
			
		||||
        self._downloaded_models: Set[int] = set()
 | 
			
		||||
        self._model_paths: Dict[int, Path] = {}
 | 
			
		||||
        self._download_lock = Lock()
 | 
			
		||||
 | 
			
		||||
        # Scan existing models
 | 
			
		||||
        self._scan_existing_models()
 | 
			
		||||
 | 
			
		||||
        logger.info(f"ModelManager initialized with models directory: {self.models_dir}")
 | 
			
		||||
        logger.info(f"Found existing models: {list(self._downloaded_models)}")
 | 
			
		||||
 | 
			
		||||
    def _scan_existing_models(self) -> None:
 | 
			
		||||
        """Scan the models directory for existing downloaded models"""
 | 
			
		||||
        if not self.models_dir.exists():
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        for model_dir in self.models_dir.iterdir():
 | 
			
		||||
            if model_dir.is_dir() and model_dir.name.isdigit():
 | 
			
		||||
                model_id = int(model_dir.name)
 | 
			
		||||
                # Check if extraction was successful by looking for pipeline.json
 | 
			
		||||
                extracted_dirs = list(model_dir.glob("*/pipeline.json"))
 | 
			
		||||
                if extracted_dirs:
 | 
			
		||||
                    self._downloaded_models.add(model_id)
 | 
			
		||||
                    # Store path to the extracted model directory
 | 
			
		||||
                    self._model_paths[model_id] = extracted_dirs[0].parent
 | 
			
		||||
                    logger.debug(f"Found existing model {model_id} at {extracted_dirs[0].parent}")
 | 
			
		||||
 | 
			
		||||
    def get_model_path(self, model_id: int) -> Optional[Path]:
 | 
			
		||||
        """
 | 
			
		||||
        Get the path to an extracted model directory
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            model_id: The model ID
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Path to the extracted model directory or None if not found
 | 
			
		||||
        """
 | 
			
		||||
        return self._model_paths.get(model_id)
 | 
			
		||||
 | 
			
		||||
    def is_model_downloaded(self, model_id: int) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Check if a model has already been downloaded and extracted
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            model_id: The model ID to check
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            True if the model is already available
 | 
			
		||||
        """
 | 
			
		||||
        return model_id in self._downloaded_models
 | 
			
		||||
 | 
			
		||||
    def ensure_model(self, model_id: int, model_url: str, model_name: str = None) -> Optional[Path]:
 | 
			
		||||
        """
 | 
			
		||||
        Ensure a model is downloaded and extracted, downloading if necessary
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            model_id: The model ID
 | 
			
		||||
            model_url: URL to download the MPTA file from
 | 
			
		||||
            model_name: Optional model name for logging
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Path to the extracted model directory or None if failed
 | 
			
		||||
        """
 | 
			
		||||
        # Check if already downloaded
 | 
			
		||||
        if self.is_model_downloaded(model_id):
 | 
			
		||||
            logger.info(f"Model {model_id} already available at {self._model_paths[model_id]}")
 | 
			
		||||
            return self._model_paths[model_id]
 | 
			
		||||
 | 
			
		||||
        # Download and extract with lock to prevent concurrent downloads of same model
 | 
			
		||||
        with self._download_lock:
 | 
			
		||||
            # Double-check after acquiring lock
 | 
			
		||||
            if self.is_model_downloaded(model_id):
 | 
			
		||||
                return self._model_paths[model_id]
 | 
			
		||||
 | 
			
		||||
            logger.info(f"Model {model_id} not found locally, downloading from {model_url}")
 | 
			
		||||
 | 
			
		||||
            # Create model directory
 | 
			
		||||
            model_dir = self.models_dir / str(model_id)
 | 
			
		||||
            model_dir.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
            # Extract filename from URL
 | 
			
		||||
            mpta_filename = self._extract_filename_from_url(model_url, model_name, model_id)
 | 
			
		||||
            mpta_path = model_dir / mpta_filename
 | 
			
		||||
 | 
			
		||||
            # Download MPTA file
 | 
			
		||||
            if not self._download_mpta(model_url, mpta_path):
 | 
			
		||||
                logger.error(f"Failed to download model {model_id}")
 | 
			
		||||
                return None
 | 
			
		||||
 | 
			
		||||
            # Extract MPTA file
 | 
			
		||||
            extracted_path = self._extract_mpta(mpta_path, model_dir)
 | 
			
		||||
            if not extracted_path:
 | 
			
		||||
                logger.error(f"Failed to extract model {model_id}")
 | 
			
		||||
                return None
 | 
			
		||||
 | 
			
		||||
            # Mark as downloaded and store path
 | 
			
		||||
            self._downloaded_models.add(model_id)
 | 
			
		||||
            self._model_paths[model_id] = extracted_path
 | 
			
		||||
 | 
			
		||||
            logger.info(f"Successfully prepared model {model_id} at {extracted_path}")
 | 
			
		||||
            return extracted_path
 | 
			
		||||
 | 
			
		||||
    def _extract_filename_from_url(self, url: str, model_name: str = None, model_id: int = None) -> str:
 | 
			
		||||
        """
 | 
			
		||||
        Extract a suitable filename from the URL
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            url: The URL to extract filename from
 | 
			
		||||
            model_name: Optional model name
 | 
			
		||||
            model_id: Optional model ID
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A suitable filename for the MPTA file
 | 
			
		||||
        """
 | 
			
		||||
        parsed = urlparse(url)
 | 
			
		||||
        path = parsed.path
 | 
			
		||||
 | 
			
		||||
        # Try to get filename from path
 | 
			
		||||
        if path:
 | 
			
		||||
            filename = os.path.basename(path)
 | 
			
		||||
            if filename and filename.endswith('.mpta'):
 | 
			
		||||
                return filename
 | 
			
		||||
 | 
			
		||||
        # Fallback to constructed name
 | 
			
		||||
        if model_name:
 | 
			
		||||
            return f"{model_name}-{model_id}.mpta"
 | 
			
		||||
        else:
 | 
			
		||||
            return f"model-{model_id}.mpta"
 | 
			
		||||
 | 
			
		||||
    def _download_mpta(self, url: str, dest_path: Path) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Download an MPTA file from a URL
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            url: URL to download from
 | 
			
		||||
            dest_path: Destination path for the file
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            True if successful, False otherwise
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            logger.info(f"Starting download of model from {url}")
 | 
			
		||||
            logger.debug(f"Download destination: {dest_path}")
 | 
			
		||||
 | 
			
		||||
            response = requests.get(url, stream=True, timeout=300)
 | 
			
		||||
            if response.status_code != 200:
 | 
			
		||||
                logger.error(f"Failed to download MPTA file (status {response.status_code})")
 | 
			
		||||
                return False
 | 
			
		||||
 | 
			
		||||
            file_size = int(response.headers.get('content-length', 0))
 | 
			
		||||
            logger.info(f"Model file size: {file_size/1024/1024:.2f} MB")
 | 
			
		||||
 | 
			
		||||
            downloaded = 0
 | 
			
		||||
            last_log_percent = 0
 | 
			
		||||
 | 
			
		||||
            with open(dest_path, 'wb') as f:
 | 
			
		||||
                for chunk in response.iter_content(chunk_size=8192):
 | 
			
		||||
                    if chunk:
 | 
			
		||||
                        f.write(chunk)
 | 
			
		||||
                        downloaded += len(chunk)
 | 
			
		||||
 | 
			
		||||
                        # Log progress every 10%
 | 
			
		||||
                        if file_size > 0:
 | 
			
		||||
                            percent = int(downloaded * 100 / file_size)
 | 
			
		||||
                            if percent >= last_log_percent + 10:
 | 
			
		||||
                                logger.debug(f"Download progress: {percent}%")
 | 
			
		||||
                                last_log_percent = percent
 | 
			
		||||
 | 
			
		||||
            logger.info(f"Successfully downloaded MPTA file to {dest_path}")
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        except requests.RequestException as e:
 | 
			
		||||
            logger.error(f"Network error downloading MPTA: {str(e)}", exc_info=True)
 | 
			
		||||
            # Clean up partial download
 | 
			
		||||
            if dest_path.exists():
 | 
			
		||||
                dest_path.unlink()
 | 
			
		||||
            return False
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Unexpected error downloading MPTA: {str(e)}", exc_info=True)
 | 
			
		||||
            # Clean up partial download
 | 
			
		||||
            if dest_path.exists():
 | 
			
		||||
                dest_path.unlink()
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
    def _extract_mpta(self, mpta_path: Path, target_dir: Path) -> Optional[Path]:
 | 
			
		||||
        """
 | 
			
		||||
        Extract an MPTA (ZIP) file to the target directory
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            mpta_path: Path to the MPTA file
 | 
			
		||||
            target_dir: Directory to extract to
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Path to the extracted model directory containing pipeline.json, or None if failed
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            if not mpta_path.exists():
 | 
			
		||||
                logger.error(f"MPTA file not found: {mpta_path}")
 | 
			
		||||
                return None
 | 
			
		||||
 | 
			
		||||
            logger.info(f"Extracting MPTA file from {mpta_path} to {target_dir}")
 | 
			
		||||
 | 
			
		||||
            with zipfile.ZipFile(mpta_path, 'r') as zip_ref:
 | 
			
		||||
                # Get list of files
 | 
			
		||||
                file_list = zip_ref.namelist()
 | 
			
		||||
                logger.debug(f"Files in MPTA archive: {len(file_list)} files")
 | 
			
		||||
 | 
			
		||||
                # Extract all files
 | 
			
		||||
                zip_ref.extractall(target_dir)
 | 
			
		||||
 | 
			
		||||
            logger.info(f"Successfully extracted MPTA file to {target_dir}")
 | 
			
		||||
 | 
			
		||||
            # Find the directory containing pipeline.json
 | 
			
		||||
            pipeline_files = list(target_dir.glob("*/pipeline.json"))
 | 
			
		||||
            if not pipeline_files:
 | 
			
		||||
                # Check if pipeline.json is in root
 | 
			
		||||
                if (target_dir / "pipeline.json").exists():
 | 
			
		||||
                    logger.info(f"Found pipeline.json in root of {target_dir}")
 | 
			
		||||
                    return target_dir
 | 
			
		||||
                logger.error(f"No pipeline.json found after extraction in {target_dir}")
 | 
			
		||||
                return None
 | 
			
		||||
 | 
			
		||||
            # Return the directory containing pipeline.json
 | 
			
		||||
            extracted_dir = pipeline_files[0].parent
 | 
			
		||||
            logger.info(f"Extracted model to {extracted_dir}")
 | 
			
		||||
 | 
			
		||||
            # Keep the MPTA file for reference but could delete if space is a concern
 | 
			
		||||
            # mpta_path.unlink()
 | 
			
		||||
            # logger.debug(f"Removed MPTA file after extraction: {mpta_path}")
 | 
			
		||||
 | 
			
		||||
            return extracted_dir
 | 
			
		||||
 | 
			
		||||
        except zipfile.BadZipFile as e:
 | 
			
		||||
            logger.error(f"Invalid ZIP/MPTA file {mpta_path}: {str(e)}", exc_info=True)
 | 
			
		||||
            return None
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Failed to extract MPTA file {mpta_path}: {str(e)}", exc_info=True)
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    def load_pipeline_config(self, model_id: int) -> Optional[Dict[str, Any]]:
 | 
			
		||||
        """
 | 
			
		||||
        Load the pipeline.json configuration for a model
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            model_id: The model ID
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            The pipeline configuration dictionary or None if not found
 | 
			
		||||
        """
 | 
			
		||||
        model_path = self.get_model_path(model_id)
 | 
			
		||||
        if not model_path:
 | 
			
		||||
            logger.error(f"Model {model_id} not found")
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        pipeline_path = model_path / "pipeline.json"
 | 
			
		||||
        if not pipeline_path.exists():
 | 
			
		||||
            logger.error(f"pipeline.json not found for model {model_id}")
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            with open(pipeline_path, 'r') as f:
 | 
			
		||||
                config = json.load(f)
 | 
			
		||||
            logger.debug(f"Loaded pipeline config for model {model_id}")
 | 
			
		||||
            return config
 | 
			
		||||
        except json.JSONDecodeError as e:
 | 
			
		||||
            logger.error(f"Invalid JSON in pipeline.json for model {model_id}: {str(e)}")
 | 
			
		||||
            return None
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Failed to load pipeline.json for model {model_id}: {str(e)}")
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    def get_model_file_path(self, model_id: int, filename: str) -> Optional[Path]:
 | 
			
		||||
        """
 | 
			
		||||
        Get the full path to a model file (e.g., .pt file)
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            model_id: The model ID
 | 
			
		||||
            filename: The filename within the model directory
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Full path to the model file or None if not found
 | 
			
		||||
        """
 | 
			
		||||
        model_path = self.get_model_path(model_id)
 | 
			
		||||
        if not model_path:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        file_path = model_path / filename
 | 
			
		||||
        if not file_path.exists():
 | 
			
		||||
            logger.error(f"Model file {filename} not found in model {model_id}")
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        return file_path
 | 
			
		||||
 | 
			
		||||
    def cleanup_model(self, model_id: int) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Remove a downloaded model to free up space
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            model_id: The model ID to remove
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            True if successful, False otherwise
 | 
			
		||||
        """
 | 
			
		||||
        if model_id not in self._downloaded_models:
 | 
			
		||||
            logger.warning(f"Model {model_id} not in downloaded models")
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            model_dir = self.models_dir / str(model_id)
 | 
			
		||||
            if model_dir.exists():
 | 
			
		||||
                import shutil
 | 
			
		||||
                shutil.rmtree(model_dir)
 | 
			
		||||
                logger.info(f"Removed model directory: {model_dir}")
 | 
			
		||||
 | 
			
		||||
            self._downloaded_models.discard(model_id)
 | 
			
		||||
            self._model_paths.pop(model_id, None)
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Failed to cleanup model {model_id}: {str(e)}")
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
    def get_all_downloaded_models(self) -> Set[int]:
 | 
			
		||||
        """
 | 
			
		||||
        Get a set of all downloaded model IDs
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Set of model IDs that are currently downloaded
 | 
			
		||||
        """
 | 
			
		||||
        return self._downloaded_models.copy()
 | 
			
		||||
							
								
								
									
										357
									
								
								core/models/pipeline.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										357
									
								
								core/models/pipeline.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,357 @@
 | 
			
		|||
"""
 | 
			
		||||
Pipeline Configuration Parser - Handles pipeline.json parsing and validation
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Dict, List, Any, Optional, Set
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
from enum import Enum
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ActionType(Enum):
 | 
			
		||||
    """Supported action types in pipeline"""
 | 
			
		||||
    REDIS_SAVE_IMAGE = "redis_save_image"
 | 
			
		||||
    REDIS_PUBLISH = "redis_publish"
 | 
			
		||||
    POSTGRESQL_UPDATE = "postgresql_update"
 | 
			
		||||
    POSTGRESQL_UPDATE_COMBINED = "postgresql_update_combined"
 | 
			
		||||
    POSTGRESQL_INSERT = "postgresql_insert"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class RedisConfig:
 | 
			
		||||
    """Redis connection configuration"""
 | 
			
		||||
    host: str
 | 
			
		||||
    port: int = 6379
 | 
			
		||||
    password: Optional[str] = None
 | 
			
		||||
    db: int = 0
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_dict(cls, data: Dict[str, Any]) -> 'RedisConfig':
 | 
			
		||||
        return cls(
 | 
			
		||||
            host=data['host'],
 | 
			
		||||
            port=data.get('port', 6379),
 | 
			
		||||
            password=data.get('password'),
 | 
			
		||||
            db=data.get('db', 0)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class PostgreSQLConfig:
 | 
			
		||||
    """PostgreSQL connection configuration"""
 | 
			
		||||
    host: str
 | 
			
		||||
    port: int
 | 
			
		||||
    database: str
 | 
			
		||||
    username: str
 | 
			
		||||
    password: str
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_dict(cls, data: Dict[str, Any]) -> 'PostgreSQLConfig':
 | 
			
		||||
        return cls(
 | 
			
		||||
            host=data['host'],
 | 
			
		||||
            port=data.get('port', 5432),
 | 
			
		||||
            database=data['database'],
 | 
			
		||||
            username=data['username'],
 | 
			
		||||
            password=data['password']
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Action:
 | 
			
		||||
    """Represents an action in the pipeline"""
 | 
			
		||||
    type: ActionType
 | 
			
		||||
    params: Dict[str, Any] = field(default_factory=dict)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_dict(cls, data: Dict[str, Any]) -> 'Action':
 | 
			
		||||
        action_type = ActionType(data['type'])
 | 
			
		||||
        params = {k: v for k, v in data.items() if k != 'type'}
 | 
			
		||||
        return cls(type=action_type, params=params)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ModelBranch:
 | 
			
		||||
    """Represents a branch in the pipeline with its own model"""
 | 
			
		||||
    model_id: str
 | 
			
		||||
    model_file: str
 | 
			
		||||
    trigger_classes: List[str]
 | 
			
		||||
    min_confidence: float = 0.5
 | 
			
		||||
    crop: bool = False
 | 
			
		||||
    crop_class: Optional[Any] = None  # Can be string or list
 | 
			
		||||
    parallel: bool = False
 | 
			
		||||
    actions: List[Action] = field(default_factory=list)
 | 
			
		||||
    branches: List['ModelBranch'] = field(default_factory=list)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_dict(cls, data: Dict[str, Any]) -> 'ModelBranch':
 | 
			
		||||
        actions = [Action.from_dict(a) for a in data.get('actions', [])]
 | 
			
		||||
        branches = [cls.from_dict(b) for b in data.get('branches', [])]
 | 
			
		||||
 | 
			
		||||
        return cls(
 | 
			
		||||
            model_id=data['modelId'],
 | 
			
		||||
            model_file=data['modelFile'],
 | 
			
		||||
            trigger_classes=data.get('triggerClasses', []),
 | 
			
		||||
            min_confidence=data.get('minConfidence', 0.5),
 | 
			
		||||
            crop=data.get('crop', False),
 | 
			
		||||
            crop_class=data.get('cropClass'),
 | 
			
		||||
            parallel=data.get('parallel', False),
 | 
			
		||||
            actions=actions,
 | 
			
		||||
            branches=branches
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class TrackingConfig:
 | 
			
		||||
    """Configuration for the tracking phase"""
 | 
			
		||||
    model_id: str
 | 
			
		||||
    model_file: str
 | 
			
		||||
    trigger_classes: List[str]
 | 
			
		||||
    min_confidence: float = 0.6
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_dict(cls, data: Dict[str, Any]) -> 'TrackingConfig':
 | 
			
		||||
        return cls(
 | 
			
		||||
            model_id=data['modelId'],
 | 
			
		||||
            model_file=data['modelFile'],
 | 
			
		||||
            trigger_classes=data.get('triggerClasses', []),
 | 
			
		||||
            min_confidence=data.get('minConfidence', 0.6)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class PipelineConfig:
 | 
			
		||||
    """Main pipeline configuration"""
 | 
			
		||||
    model_id: str
 | 
			
		||||
    model_file: str
 | 
			
		||||
    trigger_classes: List[str]
 | 
			
		||||
    min_confidence: float = 0.5
 | 
			
		||||
    crop: bool = False
 | 
			
		||||
    branches: List[ModelBranch] = field(default_factory=list)
 | 
			
		||||
    parallel_actions: List[Action] = field(default_factory=list)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_dict(cls, data: Dict[str, Any]) -> 'PipelineConfig':
 | 
			
		||||
        branches = [ModelBranch.from_dict(b) for b in data.get('branches', [])]
 | 
			
		||||
        parallel_actions = [Action.from_dict(a) for a in data.get('parallelActions', [])]
 | 
			
		||||
 | 
			
		||||
        return cls(
 | 
			
		||||
            model_id=data['modelId'],
 | 
			
		||||
            model_file=data['modelFile'],
 | 
			
		||||
            trigger_classes=data.get('triggerClasses', []),
 | 
			
		||||
            min_confidence=data.get('minConfidence', 0.5),
 | 
			
		||||
            crop=data.get('crop', False),
 | 
			
		||||
            branches=branches,
 | 
			
		||||
            parallel_actions=parallel_actions
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PipelineParser:
 | 
			
		||||
    """Parser for pipeline.json configuration files"""
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.redis_config: Optional[RedisConfig] = None
 | 
			
		||||
        self.postgresql_config: Optional[PostgreSQLConfig] = None
 | 
			
		||||
        self.tracking_config: Optional[TrackingConfig] = None
 | 
			
		||||
        self.pipeline_config: Optional[PipelineConfig] = None
 | 
			
		||||
        self._model_dependencies: Set[str] = set()
 | 
			
		||||
 | 
			
		||||
    def parse(self, config_path: Path) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Parse a pipeline.json configuration file
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            config_path: Path to the pipeline.json file
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            True if parsing was successful, False otherwise
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            if not config_path.exists():
 | 
			
		||||
                logger.error(f"Pipeline config not found: {config_path}")
 | 
			
		||||
                return False
 | 
			
		||||
 | 
			
		||||
            with open(config_path, 'r') as f:
 | 
			
		||||
                data = json.load(f)
 | 
			
		||||
 | 
			
		||||
            return self.parse_dict(data)
 | 
			
		||||
 | 
			
		||||
        except json.JSONDecodeError as e:
 | 
			
		||||
            logger.error(f"Invalid JSON in pipeline config: {str(e)}")
 | 
			
		||||
            return False
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Failed to parse pipeline config: {str(e)}", exc_info=True)
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
    def parse_dict(self, data: Dict[str, Any]) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Parse a pipeline configuration from a dictionary
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            data: The configuration dictionary
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            True if parsing was successful, False otherwise
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            # Parse Redis configuration
 | 
			
		||||
            if 'redis' in data:
 | 
			
		||||
                self.redis_config = RedisConfig.from_dict(data['redis'])
 | 
			
		||||
                logger.debug(f"Parsed Redis config: {self.redis_config.host}:{self.redis_config.port}")
 | 
			
		||||
 | 
			
		||||
            # Parse PostgreSQL configuration
 | 
			
		||||
            if 'postgresql' in data:
 | 
			
		||||
                self.postgresql_config = PostgreSQLConfig.from_dict(data['postgresql'])
 | 
			
		||||
                logger.debug(f"Parsed PostgreSQL config: {self.postgresql_config.host}:{self.postgresql_config.port}/{self.postgresql_config.database}")
 | 
			
		||||
 | 
			
		||||
            # Parse tracking configuration
 | 
			
		||||
            if 'tracking' in data:
 | 
			
		||||
                self.tracking_config = TrackingConfig.from_dict(data['tracking'])
 | 
			
		||||
                self._model_dependencies.add(self.tracking_config.model_file)
 | 
			
		||||
                logger.debug(f"Parsed tracking config: {self.tracking_config.model_id}")
 | 
			
		||||
 | 
			
		||||
            # Parse main pipeline configuration
 | 
			
		||||
            if 'pipeline' in data:
 | 
			
		||||
                self.pipeline_config = PipelineConfig.from_dict(data['pipeline'])
 | 
			
		||||
                self._collect_model_dependencies(self.pipeline_config)
 | 
			
		||||
                logger.debug(f"Parsed pipeline config: {self.pipeline_config.model_id}")
 | 
			
		||||
 | 
			
		||||
            logger.info(f"Successfully parsed pipeline configuration")
 | 
			
		||||
            logger.debug(f"Model dependencies: {self._model_dependencies}")
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        except KeyError as e:
 | 
			
		||||
            logger.error(f"Missing required field in pipeline config: {str(e)}")
 | 
			
		||||
            return False
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Failed to parse pipeline config: {str(e)}", exc_info=True)
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
    def _collect_model_dependencies(self, config: Any) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Recursively collect all model file dependencies
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            config: Pipeline or branch configuration
 | 
			
		||||
        """
 | 
			
		||||
        if hasattr(config, 'model_file'):
 | 
			
		||||
            self._model_dependencies.add(config.model_file)
 | 
			
		||||
 | 
			
		||||
        if hasattr(config, 'branches'):
 | 
			
		||||
            for branch in config.branches:
 | 
			
		||||
                self._collect_model_dependencies(branch)
 | 
			
		||||
 | 
			
		||||
    def get_model_dependencies(self) -> Set[str]:
 | 
			
		||||
        """
 | 
			
		||||
        Get all model file dependencies from the pipeline
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Set of model filenames required by the pipeline
 | 
			
		||||
        """
 | 
			
		||||
        return self._model_dependencies.copy()
 | 
			
		||||
 | 
			
		||||
    def validate(self) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Validate the parsed configuration
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            True if configuration is valid, False otherwise
 | 
			
		||||
        """
 | 
			
		||||
        if not self.pipeline_config:
 | 
			
		||||
            logger.error("No pipeline configuration found")
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        # Check that all required model files are specified
 | 
			
		||||
        if not self.pipeline_config.model_file:
 | 
			
		||||
            logger.error("Main pipeline model file not specified")
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        # Validate action configurations
 | 
			
		||||
        if not self._validate_actions(self.pipeline_config):
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        # Validate parallel actions
 | 
			
		||||
        for action in self.pipeline_config.parallel_actions:
 | 
			
		||||
            if action.type == ActionType.POSTGRESQL_UPDATE_COMBINED:
 | 
			
		||||
                wait_for = action.params.get('waitForBranches', [])
 | 
			
		||||
                if wait_for:
 | 
			
		||||
                    # Check that referenced branches exist
 | 
			
		||||
                    branch_ids = self._get_all_branch_ids(self.pipeline_config)
 | 
			
		||||
                    for branch_id in wait_for:
 | 
			
		||||
                        if branch_id not in branch_ids:
 | 
			
		||||
                            logger.error(f"Referenced branch '{branch_id}' in waitForBranches not found")
 | 
			
		||||
                            return False
 | 
			
		||||
 | 
			
		||||
        logger.info("Pipeline configuration validated successfully")
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def _validate_actions(self, config: Any) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Validate actions in a pipeline or branch configuration
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            config: Pipeline or branch configuration
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            True if valid, False otherwise
 | 
			
		||||
        """
 | 
			
		||||
        if hasattr(config, 'actions'):
 | 
			
		||||
            for action in config.actions:
 | 
			
		||||
                # Validate Redis actions need Redis config
 | 
			
		||||
                if action.type in [ActionType.REDIS_SAVE_IMAGE, ActionType.REDIS_PUBLISH]:
 | 
			
		||||
                    if not self.redis_config:
 | 
			
		||||
                        logger.error(f"Action {action.type} requires Redis configuration")
 | 
			
		||||
                        return False
 | 
			
		||||
 | 
			
		||||
                # Validate PostgreSQL actions need PostgreSQL config
 | 
			
		||||
                if action.type in [ActionType.POSTGRESQL_UPDATE, ActionType.POSTGRESQL_UPDATE_COMBINED, ActionType.POSTGRESQL_INSERT]:
 | 
			
		||||
                    if not self.postgresql_config:
 | 
			
		||||
                        logger.error(f"Action {action.type} requires PostgreSQL configuration")
 | 
			
		||||
                        return False
 | 
			
		||||
 | 
			
		||||
        # Recursively validate branches
 | 
			
		||||
        if hasattr(config, 'branches'):
 | 
			
		||||
            for branch in config.branches:
 | 
			
		||||
                if not self._validate_actions(branch):
 | 
			
		||||
                    return False
 | 
			
		||||
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def _get_all_branch_ids(self, config: Any, branch_ids: Set[str] = None) -> Set[str]:
 | 
			
		||||
        """
 | 
			
		||||
        Recursively collect all branch model IDs
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            config: Pipeline or branch configuration
 | 
			
		||||
            branch_ids: Set to collect IDs into
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Set of all branch model IDs
 | 
			
		||||
        """
 | 
			
		||||
        if branch_ids is None:
 | 
			
		||||
            branch_ids = set()
 | 
			
		||||
 | 
			
		||||
        if hasattr(config, 'branches'):
 | 
			
		||||
            for branch in config.branches:
 | 
			
		||||
                branch_ids.add(branch.model_id)
 | 
			
		||||
                self._get_all_branch_ids(branch, branch_ids)
 | 
			
		||||
 | 
			
		||||
        return branch_ids
 | 
			
		||||
 | 
			
		||||
    def get_redis_config(self) -> Optional[RedisConfig]:
 | 
			
		||||
        """Get the Redis configuration"""
 | 
			
		||||
        return self.redis_config
 | 
			
		||||
 | 
			
		||||
    def get_postgresql_config(self) -> Optional[PostgreSQLConfig]:
 | 
			
		||||
        """Get the PostgreSQL configuration"""
 | 
			
		||||
        return self.postgresql_config
 | 
			
		||||
 | 
			
		||||
    def get_tracking_config(self) -> Optional[TrackingConfig]:
 | 
			
		||||
        """Get the tracking configuration"""
 | 
			
		||||
        return self.tracking_config
 | 
			
		||||
 | 
			
		||||
    def get_pipeline_config(self) -> Optional[PipelineConfig]:
 | 
			
		||||
        """Get the main pipeline configuration"""
 | 
			
		||||
        return self.pipeline_config
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue