diff --git a/detector_worker/detection/stability_validator.py b/detector_worker/detection/stability_validator.py new file mode 100644 index 0000000..898216e --- /dev/null +++ b/detector_worker/detection/stability_validator.py @@ -0,0 +1,483 @@ +""" +Detection stability validation and lightweight detection functionality. + +This module provides validation functionality for detection pipelines including +pipeline execution pre-validation and lightweight detection for validation phases. +""" + +import logging +from typing import Dict, List, Any, Optional, Tuple, Union +from dataclasses import dataclass, field +import numpy as np + +from ..core.constants import ( + LIGHTWEIGHT_DETECTION_MIN_CONFIDENCE, + LIGHTWEIGHT_DETECTION_MIN_BBOX_AREA_RATIO, + DEFAULT_MIN_CONFIDENCE +) +from ..core.exceptions import ValidationError, create_detection_error +from ..detection.detection_result import LightweightDetectionResult, BoundingBox + +logger = logging.getLogger(__name__) + + +@dataclass +class PipelineValidationResult: + """Result of pipeline execution validation.""" + is_valid: bool + missing_branches: List[str] = field(default_factory=list) + required_branches: List[str] = field(default_factory=list) + available_classes: List[str] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return { + "is_valid": self.is_valid, + "missing_branches": self.missing_branches.copy(), + "required_branches": self.required_branches.copy(), + "available_classes": self.available_classes.copy() + } + + +@dataclass +class LightweightDetectionResult: + """Result from lightweight detection.""" + car_detected: bool + best_detection: Optional[Dict[str, Any]] = None + validation_passed: bool = False + confidence: Optional[float] = None + bbox_area_ratio: Optional[float] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + result = { + "car_detected": self.car_detected, + "validation_passed": self.validation_passed + } + + if self.best_detection: + result["best_detection"] = self.best_detection.copy() + if self.confidence is not None: + result["confidence"] = self.confidence + if self.bbox_area_ratio is not None: + result["bbox_area_ratio"] = self.bbox_area_ratio + + return result + + +class StabilityValidator: + """ + Validates detection stability and pipeline execution requirements. + + This class provides functionality for: + - Pipeline execution pre-validation + - Lightweight detection for validation phases + - Detection threshold validation + """ + + def __init__(self): + """Initialize stability validator.""" + pass + + def validate_pipeline_execution(self, + node: Dict[str, Any], + regions_dict: Dict[str, Any]) -> PipelineValidationResult: + """ + Pre-validate that all required branches will execute successfully before + committing to Redis actions and database records. + + Args: + node: Pipeline node configuration + regions_dict: Dictionary mapping class names to detection regions + + Returns: + PipelineValidationResult with validation status and details + """ + # Get all branches that parallel actions are waiting for + required_branches = set() + + for action in node.get("parallelActions", []): + if action.get("type") == "postgresql_update_combined": + wait_for_branches = action.get("waitForBranches", []) + required_branches.update(wait_for_branches) + + if not required_branches: + # No parallel actions requiring specific branches + logger.debug("No parallel actions with waitForBranches - validation passes") + return PipelineValidationResult( + is_valid=True, + required_branches=[], + available_classes=list(regions_dict.keys()) + ) + + logger.debug(f"Pre-validation: checking if required branches {list(required_branches)} will execute") + + # Check each required branch + missing_branches = [] + + for branch in node.get("branches", []): + branch_id = branch["modelId"] + + if branch_id not in required_branches: + continue # This branch is not required by parallel actions + + # Check if this branch would be triggered + trigger_classes = branch.get("triggerClasses", []) + min_conf = branch.get("minConfidence", 0) + + branch_triggered = False + for det_class in regions_dict: + det_confidence = regions_dict[det_class]["confidence"] + + if (det_class in trigger_classes and det_confidence >= min_conf): + branch_triggered = True + logger.debug(f"Pre-validation: branch {branch_id} WILL be triggered by {det_class} (conf={det_confidence:.3f} >= {min_conf})") + break + + if not branch_triggered: + missing_branches.append(branch_id) + logger.warning(f"Pre-validation: branch {branch_id} will NOT be triggered - no matching classes or insufficient confidence") + logger.debug(f" Required: {trigger_classes} with min_conf={min_conf}") + logger.debug(f" Available: {[(cls, regions_dict[cls]['confidence']) for cls in regions_dict]}") + + is_valid = len(missing_branches) == 0 + + if missing_branches: + logger.error(f"Pipeline pre-validation FAILED: required branches {missing_branches} will not execute") + else: + logger.info(f"Pipeline pre-validation PASSED: all required branches {list(required_branches)} will execute") + + return PipelineValidationResult( + is_valid=is_valid, + missing_branches=missing_branches, + required_branches=list(required_branches), + available_classes=list(regions_dict.keys()) + ) + + def _apply_detection_filters(self, + box: Any, + model: Any, + min_confidence: float, + trigger_classes: Optional[List[str]] = None, + trigger_class_indices: Optional[List[int]] = None) -> Optional[Dict[str, Any]]: + """Apply confidence and class filtering to a detection box.""" + try: + # Extract detection info + xyxy = box.xyxy[0].cpu().numpy() + conf = box.conf[0].cpu().numpy() + cls_id = int(box.cls[0].cpu().numpy()) + class_name = model.names[cls_id] + + # Apply confidence threshold + if conf < min_confidence: + return None + + # Apply trigger class filtering if specified + if trigger_class_indices and cls_id not in trigger_class_indices: + return None + if trigger_classes and class_name not in trigger_classes: + return None + + return { + "class": class_name, + "confidence": float(conf), + "bbox": [int(x) for x in xyxy], + "class_id": cls_id, + "xyxy": xyxy + } + + except Exception as e: + logger.error(f"Error processing detection box: {e}") + return None + + def run_lightweight_detection_with_validation(self, + frame: np.ndarray, + node: Dict[str, Any], + min_confidence: float = LIGHTWEIGHT_DETECTION_MIN_CONFIDENCE, + min_bbox_area_ratio: float = LIGHTWEIGHT_DETECTION_MIN_BBOX_AREA_RATIO) -> Dict[str, Any]: + """ + Run lightweight detection with validation rules for session ID triggering. + Returns detection info only if it passes validation thresholds. + + Args: + frame: Input frame for detection + node: Pipeline node configuration + min_confidence: Minimum confidence threshold for validation + min_bbox_area_ratio: Minimum bounding box area ratio for validation + + Returns: + Dictionary with validation results and detection info + """ + logger.debug(f"Running lightweight detection with validation: {node['modelId']} (conf>={min_confidence}, bbox_area>={min_bbox_area_ratio})") + + try: + # Run basic detection only - no branches, no actions + model = node["model"] + trigger_classes = node.get("triggerClasses", []) + trigger_class_indices = node.get("triggerClassIndices") + + # Run YOLO inference + res = model(frame, verbose=False) + + best_detection = None + frame_height, frame_width = frame.shape[:2] + frame_area = frame_height * frame_width + + for r in res: + boxes = r.boxes + if boxes is None or len(boxes) == 0: + continue + + for box in boxes: + detection = self._apply_detection_filters( + box, model, min_confidence, trigger_classes, trigger_class_indices + ) + + if not detection: + continue + + # Calculate bbox area ratio + x1, y1, x2, y2 = detection["xyxy"] + bbox_area = (x2 - x1) * (y2 - y1) + bbox_area_ratio = bbox_area / frame_area if frame_area > 0 else 0 + + # Apply bbox area threshold + if bbox_area_ratio < min_bbox_area_ratio: + logger.debug(f"Detection filtered out: bbox_area_ratio={bbox_area_ratio:.3f} < {min_bbox_area_ratio}") + continue + + # Validation passed + detection["bbox_area_ratio"] = float(bbox_area_ratio) + detection["validation_passed"] = True + + if not best_detection or detection["confidence"] > best_detection["confidence"]: + best_detection = detection + + if best_detection: + logger.debug(f"Validation PASSED: {best_detection['class']} (conf: {best_detection['confidence']:.3f}, area: {best_detection['bbox_area_ratio']:.3f})") + return best_detection + else: + logger.debug(f"Validation FAILED: No detection meets criteria (conf>={min_confidence}, area>={min_bbox_area_ratio})") + return {"validation_passed": False} + + except Exception as e: + logger.error(f"Error in lightweight detection with validation: {str(e)}", exc_info=True) + return {"validation_passed": False} + + def run_lightweight_detection(self, + frame: np.ndarray, + node: Dict[str, Any]) -> LightweightDetectionResult: + """ + Run lightweight detection for car presence validation only. + Returns basic detection info without running branches or external actions. + + Args: + frame: Input frame for detection + node: Pipeline node configuration + + Returns: + LightweightDetectionResult with car detection status + """ + logger.debug(f"Running lightweight detection: {node['modelId']}") + + try: + # Run basic detection only - no branches, no actions + model = node["model"] + min_confidence = node.get("minConfidence", DEFAULT_MIN_CONFIDENCE) + trigger_classes = node.get("triggerClasses", []) + trigger_class_indices = node.get("triggerClassIndices") + + # Run YOLO inference + res = model(frame, verbose=False) + + car_detected = False + best_detection = None + + for r in res: + boxes = r.boxes + if boxes is None or len(boxes) == 0: + continue + + for box in boxes: + detection = self._apply_detection_filters( + box, model, min_confidence, trigger_classes, trigger_class_indices + ) + + if not detection: + continue + + # Car detected + car_detected = True + if not best_detection or detection["confidence"] > best_detection["confidence"]: + best_detection = detection + + logger.debug(f"Lightweight detection result: car_detected={car_detected}") + if best_detection: + logger.debug(f"Best detection: {best_detection['class']} (conf: {best_detection['confidence']:.3f})") + + return LightweightDetectionResult( + car_detected=car_detected, + best_detection=best_detection, + confidence=best_detection["confidence"] if best_detection else None + ) + + except Exception as e: + logger.error(f"Error in lightweight detection: {str(e)}", exc_info=True) + return LightweightDetectionResult(car_detected=False, best_detection=None) + + def validate_detection_thresholds(self, + detection: Dict[str, Any], + frame_shape: Tuple[int, int, int], + min_confidence: float, + min_bbox_area_ratio: float) -> bool: + """ + Validate detection against confidence and bbox area thresholds. + + Args: + detection: Detection dictionary with confidence and bbox + frame_shape: Frame dimensions (height, width, channels) + min_confidence: Minimum confidence threshold + min_bbox_area_ratio: Minimum bbox area ratio threshold + + Returns: + True if detection passes all validation thresholds + """ + try: + # Check confidence + confidence = detection.get("confidence", 0.0) + if confidence < min_confidence: + logger.debug(f"Detection failed confidence threshold: {confidence:.3f} < {min_confidence}") + return False + + # Check bbox area ratio + bbox = detection.get("bbox", []) + if len(bbox) >= 4: + x1, y1, x2, y2 = bbox[:4] + bbox_area = (x2 - x1) * (y2 - y1) + + frame_height, frame_width = frame_shape[:2] + frame_area = frame_height * frame_width + bbox_area_ratio = bbox_area / frame_area if frame_area > 0 else 0 + + if bbox_area_ratio < min_bbox_area_ratio: + logger.debug(f"Detection failed bbox area threshold: {bbox_area_ratio:.3f} < {min_bbox_area_ratio}") + return False + + logger.debug(f"Detection passed validation thresholds: conf={confidence:.3f}, area_ratio={bbox_area_ratio:.3f}") + return True + + except Exception as e: + logger.error(f"Error validating detection thresholds: {e}") + return False + + def check_branch_execution_requirements(self, + branch: Dict[str, Any], + regions_dict: Dict[str, Any]) -> bool: + """ + Check if a branch will be executed based on trigger classes and confidence. + + Args: + branch: Branch configuration + regions_dict: Available detection regions + + Returns: + True if branch will be executed + """ + trigger_classes = branch.get("triggerClasses", []) + min_conf = branch.get("minConfidence", 0) + + for det_class in regions_dict: + det_confidence = regions_dict[det_class]["confidence"] + + if det_class in trigger_classes and det_confidence >= min_conf: + return True + + return False + + def get_validation_summary(self, + node: Dict[str, Any], + regions_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Get comprehensive validation summary for a pipeline node. + + Args: + node: Pipeline node configuration + regions_dict: Available detection regions + + Returns: + Dictionary with validation summary information + """ + summary = { + "model_id": node.get("modelId", "unknown"), + "available_classes": list(regions_dict.keys()), + "branch_count": len(node.get("branches", [])), + "parallel_action_count": len(node.get("parallelActions", [])), + "branches_that_will_execute": [], + "branches_that_will_not_execute": [] + } + + # Check each branch + for branch in node.get("branches", []): + branch_id = branch["modelId"] + will_execute = self.check_branch_execution_requirements(branch, regions_dict) + + if will_execute: + summary["branches_that_will_execute"].append(branch_id) + else: + summary["branches_that_will_not_execute"].append(branch_id) + + # Pipeline validation + pipeline_validation = self.validate_pipeline_execution(node, regions_dict) + summary["pipeline_validation"] = pipeline_validation.to_dict() + + return summary + + +# Global stability validator instance +stability_validator = StabilityValidator() + + +# ===== CONVENIENCE FUNCTIONS ===== +# These provide the same interface as the original functions in pympta.py + +def validate_pipeline_execution(node: Dict[str, Any], regions_dict: Dict[str, Any]) -> Tuple[bool, List[str]]: + """ + Pre-validate that all required branches will execute successfully. + + Returns: + - (True, []) if pipeline can execute completely + - (False, missing_branches) if some required branches won't execute + """ + result = stability_validator.validate_pipeline_execution(node, regions_dict) + return result.is_valid, result.missing_branches + + +def run_lightweight_detection_with_validation(frame: np.ndarray, + node: Dict[str, Any], + min_confidence: float = LIGHTWEIGHT_DETECTION_MIN_CONFIDENCE, + min_bbox_area_ratio: float = LIGHTWEIGHT_DETECTION_MIN_BBOX_AREA_RATIO) -> Dict[str, Any]: + """Run lightweight detection with validation rules for session ID triggering.""" + return stability_validator.run_lightweight_detection_with_validation( + frame, node, min_confidence, min_bbox_area_ratio + ) + + +def run_lightweight_detection(frame: np.ndarray, node: Dict[str, Any]) -> Dict[str, Any]: + """Run lightweight detection for car presence validation only.""" + result = stability_validator.run_lightweight_detection(frame, node) + return result.to_dict() + + +def validate_detection_thresholds(detection: Dict[str, Any], + frame_shape: Tuple[int, int, int], + min_confidence: float, + min_bbox_area_ratio: float) -> bool: + """Validate detection against confidence and bbox area thresholds.""" + return stability_validator.validate_detection_thresholds( + detection, frame_shape, min_confidence, min_bbox_area_ratio + ) + + +def get_validation_summary(node: Dict[str, Any], regions_dict: Dict[str, Any]) -> Dict[str, Any]: + """Get comprehensive validation summary for a pipeline node.""" + return stability_validator.get_validation_summary(node, regions_dict) \ No newline at end of file diff --git a/detector_worker/detection/tracking_manager.py b/detector_worker/detection/tracking_manager.py new file mode 100644 index 0000000..244241a --- /dev/null +++ b/detector_worker/detection/tracking_manager.py @@ -0,0 +1,481 @@ +""" +Object tracking management and state handling. + +This module provides tracking state management, session handling, +and occupancy detection functionality for the detection pipeline. +""" + +import time +import logging +from typing import Dict, List, Any, Optional, Tuple, Set +from dataclasses import dataclass, field +from datetime import datetime + +from ..core.constants import SESSION_TIMEOUT_SECONDS +from ..core.exceptions import TrackingError, create_detection_error + +logger = logging.getLogger(__name__) + + +@dataclass +class SessionState: + """Session state for tracking management.""" + active: bool = True + waiting_for_backend_session: bool = False + wait_start_time: float = 0.0 + reset_tracker_on_resume: bool = False + occupancy_mode: bool = False + occupancy_enabled_at: Optional[float] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return { + "active": self.active, + "waiting_for_backend_session": self.waiting_for_backend_session, + "wait_start_time": self.wait_start_time, + "reset_tracker_on_resume": self.reset_tracker_on_resume, + "occupancy_mode": self.occupancy_mode, + "occupancy_enabled_at": self.occupancy_enabled_at + } + + def from_dict(self, data: Dict[str, Any]) -> None: + """Update from dictionary data.""" + self.active = data.get("active", True) + self.waiting_for_backend_session = data.get("waiting_for_backend_session", False) + self.wait_start_time = data.get("wait_start_time", 0.0) + self.reset_tracker_on_resume = data.get("reset_tracker_on_resume", False) + self.occupancy_mode = data.get("occupancy_mode", False) + self.occupancy_enabled_at = data.get("occupancy_enabled_at") + + +@dataclass +class TrackingData: + """Complete tracking data for a camera and model.""" + track_stability_counters: Dict[int, int] = field(default_factory=dict) + stable_tracks: Set[int] = field(default_factory=set) + session_state: SessionState = field(default_factory=SessionState) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return { + "track_stability_counters": dict(self.track_stability_counters), + "stable_tracks": list(self.stable_tracks), + "session_state": self.session_state.to_dict() + } + + +class TrackingManager: + """ + Manages object tracking state and session handling across cameras and models. + + This class provides centralized tracking state management including: + - Track stability counters + - Session state management + - Backend session timeout handling + - Occupancy detection mode + """ + + def __init__(self, session_timeout_seconds: int = SESSION_TIMEOUT_SECONDS): + """ + Initialize tracking manager. + + Args: + session_timeout_seconds: Timeout for backend session waiting + """ + self.session_timeout_seconds = session_timeout_seconds + self._camera_tracking_data: Dict[str, Dict[str, TrackingData]] = {} + self._lock = None + + def _ensure_thread_safety(self): + """Initialize thread safety if not already done.""" + if self._lock is None: + import threading + self._lock = threading.RLock() + + def get_camera_tracking_data(self, camera_id: str, model_id: str) -> TrackingData: + """ + Get or create tracking data for a specific camera and model. + + Args: + camera_id: Unique camera identifier + model_id: Model identifier + + Returns: + Tracking data for the camera and model + """ + self._ensure_thread_safety() + + with self._lock: + if camera_id not in self._camera_tracking_data: + self._camera_tracking_data[camera_id] = {} + + if model_id not in self._camera_tracking_data[camera_id]: + logger.warning(f"๐Ÿ”„ Camera {camera_id}: Creating NEW tracking data for {model_id} - this will reset any cooldown!") + self._camera_tracking_data[camera_id][model_id] = TrackingData() + + return self._camera_tracking_data[camera_id][model_id] + + def check_stable_tracks(self, + camera_id: str, + model_id: str, + regions_dict: Dict[str, Any]) -> Tuple[bool, List[Dict[str, Any]]]: + """ + Check if any stable tracks match the detected classes for a specific camera. + + Args: + camera_id: Unique camera identifier + model_id: Model identifier + regions_dict: Dictionary mapping class names to detection regions + + Returns: + Tuple of (has_stable_tracks, stable_detections) + """ + self._ensure_thread_safety() + + with self._lock: + tracking_data = self.get_camera_tracking_data(camera_id, model_id) + stable_tracks = tracking_data.stable_tracks + + if not stable_tracks: + return False, [] + + # Check for track-based stability + stable_detections = [] + + for class_name, region_data in regions_dict.items(): + detection = region_data.get("detection", {}) + track_id = detection.get("id") + + if track_id in stable_tracks: + stable_detections.append(detection) + + has_stable = len(stable_detections) > 0 + logger.debug(f"Camera {camera_id}: Stable track check - stable_tracks: {list(stable_tracks)}, found: {has_stable}") + + return has_stable, stable_detections + + def reset_tracking_state(self, + camera_id: str, + model_id: str, + reason: str = "session ended") -> None: + """ + Reset tracking state after session completion or timeout. + + Args: + camera_id: Unique camera identifier + model_id: Model identifier + reason: Reason for reset (for logging) + """ + self._ensure_thread_safety() + + with self._lock: + tracking_data = self.get_camera_tracking_data(camera_id, model_id) + session_state = tracking_data.session_state + + # Clear all tracking data for fresh start + tracking_data.track_stability_counters.clear() + tracking_data.stable_tracks.clear() + session_state.active = True + session_state.waiting_for_backend_session = False + session_state.wait_start_time = 0.0 + session_state.reset_tracker_on_resume = True + + logger.info(f"Camera {camera_id}: ๐Ÿ”„ Reset tracking state - {reason}") + logger.info(f"Camera {camera_id}: ๐Ÿงน Cleared stability counters and stable tracks for fresh session") + + def is_camera_active(self, camera_id: str, model_id: str) -> bool: + """ + Check if camera should be processing detections. + + Args: + camera_id: Unique camera identifier + model_id: Model identifier + + Returns: + True if camera should process detections, False otherwise + """ + self._ensure_thread_safety() + + with self._lock: + tracking_data = self.get_camera_tracking_data(camera_id, model_id) + session_state = tracking_data.session_state + + # Check if waiting for backend sessionId has timed out + if session_state.waiting_for_backend_session: + current_time = time.time() + wait_start_time = session_state.wait_start_time + elapsed_time = current_time - wait_start_time + + if elapsed_time >= self.session_timeout_seconds: + logger.warning(f"Camera {camera_id}: Backend sessionId timeout ({self.session_timeout_seconds}s) - resetting tracking") + self.reset_tracking_state(camera_id, model_id, "backend sessionId timeout") + return True + else: + remaining_time = self.session_timeout_seconds - elapsed_time + logger.debug(f"Camera {camera_id}: Still waiting for backend sessionId ({remaining_time:.1f}s remaining)") + return False + + return session_state.active + + def set_waiting_for_backend_session(self, + camera_id: str, + model_id: str, + waiting: bool = True) -> None: + """ + Set whether camera is waiting for backend session ID. + + Args: + camera_id: Unique camera identifier + model_id: Model identifier + waiting: Whether to wait for backend session + """ + self._ensure_thread_safety() + + with self._lock: + tracking_data = self.get_camera_tracking_data(camera_id, model_id) + session_state = tracking_data.session_state + + session_state.waiting_for_backend_session = waiting + if waiting: + session_state.wait_start_time = time.time() + logger.debug(f"Camera {camera_id}: Started waiting for backend sessionId") + else: + session_state.wait_start_time = 0.0 + logger.debug(f"Camera {camera_id}: Stopped waiting for backend sessionId") + + def cleanup_camera_tracking(self, camera_id: str) -> None: + """ + Clean up tracking data when a camera is disconnected. + + Args: + camera_id: Unique camera identifier + """ + self._ensure_thread_safety() + + with self._lock: + if camera_id in self._camera_tracking_data: + del self._camera_tracking_data[camera_id] + logger.info(f"Cleaned up tracking data for camera {camera_id}") + + def set_occupancy_mode(self, + camera_id: str, + model_id: str, + enable: bool = True) -> bool: + """ + Enable or disable occupancy detection mode. + + Occupancy mode stops model inference after pipeline completion + while backend session handling continues in background. + + Args: + camera_id: Unique camera identifier + model_id: Model identifier + enable: True to enable occupancy mode, False to disable + + Returns: + Current occupancy mode state + """ + self._ensure_thread_safety() + + with self._lock: + tracking_data = self.get_camera_tracking_data(camera_id, model_id) + session_state = tracking_data.session_state + + if enable: + session_state.occupancy_mode = True + session_state.occupancy_enabled_at = time.time() + logger.debug(f"Camera {camera_id}: Occupancy mode ENABLED - model will stop after pipeline completion") + else: + session_state.occupancy_mode = False + session_state.occupancy_enabled_at = None + logger.debug(f"Camera {camera_id}: Occupancy mode DISABLED - model will continue running") + + return session_state.occupancy_mode + + def is_occupancy_mode_enabled(self, camera_id: str, model_id: str) -> bool: + """ + Check if occupancy mode is enabled for a camera. + + Args: + camera_id: Unique camera identifier + model_id: Model identifier + + Returns: + True if occupancy mode is enabled + """ + self._ensure_thread_safety() + + with self._lock: + tracking_data = self.get_camera_tracking_data(camera_id, model_id) + return tracking_data.session_state.occupancy_mode + + def get_occupancy_duration(self, camera_id: str, model_id: str) -> Optional[float]: + """ + Get duration since occupancy mode was enabled. + + Args: + camera_id: Unique camera identifier + model_id: Model identifier + + Returns: + Duration in seconds since occupancy mode was enabled, or None if not enabled + """ + self._ensure_thread_safety() + + with self._lock: + tracking_data = self.get_camera_tracking_data(camera_id, model_id) + session_state = tracking_data.session_state + + if not session_state.occupancy_mode or not session_state.occupancy_enabled_at: + return None + + return time.time() - session_state.occupancy_enabled_at + + def get_session_state(self, camera_id: str, model_id: str) -> SessionState: + """ + Get session state for a camera and model. + + Args: + camera_id: Unique camera identifier + model_id: Model identifier + + Returns: + Session state object + """ + self._ensure_thread_safety() + + with self._lock: + tracking_data = self.get_camera_tracking_data(camera_id, model_id) + return tracking_data.session_state + + def update_session_state(self, + camera_id: str, + model_id: str, + **kwargs) -> None: + """ + Update session state properties. + + Args: + camera_id: Unique camera identifier + model_id: Model identifier + **kwargs: Session state properties to update + """ + self._ensure_thread_safety() + + with self._lock: + tracking_data = self.get_camera_tracking_data(camera_id, model_id) + session_state = tracking_data.session_state + + for key, value in kwargs.items(): + if hasattr(session_state, key): + setattr(session_state, key, value) + logger.debug(f"Camera {camera_id}: Updated session state {key} = {value}") + else: + logger.warning(f"Camera {camera_id}: Unknown session state property: {key}") + + def get_tracking_statistics(self, camera_id: str, model_id: str) -> Dict[str, Any]: + """ + Get comprehensive tracking statistics for a camera and model. + + Args: + camera_id: Unique camera identifier + model_id: Model identifier + + Returns: + Dictionary with tracking statistics + """ + self._ensure_thread_safety() + + with self._lock: + tracking_data = self.get_camera_tracking_data(camera_id, model_id) + current_time = time.time() + + stats = tracking_data.to_dict() + + # Add computed statistics + stats["total_tracked_objects"] = len(tracking_data.track_stability_counters) + stats["stable_track_count"] = len(tracking_data.stable_tracks) + + session_state = tracking_data.session_state + if session_state.waiting_for_backend_session and session_state.wait_start_time > 0: + stats["backend_wait_duration"] = current_time - session_state.wait_start_time + stats["backend_wait_remaining"] = max(0, self.session_timeout_seconds - stats["backend_wait_duration"]) + + if session_state.occupancy_mode and session_state.occupancy_enabled_at: + stats["occupancy_duration"] = current_time - session_state.occupancy_enabled_at + + return stats + + def get_all_camera_stats(self) -> Dict[str, Dict[str, Dict[str, Any]]]: + """ + Get tracking statistics for all monitored cameras. + + Returns: + Nested dictionary: {camera_id: {model_id: stats}} + """ + self._ensure_thread_safety() + + with self._lock: + all_stats = {} + for camera_id, models in self._camera_tracking_data.items(): + all_stats[camera_id] = {} + for model_id in models.keys(): + all_stats[camera_id][model_id] = self.get_tracking_statistics(camera_id, model_id) + + return all_stats + + +# Global tracking manager instance +tracking_manager = TrackingManager() + + +# ===== CONVENIENCE FUNCTIONS ===== +# These provide the same interface as the original functions in pympta.py + +def check_stable_tracks(camera_id: str, model_id: str, regions_dict: Dict[str, Any]) -> Tuple[bool, List[Dict[str, Any]]]: + """Check if any stable tracks match the detected classes for a specific camera.""" + return tracking_manager.check_stable_tracks(camera_id, model_id, regions_dict) + + +def reset_tracking_state(camera_id: str, model_id: str, reason: str = "session ended") -> None: + """Reset tracking state after session completion or timeout.""" + tracking_manager.reset_tracking_state(camera_id, model_id, reason) + + +def is_camera_active(camera_id: str, model_id: str) -> bool: + """Check if camera should be processing detections.""" + return tracking_manager.is_camera_active(camera_id, model_id) + + +def cleanup_camera_stability(camera_id: str) -> None: + """Clean up tracking data when a camera is disconnected.""" + tracking_manager.cleanup_camera_tracking(camera_id) + + +def occupancy_detector(camera_id: str, model_id: str, enable: bool = True) -> bool: + """ + Enable or disable occupancy detection mode. + + When enabled: + - Model stops inference after completing full pipeline + - Backend sessionId handling continues in background + + Note: This is a temporary function that will be changed in the future. + """ + return tracking_manager.set_occupancy_mode(camera_id, model_id, enable) + + +def get_camera_tracking_data(camera_id: str, model_id: str) -> Dict[str, Any]: + """Get tracking data for a specific camera and model.""" + tracking_data = tracking_manager.get_camera_tracking_data(camera_id, model_id) + return tracking_data.to_dict() + + +def set_waiting_for_backend_session(camera_id: str, model_id: str, waiting: bool = True) -> None: + """Set whether camera is waiting for backend session ID.""" + tracking_manager.set_waiting_for_backend_session(camera_id, model_id, waiting) + + +def get_tracking_statistics(camera_id: str, model_id: str) -> Dict[str, Any]: + """Get comprehensive tracking statistics for a camera and model.""" + return tracking_manager.get_tracking_statistics(camera_id, model_id) \ No newline at end of file diff --git a/detector_worker/detection/yolo_detector.py b/detector_worker/detection/yolo_detector.py new file mode 100644 index 0000000..6355c60 --- /dev/null +++ b/detector_worker/detection/yolo_detector.py @@ -0,0 +1,633 @@ +""" +YOLO detection with BoT-SORT tracking integration. + +This module provides the main detection functionality combining YOLO object detection +with BoT-SORT tracking for stable track validation. +""" + +import cv2 +import logging +from typing import Dict, List, Any, Optional, Tuple, Set +from dataclasses import dataclass, field +from datetime import datetime + +from ..core.constants import DEFAULT_STABILITY_THRESHOLD, DEFAULT_MIN_CONFIDENCE +from ..core.exceptions import DetectionError, create_detection_error +from ..detection.detection_result import DetectionResult, BoundingBox, DetectionSession + +logger = logging.getLogger(__name__) + + +@dataclass +class TrackingConfig: + """Configuration for object tracking.""" + enabled: bool = True + method: str = "botsort" + reid_config_path: str = "botsort.yaml" + stability_threshold: int = DEFAULT_STABILITY_THRESHOLD + use_tracking_for_model: bool = True # False for frontal detection models + + +@dataclass +class StabilityData: + """Track stability data for a camera and model.""" + track_stability_counters: Dict[int, int] = field(default_factory=dict) + stable_tracks: Set[int] = field(default_factory=set) + session_state: Dict[str, Any] = field(default_factory=lambda: { + "active": True, + "waiting_for_backend_session": False, + "wait_start_time": 0.0, + "reset_tracker_on_resume": False + }) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return { + "track_stability_counters": dict(self.track_stability_counters), + "stable_tracks": list(self.stable_tracks), + "session_state": self.session_state.copy() + } + + +@dataclass +class ValidationResult: + """Result of track stability validation.""" + validation_complete: bool + stable_tracks: List[int] = field(default_factory=list) + current_tracks: List[int] = field(default_factory=list) + newly_stable_tracks: List[int] = field(default_factory=list) + send_none_detection: bool = False + branch_node: bool = False + bypass_validation: bool = False + awaiting_stable_tracks: bool = False + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return { + "validation_complete": self.validation_complete, + "stable_tracks": self.stable_tracks.copy(), + "current_tracks": self.current_tracks.copy(), + "newly_stable_tracks": self.newly_stable_tracks.copy(), + "send_none_detection": self.send_none_detection, + "branch_node": self.branch_node, + "bypass_validation": self.bypass_validation, + "awaiting_stable_tracks": self.awaiting_stable_tracks + } + + +class StabilityTracker: + """Manages track stability validation for cameras and models.""" + + def __init__(self): + """Initialize stability tracker.""" + self._camera_stability_tracking: Dict[str, Dict[str, StabilityData]] = {} + self._lock = None + + def _ensure_thread_safety(self): + """Initialize thread safety if not already done.""" + if self._lock is None: + import threading + self._lock = threading.RLock() + + def get_camera_stability_data(self, camera_id: str, model_id: str) -> StabilityData: + """ + Get or create stability tracking data for a specific camera and model. + + Args: + camera_id: Unique camera identifier + model_id: Model identifier + + Returns: + Stability data for the camera and model + """ + self._ensure_thread_safety() + + with self._lock: + if camera_id not in self._camera_stability_tracking: + self._camera_stability_tracking[camera_id] = {} + + if model_id not in self._camera_stability_tracking[camera_id]: + logger.warning(f"๐Ÿ”„ Camera {camera_id}: Creating NEW stability data for {model_id} - this will reset any cooldown!") + self._camera_stability_tracking[camera_id][model_id] = StabilityData() + + return self._camera_stability_tracking[camera_id][model_id] + + def reset_camera_stability_tracking(self, camera_id: str, model_id: str) -> None: + """ + Reset all stability tracking data for a specific camera and model. + + Args: + camera_id: Unique camera identifier + model_id: Model identifier + """ + self._ensure_thread_safety() + + with self._lock: + if camera_id in self._camera_stability_tracking and model_id in self._camera_stability_tracking[camera_id]: + stability_data = self._camera_stability_tracking[camera_id][model_id] + + # Clear all tracking data + old_counters = dict(stability_data.track_stability_counters) + old_stable = list(stability_data.stable_tracks) + + stability_data.track_stability_counters.clear() + stability_data.stable_tracks.clear() + + # IMPORTANT: Set flag to reset YOLO tracker on next detection run + stability_data.session_state["reset_tracker_on_resume"] = True + + logger.info(f"๐Ÿงน Camera {camera_id}: CLEARED stability tracking - old_counters={old_counters}, old_stable={old_stable}") + logger.info(f"๐Ÿ”„ Camera {camera_id}: YOLO tracker will be reset on next detection - fresh track IDs will start from 1") + + def update_single_track_stability(self, + detection: Optional[Dict[str, Any]], + camera_id: str, + model_id: str, + stability_threshold: int, + current_mode: str, + is_branch_node: bool = False) -> ValidationResult: + """ + Update track stability validation for a single highest confidence detection. + + Args: + detection: Detection data with track_id, or None if no detection + camera_id: Unique camera identifier + model_id: Model identifier + stability_threshold: Number of consecutive frames needed for stability + current_mode: Current pipeline mode + is_branch_node: Whether this is a branch node (skip validation) + + Returns: + Validation result with track status + """ + self._ensure_thread_safety() + + with self._lock: + # Branch nodes should not do validation - only main pipeline should + if is_branch_node: + logger.debug(f"โญ๏ธ Camera {camera_id}: Skipping validation for branch node {model_id} - validation only done at main pipeline level") + return ValidationResult( + validation_complete=False, + branch_node=True, + stable_tracks=[], + current_tracks=[] + ) + + # Check current mode - validation counters should increment in both validation_detecting and full_pipeline modes + is_validation_mode = current_mode in ["validation_detecting", "full_pipeline"] + + # Get camera-specific stability data + stability_data = self.get_camera_stability_data(camera_id, model_id) + track_counters = stability_data.track_stability_counters + stable_tracks = stability_data.stable_tracks + + current_track_id = detection.get("id") if detection else None + + # โ•โ•โ• MODE-AWARE TRACK VALIDATION โ•โ•โ• + logger.debug(f"๐Ÿ“‹ Camera {camera_id}: === TRACK VALIDATION ANALYSIS ===") + logger.debug(f"๐Ÿ“‹ Camera {camera_id}: Current mode: {current_mode} (validation_mode={is_validation_mode})") + logger.debug(f"๐Ÿ“‹ Camera {camera_id}: Current track_id: {current_track_id} (assigned by YOLO tracking - not sequential)") + logger.debug(f"๐Ÿ“‹ Camera {camera_id}: Existing counters: {dict(track_counters)}") + logger.debug(f"๐Ÿ“‹ Camera {camera_id}: Stable tracks: {list(stable_tracks)}") + + # IMPORTANT: Only modify validation counters during validation_detecting mode + if not is_validation_mode: + logger.debug(f"๐Ÿšซ Camera {camera_id}: NOT in validation mode - skipping counter modifications") + return ValidationResult( + validation_complete=False, + stable_tracks=list(stable_tracks), + current_tracks=[current_track_id] if current_track_id is not None else [] + ) + + if current_track_id is not None: + # Check if this is a different track than we were tracking + previous_track_ids = list(track_counters.keys()) + + # VALIDATION MODE: Reset counter if different track OR if track was previously stable + should_reset = ( + len(previous_track_ids) == 0 or # No previous tracking + current_track_id not in previous_track_ids or # Different track ID + current_track_id in stable_tracks # Track was stable - start fresh validation + ) + + logger.debug(f"๐Ÿ“‹ Camera {camera_id}: Previous track_ids: {previous_track_ids}") + logger.debug(f"๐Ÿ“‹ Camera {camera_id}: Track {current_track_id} was stable: {current_track_id in stable_tracks}") + logger.debug(f"๐Ÿ“‹ Camera {camera_id}: Should reset counters: {should_reset}") + + if should_reset: + # Clear all previous tracking - fresh validation needed + if previous_track_ids: + for old_track_id in previous_track_ids: + old_count = track_counters.pop(old_track_id, 0) + stable_tracks.discard(old_track_id) + logger.info(f"๐Ÿ”„ Camera {camera_id}: VALIDATION RESET - track {old_track_id} counter from {old_count} to 0 (reason: {'stable_track_restart' if current_track_id == old_track_id else 'different_track'})") + + # Start fresh validation for this track + old_count = track_counters.get(current_track_id, 0) # Store old count for logging + track_counters[current_track_id] = 1 + current_count = 1 + logger.info(f"๐Ÿ†• Camera {camera_id}: FRESH VALIDATION - Track {current_track_id} starting at 1/{stability_threshold}") + else: + # Continue validation for same track + old_count = track_counters.get(current_track_id, 0) + track_counters[current_track_id] = old_count + 1 + current_count = track_counters[current_track_id] + + logger.debug(f"๐Ÿ”ข Camera {camera_id}: Track {current_track_id} counter: {old_count} โ†’ {current_count}") + logger.info(f"๐Ÿ” Camera {camera_id}: Track ID {current_track_id} validation {current_count}/{stability_threshold}") + + # Check if track has reached stability threshold + logger.debug(f"๐Ÿ“Š Camera {camera_id}: Checking stability: {current_count} >= {stability_threshold}? {current_count >= stability_threshold}") + logger.debug(f"๐Ÿ“Š Camera {camera_id}: Already stable: {current_track_id in stable_tracks}") + + if current_count >= stability_threshold and current_track_id not in stable_tracks: + stable_tracks.add(current_track_id) + logger.info(f"โœ… Camera {camera_id}: Track ID {current_track_id} STABLE after {current_count} consecutive frames") + logger.info(f"๐ŸŽฏ Camera {camera_id}: TRACK VALIDATION COMPLETE") + logger.debug(f"๐ŸŽฏ Camera {camera_id}: Stable tracks now: {list(stable_tracks)}") + return ValidationResult( + validation_complete=True, + send_none_detection=True, + stable_tracks=[current_track_id], + newly_stable_tracks=[current_track_id], + current_tracks=[current_track_id] + ) + elif current_count >= stability_threshold: + logger.debug(f"๐Ÿ“Š Camera {camera_id}: Track {current_track_id} already stable - not re-adding") + else: + # No car detected - ALWAYS clear all tracking and reset counters + logger.debug(f"๐Ÿšซ Camera {camera_id}: NO CAR DETECTED - clearing all tracking") + if track_counters or stable_tracks: + logger.debug(f"๐Ÿšซ Camera {camera_id}: Existing state before reset: counters={dict(track_counters)}, stable={list(stable_tracks)}") + for track_id in list(track_counters.keys()): + old_count = track_counters.pop(track_id, 0) + logger.info(f"๐Ÿ”„ Camera {camera_id}: No car detected - RESET track {track_id} counter from {old_count} to 0") + track_counters.clear() # Ensure complete reset + stable_tracks.clear() # Clear all stable tracks + logger.info(f"โœ… Camera {camera_id}: RESET TO VALIDATION PHASE - All counters and stable tracks cleared") + else: + logger.debug(f"๐Ÿšซ Camera {camera_id}: No existing counters to clear") + logger.debug(f"Camera {camera_id}: VALIDATION - no car detected (all counters reset)") + + # Final return - validation not complete + result = ValidationResult( + validation_complete=False, + stable_tracks=list(stable_tracks), + current_tracks=[current_track_id] if current_track_id is not None else [] + ) + + logger.debug(f"๐Ÿ“‹ Camera {camera_id}: Track stability result: {result.to_dict()}") + logger.debug(f"๐Ÿ“‹ Camera {camera_id}: Final counters: {dict(track_counters)}") + logger.debug(f"๐Ÿ“‹ Camera {camera_id}: Final stable tracks: {list(stable_tracks)}") + + return result + + +class YOLODetector: + """ + YOLO object detector with BoT-SORT tracking integration. + + Provides structured detection functionality with track stability validation + and multi-class detection support. + """ + + def __init__(self, stability_tracker: Optional[StabilityTracker] = None): + """ + Initialize YOLO detector. + + Args: + stability_tracker: Optional stability tracker instance + """ + self.stability_tracker = stability_tracker or StabilityTracker() + + def _extract_tracking_config(self, node: Dict[str, Any]) -> TrackingConfig: + """Extract tracking configuration from pipeline node.""" + tracking_config = node.get("tracking", {}) + model_id = node.get("modelId", "") + + # Don't use tracking for frontal detection models + use_tracking = not ("frontal" in model_id.lower() or "detection" in model_id.lower()) + + return TrackingConfig( + enabled=tracking_config.get("enabled", True), + method=tracking_config.get("method", "botsort"), + reid_config_path=tracking_config.get("reidConfig", tracking_config.get("reidConfigPath", "botsort.yaml")), + stability_threshold=tracking_config.get("stabilityThreshold", node.get("stabilityThreshold", DEFAULT_STABILITY_THRESHOLD)), + use_tracking_for_model=use_tracking + ) + + def _determine_confidence_threshold(self, node: Dict[str, Any]) -> float: + """Determine confidence threshold based on model type.""" + model_id = node.get("modelId", "") + + # Use frontalMinConfidence for frontal detection models + if "frontal" in model_id.lower() and "frontalMinConfidence" in node: + min_confidence = node.get("frontalMinConfidence", DEFAULT_MIN_CONFIDENCE) + logger.debug(f"Using frontalMinConfidence={min_confidence} for {model_id}") + else: + min_confidence = node.get("minConfidence", DEFAULT_MIN_CONFIDENCE) + + return min_confidence + + def _reset_yolo_tracker_if_needed(self, node: Dict[str, Any], camera_id: str, model_id: str) -> None: + """Reset YOLO tracker if flagged for reset.""" + stability_data = self.stability_tracker.get_camera_stability_data(camera_id, model_id) + session_state = stability_data.session_state + + if session_state.get("reset_tracker_on_resume", False): + # Reset YOLO tracker to get fresh track IDs + if hasattr(node["model"], 'trackers') and node["model"].trackers: + node["model"].trackers.clear() # Clear tracker state + logger.info(f"Camera {camera_id}: ๐Ÿ”„ Reset YOLO tracker - new cars will get fresh track IDs") + session_state["reset_tracker_on_resume"] = False # Clear the flag + + def _run_yolo_inference(self, + frame, + node: Dict[str, Any], + tracking_config: TrackingConfig) -> Any: + """Run YOLO inference with or without tracking.""" + # Prepare class filtering + trigger_class_indices = node.get("triggerClassIndices") + class_filter = {"classes": trigger_class_indices} if trigger_class_indices else {} + + model_id = node.get("modelId", "") + use_tracking = tracking_config.enabled and tracking_config.use_tracking_for_model + + logger.debug(f"Running detection for {model_id} - tracking: {use_tracking}, stability_threshold: {tracking_config.stability_threshold}, classes: {node.get('triggerClasses', 'all')}") + + if use_tracking: + # Use tracking for main detection models (yolo11m, etc.) + logger.debug(f"Using tracking for {model_id}") + res = node["model"].track( + frame, + stream=False, + persist=True, + **class_filter + )[0] + else: + # Use detection only for frontal detection and other detection-only models + logger.debug(f"Using prediction only for {model_id}") + res = node["model"].predict( + frame, + stream=False, + **class_filter + )[0] + + return res + + def _process_detections(self, + res: Any, + node: Dict[str, Any], + camera_id: str, + min_confidence: float) -> List[Dict[str, Any]]: + """Process YOLO detection results into candidate detections.""" + candidate_detections = [] + + if res.boxes is None or len(res.boxes) == 0: + logger.debug(f"๐Ÿšซ Camera {camera_id}: YOLO returned no detections") + return candidate_detections + + logger.debug(f"๐Ÿ” Camera {camera_id}: YOLO detected {len(res.boxes)} raw objects - processing with tracking...") + logger.debug(f"๐Ÿ” Camera {camera_id}: === DETECTION ANALYSIS ===") + + for i, box in enumerate(res.boxes): + # Extract detection data + conf = float(box.cpu().conf[0]) + cls_id = int(box.cpu().cls[0]) + class_name = node["model"].names[cls_id] + + # Extract bounding box + xy = box.cpu().xyxy[0] + x1, y1, x2, y2 = map(int, xy) + bbox = (x1, y1, x2, y2) + + # Extract tracking ID if available + track_id = None + if hasattr(box, "id") and box.id is not None: + track_id = int(box.id.item()) + + logger.debug(f"๐Ÿ” Camera {camera_id}: Detection {i+1}: class='{class_name}' conf={conf:.3f} track_id={track_id} bbox={bbox}") + + # Apply confidence filtering + if conf < min_confidence: + logger.debug(f"โŒ Camera {camera_id}: Detection {i+1} REJECTED - confidence {conf:.3f} < {min_confidence}") + continue + + # Create detection object + detection = { + "class": class_name, + "confidence": conf, + "id": track_id, + "bbox": bbox, + "class_id": cls_id + } + + candidate_detections.append(detection) + logger.debug(f"โœ… Camera {camera_id}: Detection {i+1} ACCEPTED as candidate: {class_name} (conf={conf:.3f}, track_id={track_id})") + + return candidate_detections + + def _select_best_detection(self, + candidate_detections: List[Dict[str, Any]], + camera_id: str) -> Optional[Dict[str, Any]]: + """Select the highest confidence detection from candidates.""" + if not candidate_detections: + logger.debug(f"๐Ÿšซ Camera {camera_id}: No valid candidates after filtering - no car will be tracked") + return None + + logger.debug(f"๐Ÿ† Camera {camera_id}: === SELECTING HIGHEST CONFIDENCE CAR ===") + for i, detection in enumerate(candidate_detections): + logger.debug(f"๐Ÿ† Camera {camera_id}: Candidate {i+1}: {detection['class']} conf={detection['confidence']:.3f} track_id={detection['id']}") + + # Find the single highest confidence detection across all detected classes + best_detection = max(candidate_detections, key=lambda x: x["confidence"]) + + logger.info(f"๐ŸŽฏ Camera {camera_id}: SELECTED WINNER: {best_detection['class']} (conf={best_detection['confidence']:.3f}, track_id={best_detection['id']}, bbox={best_detection['bbox']})") + + # Show which cars were NOT selected + for detection in candidate_detections: + if detection != best_detection: + logger.debug(f"๐Ÿšซ Camera {camera_id}: NOT SELECTED: {detection['class']} (conf={detection['confidence']:.3f}, track_id={detection['id']}) - lower confidence") + + return best_detection + + def _apply_class_mapping(self, detection: Dict[str, Any], node: Dict[str, Any]) -> Dict[str, Any]: + """Apply class mapping if configured.""" + original_class = detection["class"] + class_mapping = node.get("classMapping", {}) + + if original_class in class_mapping: + mapped_class = class_mapping[original_class] + logger.info(f"Class mapping applied: {original_class} โ†’ {mapped_class}") + # Update the detection object with mapped class + detection["class"] = mapped_class + detection["original_class"] = original_class # Keep original for reference + + return detection + + def _validate_multi_class(self, + regions_dict: Dict[str, Any], + node: Dict[str, Any]) -> bool: + """Validate multi-class detection requirements.""" + if not (node.get("multiClass", False) and node.get("expectedClasses")): + return True + + expected_classes = node["expectedClasses"] + detected_classes = list(regions_dict.keys()) + + logger.debug(f"Multi-class validation: expected={expected_classes}, detected={detected_classes}") + + # Check for required classes (flexible - at least one must match) + matching_classes = [cls for cls in expected_classes if cls in detected_classes] + if not matching_classes: + logger.warning(f"Multi-class validation failed: no expected classes detected") + return False + + logger.info(f"Multi-class validation passed: {matching_classes} detected") + return True + + def run_detection_with_tracking(self, + frame, + node: Dict[str, Any], + context: Optional[Dict[str, Any]] = None) -> Tuple[List[Dict[str, Any]], Dict[str, Any], Dict[str, Any]]: + """ + Run YOLO detection with BoT-SORT tracking and stability validation. + + Args: + frame: Input frame/image + node: Pipeline node configuration with model and settings + context: Optional context information (camera info, session data, etc.) + + Returns: + tuple: (all_detections, regions_dict, track_validation_result) where: + - all_detections: List of all detection objects + - regions_dict: Dict mapping class names to highest confidence detections + - track_validation_result: Dict with validation status and stable tracks + """ + try: + camera_id = context.get("camera_id", "unknown") if context else "unknown" + model_id = node.get("modelId", "unknown") + current_mode = context.get("current_mode", "unknown") if context else "unknown" + + # Extract configuration + tracking_config = self._extract_tracking_config(node) + min_confidence = self._determine_confidence_threshold(node) + + # Check if we need to reset tracker after cooldown + self._reset_yolo_tracker_if_needed(node, camera_id, model_id) + + # Run YOLO inference + res = self._run_yolo_inference(frame, node, tracking_config) + + # Process detection results + candidate_detections = self._process_detections(res, node, camera_id, min_confidence) + + # Select best detection + best_detection = self._select_best_detection(candidate_detections, camera_id) + + # Update track stability validation + is_branch_node = node.get("cropClass") is not None or node.get("parallel") is True + track_validation_result = self.stability_tracker.update_single_track_stability( + detection=best_detection, + camera_id=camera_id, + model_id=model_id, + stability_threshold=tracking_config.stability_threshold, + current_mode=current_mode, + is_branch_node=is_branch_node + ) + + # Handle no detection case + if best_detection is None: + # Store validation state in context for pipeline decisions + if context is not None: + context["track_validation_result"] = track_validation_result.to_dict() + + return [], {}, track_validation_result.to_dict() + + # Apply class mapping + best_detection = self._apply_class_mapping(best_detection, node) + + # Create regions dictionary + mapped_class = best_detection["class"] + track_id = best_detection["id"] + + all_detections = [best_detection] + regions_dict = { + mapped_class: { + "bbox": best_detection["bbox"], + "confidence": best_detection["confidence"], + "detection": best_detection, + "track_id": track_id + } + } + + # Multi-class validation + if not self._validate_multi_class(regions_dict, node): + return [], {}, track_validation_result.to_dict() + + logger.info(f"โœ… Camera {camera_id}: DETECTION COMPLETE - tracking single car: track_id={track_id}, conf={best_detection['confidence']:.3f}") + logger.debug(f"๐Ÿ“Š Camera {camera_id}: Detection summary: {len(res.boxes)} raw โ†’ {len(candidate_detections)} candidates โ†’ 1 selected") + + # Store validation state in context for pipeline decisions + if context is not None: + context["track_validation_result"] = track_validation_result.to_dict() + + return all_detections, regions_dict, track_validation_result.to_dict() + + except Exception as e: + camera_id = context.get("camera_id", "unknown") if context else "unknown" + model_id = node.get("modelId", "unknown") + raise create_detection_error(camera_id, model_id, "detection_with_tracking", e) + + +# Global instances for backward compatibility +_global_stability_tracker = StabilityTracker() +_global_yolo_detector = YOLODetector(_global_stability_tracker) + + +# ===== CONVENIENCE FUNCTIONS ===== +# These provide the same interface as the original functions in pympta.py + +def run_detection_with_tracking(frame, node: Dict[str, Any], context: Optional[Dict[str, Any]] = None) -> Tuple[List[Dict[str, Any]], Dict[str, Any], Dict[str, Any]]: + """Run YOLO detection with BoT-SORT tracking using global detector instance.""" + return _global_yolo_detector.run_detection_with_tracking(frame, node, context) + + +def get_camera_stability_data(camera_id: str, model_id: str) -> Dict[str, Any]: + """Get stability tracking data for a specific camera and model.""" + stability_data = _global_stability_tracker.get_camera_stability_data(camera_id, model_id) + return stability_data.to_dict() + + +def reset_camera_stability_tracking(camera_id: str, model_id: str) -> None: + """Reset all stability tracking data for a specific camera and model.""" + _global_stability_tracker.reset_camera_stability_tracking(camera_id, model_id) + + +def update_single_track_stability(node: Dict[str, Any], + detection: Optional[Dict[str, Any]], + camera_id: str, + frame_shape=None, + stability_threshold: int = DEFAULT_STABILITY_THRESHOLD, + context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """Update track stability validation for a single highest confidence detection.""" + model_id = node.get("modelId", "unknown") + current_mode = context.get("current_mode", "unknown") if context else "unknown" + is_branch_node = node.get("cropClass") is not None or node.get("parallel") is True + + result = _global_stability_tracker.update_single_track_stability( + detection=detection, + camera_id=camera_id, + model_id=model_id, + stability_threshold=stability_threshold, + current_mode=current_mode, + is_branch_node=is_branch_node + ) + + return result.to_dict() + + +def reset_tracking_state(camera_id: str, model_id: str, reason: str = "reset_requested") -> None: + """Reset tracking state for a camera and model.""" + logger.info(f"๐Ÿ”„ Camera {camera_id}: Resetting tracking state for {model_id} - reason: {reason}") + reset_camera_stability_tracking(camera_id, model_id) \ No newline at end of file diff --git a/detector_worker/pipeline/pipeline_executor.py b/detector_worker/pipeline/pipeline_executor.py new file mode 100644 index 0000000..7b88073 --- /dev/null +++ b/detector_worker/pipeline/pipeline_executor.py @@ -0,0 +1,694 @@ +""" +Pipeline execution engine for computer vision detection workflows. + +This module provides the main pipeline execution functionality including: +- Multi-class detection coordination +- Branch processing (parallel and sequential) +- Action execution and database operations +- Session state management +""" + +import logging +import concurrent.futures +from typing import Dict, List, Any, Optional, Tuple, Union +from dataclasses import dataclass, field +from datetime import datetime +import numpy as np + +from ..core.constants import ( + DEFAULT_THREAD_POOL_SIZE, + CLASSIFICATION_TIMEOUT_SECONDS, + PIPELINE_EXECUTION_TIMEOUT +) +from ..core.exceptions import PipelineError, create_pipeline_error +from ..detection.detection_result import DetectionResult, DetectionSession +from ..detection.yolo_detector import run_detection_with_tracking +from ..detection.stability_validator import validate_pipeline_execution +from ..detection.tracking_manager import is_camera_active, occupancy_detector + +logger = logging.getLogger(__name__) + + +@dataclass +class PipelineContext: + """Context information passed through pipeline execution.""" + camera_id: str = "unknown" + backend_session_id: Optional[str] = None + display_id: Optional[str] = None + current_mode: str = "unknown" + regions_dict: Optional[Dict[str, Any]] = None + session_id: Optional[str] = None + timestamp: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return { + "camera_id": self.camera_id, + "backend_session_id": self.backend_session_id, + "display_id": self.display_id, + "current_mode": self.current_mode, + "regions_dict": self.regions_dict, + "session_id": self.session_id, + "timestamp": self.timestamp + } + + +@dataclass +class BranchResult: + """Result from branch execution.""" + model_id: str + success: bool + result: Optional[Dict[str, Any]] = None + error: Optional[str] = None + nested_results: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + data = { + "model_id": self.model_id, + "success": self.success + } + if self.result: + data["result"] = self.result + if self.error: + data["error"] = self.error + if self.nested_results: + data["nested_results"] = self.nested_results + return data + + +@dataclass +class PipelineResult: + """Result from pipeline execution.""" + success: bool + primary_detection: Optional[Dict[str, Any]] = None + primary_bbox: Optional[List[int]] = None + branch_results: Dict[str, Any] = field(default_factory=dict) + session_id: Optional[str] = None + awaiting_session_id: bool = False + awaiting_stable_tracks: bool = False + + def to_tuple(self, return_bbox: bool = False) -> Union[Dict[str, Any], Tuple[Dict[str, Any], List[int]], Tuple[None, None]]: + """Convert to return format expected by original function.""" + if not self.success or not self.primary_detection: + return (None, None) if return_bbox else None + + if return_bbox: + return (self.primary_detection, self.primary_bbox or [0, 0, 0, 0]) + else: + return self.primary_detection + + +class PipelineExecutor: + """ + Main pipeline execution engine for computer vision detection workflows. + + This class handles the complete pipeline including detection, tracking, + branch processing, action execution, and database operations. + """ + + def __init__(self, thread_pool_size: int = DEFAULT_THREAD_POOL_SIZE): + """ + Initialize pipeline executor. + + Args: + thread_pool_size: Maximum number of threads for parallel processing + """ + self.thread_pool_size = thread_pool_size + + def _extract_context(self, context: Optional[Dict[str, Any]]) -> PipelineContext: + """Extract pipeline context from input dictionary.""" + if not context: + return PipelineContext() + + return PipelineContext( + camera_id=context.get("camera_id", "unknown"), + backend_session_id=context.get("backend_session_id"), + display_id=context.get("display_id"), + current_mode=context.get("current_mode", "unknown"), + regions_dict=context.get("regions_dict"), + session_id=context.get("session_id"), + timestamp=context.get("timestamp") + ) + + def _handle_classification_task(self, + frame: np.ndarray, + node: Dict[str, Any], + pipeline_context: PipelineContext, + return_bbox: bool) -> Union[Dict[str, Any], Tuple[Dict[str, Any], List[int]], Tuple[None, None]]: + """Handle classification-only pipeline nodes.""" + try: + results = node["model"].predict(frame, stream=False) + if not results: + return (None, None) if return_bbox else None + + r = results[0] + probs = r.probs + if probs is None: + return (None, None) if return_bbox else None + + top1_idx = int(probs.top1) + top1_conf = float(probs.top1conf) + class_name = node["model"].names[top1_idx] + + det = { + "class": class_name, + "confidence": top1_conf, + "id": None, + class_name: class_name # Add class name as key for backward compatibility + } + + # Add specific field mappings for database operations based on model type + model_id = node.get("modelId", "").lower() + if "brand" in model_id or "brand_cls" in model_id: + det["brand"] = class_name + elif "bodytype" in model_id or "body" in model_id: + det["body_type"] = class_name + elif "color" in model_id: + det["color"] = class_name + + # Execute actions for classification nodes + self._execute_node_actions(node, frame, det, pipeline_context.regions_dict) + + return (det, None) if return_bbox else det + + except Exception as e: + logger.error(f"Error in classification task for {node.get('modelId')}: {e}") + return (None, None) if return_bbox else None + + def _check_camera_active(self, camera_id: str, model_id: str, return_bbox: bool) -> Optional[Union[Dict[str, Any], Tuple[Dict[str, Any], List[int]]]]: + """Check if camera is active for processing.""" + if not is_camera_active(camera_id, model_id): + logger.debug(f"โฐ Camera {camera_id}: Waiting for backend sessionId, sending 'none' detection") + none_detection = { + "class": "none", + "confidence": 1.0, + "bbox": [0, 0, 0, 0], + "branch_results": {} + } + return (none_detection, [0, 0, 0, 0]) if return_bbox else none_detection + return None + + def _run_detection_stage(self, + frame: np.ndarray, + node: Dict[str, Any], + pipeline_context: PipelineContext, + validated_detection: Optional[Dict[str, Any]] = None) -> Tuple[List[Dict[str, Any]], Dict[str, Any], Dict[str, Any]]: + """Run the detection stage of the pipeline.""" + if validated_detection: + track_id = validated_detection.get('track_id') + logger.info(f"๐Ÿ”„ PIPELINE: Using validated detection from validation phase - track_id={track_id}") + + # Convert validated detection back to all_detections format for branch processing + all_detections = [validated_detection] + + # Create regions_dict based on validated detection class with proper structure + class_name = validated_detection.get("class", "car") + regions_dict = { + class_name: { + "confidence": validated_detection.get("confidence"), + "bbox": validated_detection.get("bbox", [0, 0, 0, 0]), + "detection": validated_detection + } + } + + # Bypass track validation completely - force pipeline execution + track_validation_result = { + "validation_complete": True, + "stable_tracks": ["cached"], # Use dummy stable track to force pipeline execution + "current_tracks": ["cached"], + "bypass_validation": True + } + else: + # Normal detection stage - Using structured detection function + all_detections, regions_dict, track_validation_result = run_detection_with_tracking( + frame, node, pipeline_context.to_dict() + ) + + return all_detections, regions_dict, track_validation_result + + def _validate_tracking_requirements(self, + node: Dict[str, Any], + track_validation_result: Dict[str, Any], + pipeline_context: PipelineContext, + return_bbox: bool) -> Optional[Union[Dict[str, Any], Tuple[Dict[str, Any], List[int]]]]: + """Validate tracking requirements for pipeline execution.""" + tracking_config = node.get("tracking", {}) + stability_threshold = tracking_config.get("stabilityThreshold", node.get("stabilityThreshold", 1)) + + if stability_threshold <= 1 or not tracking_config.get("enabled", True): + return None # No tracking requirements + + # Check if this is a branch node - branches should execute regardless of main validation state + is_branch_node = node.get("cropClass") is not None or node.get("parallel") is True + + if is_branch_node: + logger.debug(f"๐Ÿ” Camera {pipeline_context.camera_id}: Branch node {node.get('modelId')} executing during track validation phase") + return None + + # Main pipeline node during track validation - check for stable tracks + stable_tracks = track_validation_result.get("stable_tracks", []) + + if not stable_tracks: + logger.debug(f"๐Ÿ”’ Camera {pipeline_context.camera_id}: Main pipeline requires stable tracks - none found, skipping pipeline execution") + none_detection = { + "class": "none", + "confidence": 1.0, + "bbox": [0, 0, 0, 0], + "awaiting_stable_tracks": True + } + return (none_detection, [0, 0, 0, 0]) if return_bbox else none_detection + + logger.info(f"๐ŸŽฏ Camera {pipeline_context.camera_id}: STABLE TRACKS DETECTED - proceeding with full pipeline (tracks: {stable_tracks})") + return None + + def _handle_database_operations(self, + node: Dict[str, Any], + detection_result: Dict[str, Any], + regions_dict: Dict[str, Any], + pipeline_context: PipelineContext) -> None: + """Handle database operations if database manager is available.""" + if not (node.get("db_manager") and regions_dict): + return + + detected_classes = list(regions_dict.keys()) + logger.debug(f"Valid detections found: {detected_classes}") + + if pipeline_context.backend_session_id: + # Backend sessionId is available, proceed with database operations + display_id = pipeline_context.display_id or "unknown" + timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + + inserted_session_id = node["db_manager"].insert_initial_detection( + display_id=display_id, + captured_timestamp=timestamp, + session_id=pipeline_context.backend_session_id + ) + + if inserted_session_id: + detection_result["session_id"] = inserted_session_id + detection_result["timestamp"] = timestamp + logger.info(f"๐Ÿ’พ DATABASE RECORD CREATED with backend session_id: {inserted_session_id}") + logger.debug(f"Database record: display_id={display_id}, timestamp={timestamp}") + else: + logger.error(f"Failed to create database record with backend session_id: {pipeline_context.backend_session_id}") + else: + logger.info(f"๐Ÿ“ก Camera {pipeline_context.camera_id}: Full pipeline completed, detection data will be sent to backend. Database operations will occur when sessionId is received.") + # Store detection info for later database operations when sessionId arrives + detection_result["awaiting_session_id"] = True + detection_result["timestamp"] = datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + + def _execute_node_actions(self, + node: Dict[str, Any], + frame: np.ndarray, + detection_result: Dict[str, Any], + regions_dict: Optional[Dict[str, Any]]) -> None: + """Execute actions for a node.""" + # This is a placeholder for action execution + # In the actual implementation, this would import and call execute_actions + pass + + def _execute_parallel_actions(self, + node: Dict[str, Any], + frame: np.ndarray, + detection_result: Dict[str, Any], + regions_dict: Dict[str, Any]) -> None: + """Execute parallel actions after branch completion.""" + # This is a placeholder for parallel action execution + # In the actual implementation, this would import and call execute_parallel_actions + pass + + def _crop_region_by_class(self, + frame: np.ndarray, + regions_dict: Dict[str, Any], + class_name: str) -> Optional[np.ndarray]: + """Crop a specific region from frame based on detected class.""" + if class_name not in regions_dict: + logger.warning(f"Class '{class_name}' not found in detected regions") + return None + + bbox = regions_dict[class_name]["bbox"] + x1, y1, x2, y2 = bbox + + # Validate bbox coordinates + if x2 <= x1 or y2 <= y1: + logger.warning(f"Invalid bbox for class {class_name}: {bbox}") + return None + + try: + cropped = frame[y1:y2, x1:x2] + if cropped.size == 0: + logger.warning(f"Empty crop for class {class_name}") + return None + return cropped + except Exception as e: + logger.error(f"Error cropping region for class {class_name}: {e}") + return None + + def _prepare_branch_context(self, + base_context: PipelineContext, + regions_dict: Dict[str, Any], + detection_result: Dict[str, Any]) -> Dict[str, Any]: + """Prepare context for branch execution.""" + branch_context = base_context.to_dict() + branch_context["regions_dict"] = regions_dict + + # Pass session_id from detection_result to branch context for Redis actions + if "session_id" in detection_result: + branch_context["session_id"] = detection_result["session_id"] + logger.debug(f"Added session_id to branch context: {detection_result['session_id']}") + elif base_context.backend_session_id: + branch_context["session_id"] = base_context.backend_session_id + logger.debug(f"Added backend_session_id to branch context: {base_context.backend_session_id}") + + return branch_context + + def _execute_single_branch(self, + frame: np.ndarray, + branch: Dict[str, Any], + branch_context: Dict[str, Any], + regions_dict: Dict[str, Any]) -> BranchResult: + """Execute a single branch.""" + model_id = branch["modelId"] + + try: + sub_frame = frame + crop_class = branch.get("cropClass") + + logger.info(f"Starting branch: {model_id}, cropClass: {crop_class}") + + # Handle cropping if required + if branch.get("crop", False) and crop_class: + if crop_class in regions_dict: + cropped = self._crop_region_by_class(frame, regions_dict, crop_class) + if cropped is not None: + sub_frame = cropped # Use cropped image without manual resizing + logger.debug(f"Successfully cropped {crop_class} region for {model_id} - model will handle resizing") + else: + return BranchResult( + model_id=model_id, + success=False, + error=f"Failed to crop {crop_class} region" + ) + else: + return BranchResult( + model_id=model_id, + success=False, + error=f"Crop class {crop_class} not found in detected regions" + ) + + # Execute branch pipeline + result, _ = self.run_pipeline(sub_frame, branch, True, branch_context) + + if result: + branch_result = BranchResult( + model_id=model_id, + success=True, + result=result + ) + + # Collect nested branch results if they exist + if "branch_results" in result: + branch_result.nested_results = result["branch_results"] + + logger.info(f"Branch {model_id} completed: {result}") + return branch_result + else: + return BranchResult( + model_id=model_id, + success=False, + error="Branch returned no result" + ) + + except Exception as e: + logger.error(f"Error in branch {model_id}: {e}") + return BranchResult( + model_id=model_id, + success=False, + error=str(e) + ) + + def _process_branches_parallel(self, + frame: np.ndarray, + active_branches: List[Dict[str, Any]], + branch_context: Dict[str, Any], + regions_dict: Dict[str, Any]) -> Dict[str, Any]: + """Process branches in parallel.""" + branch_results = {} + + with concurrent.futures.ThreadPoolExecutor(max_workers=len(active_branches)) as executor: + futures = {} + + for branch in active_branches: + future = executor.submit( + self._execute_single_branch, + frame, branch, branch_context, regions_dict + ) + futures[future] = branch + + # Collect results + for future in concurrent.futures.as_completed(futures): + branch = futures[future] + try: + result = future.result() + if result.success and result.result: + branch_results[result.model_id] = result.result + + # Collect nested branch results + for nested_id, nested_result in result.nested_results.items(): + branch_results[nested_id] = nested_result + logger.info(f"Collected nested branch result: {nested_id} = {nested_result}") + else: + logger.error(f"Branch {result.model_id} failed: {result.error}") + except Exception as e: + logger.error(f"Branch {branch['modelId']} failed: {e}") + + return branch_results + + def _process_branches_sequential(self, + frame: np.ndarray, + active_branches: List[Dict[str, Any]], + branch_context: Dict[str, Any], + regions_dict: Dict[str, Any]) -> Dict[str, Any]: + """Process branches sequentially.""" + branch_results = {} + + for branch in active_branches: + result = self._execute_single_branch(frame, branch, branch_context, regions_dict) + + if result.success and result.result: + branch_results[result.model_id] = result.result + + # Collect nested branch results + for nested_id, nested_result in result.nested_results.items(): + branch_results[nested_id] = nested_result + logger.info(f"Collected nested branch result: {nested_id} = {nested_result}") + else: + logger.error(f"Branch {result.model_id} failed: {result.error}") + + return branch_results + + def _filter_active_branches(self, + node: Dict[str, Any], + regions_dict: Dict[str, Any]) -> List[Dict[str, Any]]: + """Filter branches that should be triggered based on detected regions.""" + active_branches = [] + + for br in node["branches"]: + trigger_classes = br.get("triggerClasses", []) + min_conf = br.get("minConfidence", 0) + + logger.debug(f"Evaluating branch {br['modelId']}: trigger_classes={trigger_classes}, min_conf={min_conf}") + + # Check if any detected class matches branch trigger + branch_triggered = False + for det_class in regions_dict: + det_confidence = regions_dict[det_class]["confidence"] + logger.debug(f" Checking detected class '{det_class}' (confidence={det_confidence:.3f}) against triggers {trigger_classes}") + + if (det_class in trigger_classes and det_confidence >= min_conf): + active_branches.append(br) + branch_triggered = True + logger.info(f"Branch {br['modelId']} activated by class '{det_class}' (conf={det_confidence:.3f} >= {min_conf})") + break + + if not branch_triggered: + logger.debug(f"Branch {br['modelId']} not triggered - no matching classes or insufficient confidence") + + return active_branches + + def _process_branches(self, + frame: np.ndarray, + node: Dict[str, Any], + detection_result: Dict[str, Any], + regions_dict: Dict[str, Any], + pipeline_context: PipelineContext) -> Dict[str, Any]: + """Process all branches for a node.""" + if not node.get("branches"): + return {} + + # Filter branches that should be triggered + active_branches = self._filter_active_branches(node, regions_dict) + + if not active_branches: + return {} + + # Prepare branch context + branch_context = self._prepare_branch_context(pipeline_context, regions_dict, detection_result) + + # Execute branches + if node.get("parallel", False) or any(br.get("parallel", False) for br in active_branches): + # Run branches in parallel + branch_results = self._process_branches_parallel(frame, active_branches, branch_context, regions_dict) + else: + # Run branches sequentially + branch_results = self._process_branches_sequential(frame, active_branches, branch_context, regions_dict) + + return branch_results + + def run_pipeline(self, + frame: np.ndarray, + node: Dict[str, Any], + return_bbox: bool = False, + context: Optional[Dict[str, Any]] = None, + validated_detection: Optional[Dict[str, Any]] = None) -> Union[Dict[str, Any], Tuple[Dict[str, Any], List[int]], Tuple[None, None]]: + """ + Run enhanced pipeline that supports: + - Multi-class detection (detecting multiple classes simultaneously) + - Parallel branch processing + - Region-based actions and cropping + - Context passing for session/camera information + + Args: + frame: Input frame for processing + node: Pipeline node configuration + return_bbox: Whether to return bounding box with result + context: Optional context information + validated_detection: Optional pre-validated detection to use + + Returns: + Detection result, optionally with bounding box + """ + try: + # Extract context information + pipeline_context = self._extract_context(context) + model_id = node.get("modelId", "unknown") + + if pipeline_context.backend_session_id: + logger.info(f"๐Ÿ”‘ PIPELINE USING BACKEND SESSION_ID: {pipeline_context.backend_session_id} for camera {pipeline_context.camera_id}") + + task = getattr(node["model"], "task", None) + + # โ”€โ”€โ”€ Classification stage โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + if task == "classify": + return self._handle_classification_task(frame, node, pipeline_context, return_bbox) + + # โ”€โ”€โ”€ Session management check โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + camera_inactive_result = self._check_camera_active(pipeline_context.camera_id, model_id, return_bbox) + if camera_inactive_result is not None: + return camera_inactive_result + + # โ”€โ”€โ”€ Detection stage โ”€โ”€โ”€ + all_detections, regions_dict, track_validation_result = self._run_detection_stage( + frame, node, pipeline_context, validated_detection + ) + + if not all_detections: + logger.debug("No detections from structured detection function - sending 'none' detection") + none_detection = { + "class": "none", + "confidence": 1.0, + "bbox": [0, 0, 0, 0], + "branch_results": {} + } + return (none_detection, [0, 0, 0, 0]) if return_bbox else none_detection + + # โ”€โ”€โ”€ Track-Based Validation System โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + tracking_validation_result = self._validate_tracking_requirements( + node, track_validation_result, pipeline_context, return_bbox + ) + if tracking_validation_result is not None: + return tracking_validation_result + + # โ”€โ”€โ”€ Pre-validate pipeline execution โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + pipeline_valid, missing_branches = validate_pipeline_execution(node, regions_dict) + + if not pipeline_valid: + logger.error(f"Pipeline execution validation FAILED - required branches {missing_branches} cannot execute") + logger.error("Aborting pipeline: no Redis actions or database records will be created") + return (None, None) if return_bbox else None + + # โ”€โ”€โ”€ Execute actions with region information โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + detection_result = { + "detections": all_detections, + "regions": regions_dict, + **pipeline_context.to_dict() + } + + # โ”€โ”€โ”€ Database operations โ”€โ”€โ”€โ”€ + self._handle_database_operations(node, detection_result, regions_dict, pipeline_context) + + # Execute actions for root node only if it doesn't have branches + # Branch nodes with actions will execute them after branch processing + if not node.get("branches") or node.get("modelId") == "yolo11n": + self._execute_node_actions(node, frame, detection_result, regions_dict) + + # โ”€โ”€โ”€ Branch processing โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + branch_results = self._process_branches(frame, node, detection_result, regions_dict, pipeline_context) + detection_result["branch_results"] = branch_results + + # โ”€โ”€โ”€ Execute Parallel Actions โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + if node.get("parallelActions") and "branch_results" in detection_result: + self._execute_parallel_actions(node, frame, detection_result, regions_dict) + + # โ”€โ”€โ”€ Auto-enable occupancy mode after successful pipeline completion โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + occupancy_detector(pipeline_context.camera_id, model_id, enable=True) + + logger.info(f"โœ… Camera {pipeline_context.camera_id}: Pipeline completed, detection data will be sent to backend") + logger.info(f"๐Ÿ›‘ Camera {pipeline_context.camera_id}: Model will stop inference for future frames") + logger.info(f"๐Ÿ“ก Backend sessionId will be handled when received via WebSocket") + + # โ”€โ”€โ”€ Execute actions after successful detection AND branch processing โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + # This ensures detection nodes (like frontal_detection_v1) execute their actions + # after completing both detection and branch processing + if node.get("actions") and regions_dict and node.get("modelId") != "yolo11n": + # Execute actions for branch detection nodes, skip root to avoid duplication + logger.debug(f"Executing post-detection actions for branch node {node.get('modelId')}") + self._execute_node_actions(node, frame, detection_result, regions_dict) + + # โ”€โ”€โ”€ Return detection result โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + primary_detection = max(all_detections, key=lambda x: x["confidence"]) + primary_bbox = primary_detection["bbox"] + + # Add branch results and session_id to primary detection for compatibility + if "branch_results" in detection_result: + primary_detection["branch_results"] = detection_result["branch_results"] + if "session_id" in detection_result: + primary_detection["session_id"] = detection_result["session_id"] + + return (primary_detection, primary_bbox) if return_bbox else primary_detection + + except Exception as e: + pipeline_id = node.get("modelId", "unknown") + raise create_pipeline_error(pipeline_id, "pipeline_execution", e) + + +# Global pipeline executor instance +pipeline_executor = PipelineExecutor() + + +# ===== CONVENIENCE FUNCTIONS ===== +# These provide the same interface as the original functions in pympta.py + +def run_pipeline(frame: np.ndarray, + node: Dict[str, Any], + return_bbox: bool = False, + context: Optional[Dict[str, Any]] = None, + validated_detection: Optional[Dict[str, Any]] = None) -> Union[Dict[str, Any], Tuple[Dict[str, Any], List[int]], Tuple[None, None]]: + """Run enhanced pipeline using global executor instance.""" + return pipeline_executor.run_pipeline(frame, node, return_bbox, context, validated_detection) + + +def crop_region_by_class(frame: np.ndarray, regions_dict: Dict[str, Any], class_name: str) -> Optional[np.ndarray]: + """Crop a specific region from frame based on detected class.""" + return pipeline_executor._crop_region_by_class(frame, regions_dict, class_name) \ No newline at end of file diff --git a/detector_worker/streams/camera_monitor.py b/detector_worker/streams/camera_monitor.py new file mode 100644 index 0000000..31432f6 --- /dev/null +++ b/detector_worker/streams/camera_monitor.py @@ -0,0 +1,345 @@ +""" +Camera connection state monitoring and management. + +This module provides centralized tracking of camera connection states, +error handling, and disconnection notifications. +""" + +import time +import logging +from typing import Dict, Any, Optional +from dataclasses import dataclass, field +from datetime import datetime + +from ..core.exceptions import CameraConnectionError + +logger = logging.getLogger(__name__) + + +@dataclass +class CameraConnectionState: + """Represents the connection state of a single camera.""" + connected: bool = True + last_error: Optional[str] = None + last_error_time: Optional[float] = None + consecutive_failures: int = 0 + disconnection_notified: bool = False + last_successful_frame: Optional[float] = None + connection_start_time: Optional[float] = field(default_factory=time.time) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return { + "connected": self.connected, + "last_error": self.last_error, + "last_error_time": self.last_error_time, + "consecutive_failures": self.consecutive_failures, + "disconnection_notified": self.disconnection_notified, + "last_successful_frame": self.last_successful_frame, + "connection_start_time": self.connection_start_time + } + + +class CameraMonitor: + """ + Monitors camera connection states and handles disconnection notifications. + + This class provides a centralized way to track camera connection status, + handle error states, and determine when to notify about disconnections. + """ + + def __init__(self, failure_threshold: int = 3): + """ + Initialize camera monitor. + + Args: + failure_threshold: Number of consecutive failures before considering disconnected + """ + self.failure_threshold = failure_threshold + self._camera_states: Dict[str, CameraConnectionState] = {} + self._lock = None # Will be initialized when needed for thread safety + + def _ensure_thread_safety(self): + """Initialize thread safety if not already done.""" + if self._lock is None: + import threading + self._lock = threading.RLock() + + def get_camera_state(self, camera_id: str) -> CameraConnectionState: + """ + Get or create camera connection state. + + Args: + camera_id: Unique camera identifier + + Returns: + Current connection state for the camera + """ + self._ensure_thread_safety() + + with self._lock: + if camera_id not in self._camera_states: + self._camera_states[camera_id] = CameraConnectionState() + logger.debug(f"Initialized connection state for camera {camera_id}") + + return self._camera_states[camera_id] + + def set_camera_connected(self, + camera_id: str, + connected: bool = True, + error_msg: Optional[str] = None) -> None: + """ + Set camera connection state and track error information. + + Args: + camera_id: Unique camera identifier + connected: Whether camera is connected + error_msg: Error message if disconnected + """ + self._ensure_thread_safety() + + current_time = time.time() + + with self._lock: + state = self.get_camera_state(camera_id) + + if connected: + # Camera is now connected + if not state.connected: + logger.info(f"Camera {camera_id} reconnected successfully") + + state.connected = True + state.last_successful_frame = current_time + state.consecutive_failures = 0 + state.disconnection_notified = False + state.last_error = None + state.last_error_time = None + + else: + # Camera is now disconnected + state.connected = False + state.consecutive_failures += 1 + state.last_error = error_msg + state.last_error_time = current_time + + if state.consecutive_failures == 1: + logger.warning(f"Camera {camera_id} connection lost: {error_msg}") + elif state.consecutive_failures >= self.failure_threshold: + logger.error(f"Camera {camera_id} has {state.consecutive_failures} consecutive failures") + + logger.debug(f"Camera {camera_id} state updated - failures: {state.consecutive_failures}") + + def is_camera_connected(self, camera_id: str) -> bool: + """ + Check if camera is currently connected. + + Args: + camera_id: Unique camera identifier + + Returns: + True if camera is connected, False otherwise + """ + self._ensure_thread_safety() + + with self._lock: + state = self._camera_states.get(camera_id) + return state.connected if state else True # Default to connected for new cameras + + def should_notify_disconnection(self, camera_id: str) -> bool: + """ + Check if we should notify backend about disconnection. + + A disconnection notification should be sent when: + 1. Camera is disconnected + 2. We haven't already notified about this disconnection + 3. We have enough consecutive failures + + Args: + camera_id: Unique camera identifier + + Returns: + True if disconnection notification should be sent + """ + self._ensure_thread_safety() + + with self._lock: + state = self._camera_states.get(camera_id) + if not state: + return False + + is_disconnected = not state.connected + not_yet_notified = not state.disconnection_notified + has_enough_failures = state.consecutive_failures >= self.failure_threshold + + should_notify = is_disconnected and not_yet_notified and has_enough_failures + + if should_notify: + logger.info(f"Camera {camera_id} qualifies for disconnection notification - " + f"failures: {state.consecutive_failures}, error: {state.last_error}") + + return should_notify + + def mark_disconnection_notified(self, camera_id: str) -> None: + """ + Mark that we've notified backend about this disconnection. + + Args: + camera_id: Unique camera identifier + """ + self._ensure_thread_safety() + + with self._lock: + if camera_id in self._camera_states: + self._camera_states[camera_id].disconnection_notified = True + logger.debug(f"Marked disconnection notification sent for camera {camera_id}") + + def get_connection_stats(self, camera_id: str) -> Dict[str, Any]: + """ + Get comprehensive connection statistics for a camera. + + Args: + camera_id: Unique camera identifier + + Returns: + Dictionary with connection statistics + """ + self._ensure_thread_safety() + + with self._lock: + state = self._camera_states.get(camera_id) + if not state: + return {"error": "Camera not found"} + + current_time = time.time() + stats = state.to_dict() + + # Add computed stats + if state.connection_start_time: + stats["uptime_seconds"] = current_time - state.connection_start_time + + if state.last_successful_frame: + stats["seconds_since_last_frame"] = current_time - state.last_successful_frame + + if state.last_error_time: + stats["seconds_since_last_error"] = current_time - state.last_error_time + + return stats + + def reset_camera_state(self, camera_id: str) -> None: + """ + Reset camera state to initial connected state. + + Args: + camera_id: Unique camera identifier + """ + self._ensure_thread_safety() + + with self._lock: + if camera_id in self._camera_states: + del self._camera_states[camera_id] + logger.info(f"Reset connection state for camera {camera_id}") + + def cleanup_inactive_cameras(self, inactive_threshold_seconds: int = 3600) -> int: + """ + Remove states for cameras inactive for too long. + + Args: + inactive_threshold_seconds: Seconds of inactivity before cleanup + + Returns: + Number of camera states cleaned up + """ + self._ensure_thread_safety() + + current_time = time.time() + cleanup_count = 0 + + with self._lock: + camera_ids_to_remove = [] + + for camera_id, state in self._camera_states.items(): + last_activity = max( + state.connection_start_time or 0, + state.last_successful_frame or 0, + state.last_error_time or 0 + ) + + if current_time - last_activity > inactive_threshold_seconds: + camera_ids_to_remove.append(camera_id) + + for camera_id in camera_ids_to_remove: + del self._camera_states[camera_id] + cleanup_count += 1 + logger.debug(f"Cleaned up inactive camera state for {camera_id}") + + if cleanup_count > 0: + logger.info(f"Cleaned up {cleanup_count} inactive camera states") + + return cleanup_count + + def get_all_camera_states(self) -> Dict[str, Dict[str, Any]]: + """ + Get connection states for all monitored cameras. + + Returns: + Dictionary mapping camera IDs to their connection states + """ + self._ensure_thread_safety() + + with self._lock: + return { + camera_id: state.to_dict() + for camera_id, state in self._camera_states.items() + } + + def get_disconnected_cameras(self) -> list[str]: + """ + Get list of currently disconnected camera IDs. + + Returns: + List of camera IDs that are currently disconnected + """ + self._ensure_thread_safety() + + with self._lock: + return [ + camera_id for camera_id, state in self._camera_states.items() + if not state.connected + ] + + +# Global camera monitor instance +camera_monitor = CameraMonitor() + + +# ===== CONVENIENCE FUNCTIONS ===== +# These provide the same interface as the original functions in app.py + +def set_camera_connected(camera_id: str, connected: bool = True, error_msg: Optional[str] = None) -> None: + """Set camera connection state and track error information.""" + camera_monitor.set_camera_connected(camera_id, connected, error_msg) + + +def is_camera_connected(camera_id: str) -> bool: + """Check if camera is currently connected.""" + return camera_monitor.is_camera_connected(camera_id) + + +def should_notify_disconnection(camera_id: str) -> bool: + """Check if we should notify backend about disconnection.""" + return camera_monitor.should_notify_disconnection(camera_id) + + +def mark_disconnection_notified(camera_id: str) -> None: + """Mark that we've notified backend about this disconnection.""" + camera_monitor.mark_disconnection_notified(camera_id) + + +def get_connection_stats(camera_id: str) -> Dict[str, Any]: + """Get comprehensive connection statistics for a camera.""" + return camera_monitor.get_connection_stats(camera_id) + + +def reset_camera_state(camera_id: str) -> None: + """Reset camera state to initial connected state.""" + camera_monitor.reset_camera_state(camera_id) \ No newline at end of file diff --git a/detector_worker/streams/frame_reader.py b/detector_worker/streams/frame_reader.py new file mode 100644 index 0000000..f8b26af --- /dev/null +++ b/detector_worker/streams/frame_reader.py @@ -0,0 +1,476 @@ +""" +Frame reading implementations for RTSP and HTTP snapshot streams. + +This module provides thread-safe frame readers for different camera stream types. +""" + +import cv2 +import time +import queue +import logging +import requests +import threading +from typing import Optional, Any +import numpy as np + +from ..core.constants import ( + DEFAULT_RECONNECT_INTERVAL_SEC, + DEFAULT_MAX_RETRIES, + HTTP_SNAPSHOT_TIMEOUT, + SHARED_STREAM_BUFFER_SIZE +) +from ..core.exceptions import ( + CameraConnectionError, + FrameReadError, + create_stream_error +) + +logger = logging.getLogger(__name__) + + +def fetch_snapshot(url: str, timeout: int = HTTP_SNAPSHOT_TIMEOUT) -> Optional[np.ndarray]: + """ + Fetch a single snapshot from HTTP/HTTPS URL. + + Args: + url: HTTP/HTTPS URL to fetch snapshot from + timeout: Request timeout in seconds + + Returns: + Decoded image frame or None if fetch failed + """ + try: + logger.debug(f"Fetching snapshot from {url}") + response = requests.get(url, timeout=timeout, stream=True) + response.raise_for_status() + + # Check if response has content + if not response.content: + logger.warning(f"Empty response from snapshot URL: {url}") + return None + + # Decode image from response bytes + img_array = np.frombuffer(response.content, dtype=np.uint8) + frame = cv2.imdecode(img_array, cv2.IMREAD_COLOR) + + if frame is None: + logger.warning(f"Failed to decode image from snapshot URL: {url}") + return None + + logger.debug(f"Successfully fetched snapshot from {url}, shape: {frame.shape}") + return frame + + except requests.exceptions.Timeout: + logger.warning(f"Timeout fetching snapshot from {url}") + return None + except requests.exceptions.ConnectionError: + logger.warning(f"Connection error fetching snapshot from {url}") + return None + except requests.exceptions.HTTPError as e: + logger.warning(f"HTTP error fetching snapshot from {url}: {e}") + return None + except Exception as e: + logger.error(f"Unexpected error fetching snapshot from {url}: {e}") + return None + + +class RTSPFrameReader: + """Thread-safe RTSP frame reader.""" + + def __init__(self, + camera_id: str, + rtsp_url: str, + buffer: queue.Queue, + stop_event: threading.Event, + reconnect_interval: int = DEFAULT_RECONNECT_INTERVAL_SEC, + max_retries: int = DEFAULT_MAX_RETRIES, + connection_callback=None): + """ + Initialize RTSP frame reader. + + Args: + camera_id: Unique camera identifier + rtsp_url: RTSP stream URL + buffer: Queue to put frames into + stop_event: Event to signal thread shutdown + reconnect_interval: Seconds between reconnection attempts + max_retries: Maximum retry attempts (-1 for unlimited) + connection_callback: Callback function for connection state changes + """ + self.camera_id = camera_id + self.rtsp_url = rtsp_url + self.buffer = buffer + self.stop_event = stop_event + self.reconnect_interval = reconnect_interval + self.max_retries = max_retries + self.connection_callback = connection_callback + + self.cap: Optional[cv2.VideoCapture] = None + self.retries = 0 + self.frame_count = 0 + self.last_log_time = time.time() + + def _set_connection_state(self, connected: bool, error_msg: Optional[str] = None): + """Update connection state via callback.""" + if self.connection_callback: + self.connection_callback(self.camera_id, connected, error_msg) + + def _initialize_capture(self) -> bool: + """Initialize video capture.""" + try: + self.cap = cv2.VideoCapture(self.rtsp_url) + if self.cap.isOpened(): + # Log camera properties + width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = self.cap.get(cv2.CAP_PROP_FPS) + logger.info(f"Camera {self.camera_id} opened: {width}x{height}, FPS: {fps}") + self._set_connection_state(True) + return True + else: + logger.error(f"Camera {self.camera_id} failed to open") + self._set_connection_state(False, "Failed to open camera") + return False + except Exception as e: + error_msg = f"Failed to initialize capture: {e}" + logger.error(f"Camera {self.camera_id}: {error_msg}") + self._set_connection_state(False, error_msg) + return False + + def _reconnect(self) -> bool: + """Attempt to reconnect to RTSP stream.""" + if self.cap: + self.cap.release() + self.cap = None + + logger.info(f"Attempting to reconnect RTSP stream for camera: {self.camera_id}") + time.sleep(self.reconnect_interval) + + return self._initialize_capture() + + def _read_frame(self) -> Optional[np.ndarray]: + """Read a single frame from the stream.""" + if not self.cap or not self.cap.isOpened(): + return None + + try: + ret, frame = self.cap.read() + if not ret: + return None + + # Update statistics + self.frame_count += 1 + current_time = time.time() + + # Log frame stats every 5 seconds + if current_time - self.last_log_time > 5: + elapsed = current_time - self.last_log_time + logger.info(f"Camera {self.camera_id}: Read {self.frame_count} frames in {elapsed:.1f}s") + self.frame_count = 0 + self.last_log_time = current_time + + return frame + + except cv2.error as e: + raise FrameReadError(f"OpenCV error reading frame: {e}") + + def _put_frame(self, frame: np.ndarray): + """Put frame into buffer, removing old frame if necessary.""" + # Remove old frame if buffer is full + if not self.buffer.empty(): + try: + self.buffer.get_nowait() + logger.debug(f"Removed old frame from buffer for camera {self.camera_id}") + except queue.Empty: + pass + + self.buffer.put(frame) + logger.debug(f"Added frame to buffer for camera {self.camera_id}, buffer size: {self.buffer.qsize()}") + + def run(self): + """Main frame reading loop.""" + logger.info(f"Starting RTSP frame reader for camera {self.camera_id}") + + try: + # Initialize capture + if not self._initialize_capture(): + return + + while not self.stop_event.is_set(): + try: + frame = self._read_frame() + + if frame is None: + # Connection lost + self.retries += 1 + error_msg = f"Connection lost, retry {self.retries}/{self.max_retries}" + logger.warning(f"Camera {self.camera_id}: {error_msg}") + self._set_connection_state(False, error_msg) + + # Check max retries + if self.retries > self.max_retries and self.max_retries != -1: + logger.error(f"Camera {self.camera_id}: Max retries reached, stopping") + self._set_connection_state(False, "Max retries reached") + break + + # Attempt reconnection + if not self._reconnect(): + continue + + # Reset retry counter on successful reconnection + self.retries = 0 + continue + + # Successfully read frame + logger.debug(f"Camera {self.camera_id}: Read frame, shape: {frame.shape}") + self.retries = 0 + self._set_connection_state(True) + self._put_frame(frame) + + # Short sleep to avoid CPU overuse + time.sleep(0.01) + + except FrameReadError as e: + self.retries += 1 + error_msg = f"Frame read error: {e}" + logger.error(f"Camera {self.camera_id}: {error_msg}") + self._set_connection_state(False, error_msg) + + # Check max retries + if self.retries > self.max_retries and self.max_retries != -1: + logger.error(f"Camera {self.camera_id}: Max retries reached after error") + self._set_connection_state(False, "Max retries reached after error") + break + + # Attempt reconnection + if not self._reconnect(): + continue + + except Exception as e: + error_msg = f"Unexpected error: {str(e)}" + logger.error(f"Camera {self.camera_id}: {error_msg}", exc_info=True) + self._set_connection_state(False, error_msg) + break + + except Exception as e: + logger.error(f"Error in RTSP frame reader for camera {self.camera_id}: {str(e)}", exc_info=True) + finally: + logger.info(f"RTSP frame reader for camera {self.camera_id} is exiting") + if self.cap and self.cap.isOpened(): + self.cap.release() + + +class SnapshotFrameReader: + """Thread-safe HTTP snapshot frame reader.""" + + def __init__(self, + camera_id: str, + snapshot_url: str, + snapshot_interval: int, + buffer: queue.Queue, + stop_event: threading.Event, + max_retries: int = DEFAULT_MAX_RETRIES, + connection_callback=None): + """ + Initialize snapshot frame reader. + + Args: + camera_id: Unique camera identifier + snapshot_url: HTTP/HTTPS snapshot URL + snapshot_interval: Interval between snapshots in milliseconds + buffer: Queue to put frames into + stop_event: Event to signal thread shutdown + max_retries: Maximum retry attempts (-1 for unlimited) + connection_callback: Callback function for connection state changes + """ + self.camera_id = camera_id + self.snapshot_url = snapshot_url + self.snapshot_interval = snapshot_interval + self.buffer = buffer + self.stop_event = stop_event + self.max_retries = max_retries + self.connection_callback = connection_callback + + self.retries = 0 + self.consecutive_failures = 0 + self.frame_count = 0 + self.last_log_time = time.time() + + def _set_connection_state(self, connected: bool, error_msg: Optional[str] = None): + """Update connection state via callback.""" + if self.connection_callback: + self.connection_callback(self.camera_id, connected, error_msg) + + def _calculate_backoff_delay(self) -> float: + """Calculate exponential backoff delay based on consecutive failures.""" + interval_seconds = self.snapshot_interval / 1000.0 + backoff_delay = min(30, max(1, min(2 ** min(self.consecutive_failures - 1, 6), interval_seconds * 2))) + return backoff_delay + + def _test_connectivity(self): + """Test connectivity to snapshot URL.""" + if self.consecutive_failures % 5 == 1: # Every 5th failure + try: + test_response = requests.get(self.snapshot_url, timeout=(2, 5), stream=False) + logger.info(f"Camera {self.camera_id}: Connectivity test result: {test_response.status_code}") + except Exception as test_error: + logger.warning(f"Camera {self.camera_id}: Connectivity test failed: {test_error}") + + def _put_frame(self, frame: np.ndarray): + """Put frame into buffer, removing old frame if necessary.""" + # Remove old frame if buffer is full + if not self.buffer.empty(): + try: + self.buffer.get_nowait() + logger.debug(f"Removed old snapshot from buffer for camera {self.camera_id}") + except queue.Empty: + pass + + self.buffer.put(frame) + logger.debug(f"Added snapshot to buffer for camera {self.camera_id}, buffer size: {self.buffer.qsize()}") + + def run(self): + """Main snapshot reading loop.""" + logger.info(f"Starting snapshot reader for camera {self.camera_id} from {self.snapshot_url}") + + interval_seconds = self.snapshot_interval / 1000.0 + logger.info(f"Snapshot interval for camera {self.camera_id}: {interval_seconds}s") + + # Initialize connection state + self._set_connection_state(True) + + try: + while not self.stop_event.is_set(): + try: + start_time = time.time() + frame = fetch_snapshot(self.snapshot_url) + + if frame is None: + # Failed to fetch snapshot + self.consecutive_failures += 1 + self.retries += 1 + error_msg = f"Failed to fetch snapshot, consecutive failures: {self.consecutive_failures}" + logger.warning(f"Camera {self.camera_id}: {error_msg}") + self._set_connection_state(False, error_msg) + + # Test connectivity periodically + self._test_connectivity() + + # Check max retries + if self.retries > self.max_retries and self.max_retries != -1: + logger.error(f"Camera {self.camera_id}: Max retries reached for snapshot, stopping") + self._set_connection_state(False, "Max retries reached for snapshot") + break + + # Exponential backoff + backoff_delay = self._calculate_backoff_delay() + logger.debug(f"Camera {self.camera_id}: Backing off for {backoff_delay:.1f}s") + if self.stop_event.wait(backoff_delay): + break # Exit if stop event set during backoff + continue + + # Successfully fetched snapshot + self.consecutive_failures = 0 # Reset on success + self.retries = 0 + self.frame_count += 1 + current_time = time.time() + + # Log frame stats every 5 seconds + if current_time - self.last_log_time > 5: + elapsed = current_time - self.last_log_time + logger.info(f"Camera {self.camera_id}: Fetched {self.frame_count} snapshots in {elapsed:.1f}s") + self.frame_count = 0 + self.last_log_time = current_time + + logger.debug(f"Camera {self.camera_id}: Fetched snapshot, shape: {frame.shape}") + self._set_connection_state(True) + self._put_frame(frame) + + # Wait for interval + elapsed = time.time() - start_time + sleep_time = max(interval_seconds - elapsed, 0) + if sleep_time > 0: + if self.stop_event.wait(sleep_time): + break # Exit if stop event set during sleep + + except Exception as e: + self.consecutive_failures += 1 + self.retries += 1 + error_msg = f"Unexpected error: {str(e)}" + logger.error(f"Camera {self.camera_id}: {error_msg}", exc_info=True) + self._set_connection_state(False, error_msg) + + # Check max retries + if self.retries > self.max_retries and self.max_retries != -1: + logger.error(f"Camera {self.camera_id}: Max retries reached after error") + self._set_connection_state(False, "Max retries reached after error") + break + + # Exponential backoff for exceptions too + backoff_delay = self._calculate_backoff_delay() + logger.debug(f"Camera {self.camera_id}: Exception backoff for {backoff_delay:.1f}s") + if self.stop_event.wait(backoff_delay): + break # Exit if stop event set during backoff + + except Exception as e: + logger.error(f"Error in snapshot reader for camera {self.camera_id}: {str(e)}", exc_info=True) + finally: + logger.info(f"Snapshot reader for camera {self.camera_id} is exiting") + + +def create_frame_reader_thread(camera_id: str, + rtsp_url: Optional[str] = None, + snapshot_url: Optional[str] = None, + snapshot_interval: Optional[int] = None, + buffer: Optional[queue.Queue] = None, + stop_event: Optional[threading.Event] = None, + connection_callback=None) -> Optional[threading.Thread]: + """ + Create appropriate frame reader thread based on stream type. + + Args: + camera_id: Unique camera identifier + rtsp_url: RTSP stream URL (for RTSP streams) + snapshot_url: HTTP snapshot URL (for snapshot streams) + snapshot_interval: Snapshot interval in milliseconds + buffer: Frame buffer queue + stop_event: Thread stop event + connection_callback: Connection state callback + + Returns: + Configured thread ready to start, or None if invalid parameters + """ + if not buffer: + buffer = queue.Queue(maxsize=SHARED_STREAM_BUFFER_SIZE) + if not stop_event: + stop_event = threading.Event() + + if snapshot_url and snapshot_interval: + # Create snapshot reader + reader = SnapshotFrameReader( + camera_id=camera_id, + snapshot_url=snapshot_url, + snapshot_interval=snapshot_interval, + buffer=buffer, + stop_event=stop_event, + connection_callback=connection_callback + ) + thread = threading.Thread(target=reader.run, name=f"snapshot-{camera_id}") + + elif rtsp_url: + # Create RTSP reader + reader = RTSPFrameReader( + camera_id=camera_id, + rtsp_url=rtsp_url, + buffer=buffer, + stop_event=stop_event, + connection_callback=connection_callback + ) + thread = threading.Thread(target=reader.run, name=f"rtsp-{camera_id}") + + else: + logger.error(f"No valid URL provided for camera {camera_id}") + return None + + thread.daemon = True + return thread \ No newline at end of file diff --git a/detector_worker/streams/stream_manager.py b/detector_worker/streams/stream_manager.py new file mode 100644 index 0000000..c6d2155 --- /dev/null +++ b/detector_worker/streams/stream_manager.py @@ -0,0 +1,572 @@ +""" +Stream lifecycle management and coordination. + +This module provides centralized management of camera streams including +lifecycle management, resource allocation, and stream coordination. +""" + +import time +import queue +import logging +import threading +from typing import Dict, List, Any, Optional, Tuple, Set +from dataclasses import dataclass, field +from datetime import datetime + +from ..core.constants import ( + DEFAULT_MAX_STREAMS, + SHARED_STREAM_BUFFER_SIZE, + DEFAULT_RECONNECT_INTERVAL_SEC, + DEFAULT_MAX_RETRIES +) +from ..core.exceptions import StreamError, create_stream_error +from ..streams.frame_reader import create_frame_reader_thread +from ..streams.camera_monitor import set_camera_connected + +logger = logging.getLogger(__name__) + + +@dataclass +class StreamInfo: + """Information about a single camera stream.""" + camera_id: str + stream_url: str + stream_type: str # "rtsp" or "snapshot" + snapshot_interval: Optional[int] = None + buffer: Optional[queue.Queue] = None + stop_event: Optional[threading.Event] = None + thread: Optional[threading.Thread] = None + subscribers: Set[str] = field(default_factory=set) + created_at: float = field(default_factory=time.time) + last_frame_time: Optional[float] = None + frame_count: int = 0 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return { + "camera_id": self.camera_id, + "stream_url": self.stream_url, + "stream_type": self.stream_type, + "snapshot_interval": self.snapshot_interval, + "subscriber_count": len(self.subscribers), + "subscribers": list(self.subscribers), + "created_at": self.created_at, + "last_frame_time": self.last_frame_time, + "frame_count": self.frame_count, + "is_active": self.thread is not None and self.thread.is_alive() + } + + +@dataclass +class StreamSubscription: + """Information about a stream subscription.""" + subscription_id: str + camera_id: str + subscriber_id: str + created_at: float = field(default_factory=time.time) + last_access: float = field(default_factory=time.time) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return { + "subscription_id": self.subscription_id, + "camera_id": self.camera_id, + "subscriber_id": self.subscriber_id, + "created_at": self.created_at, + "last_access": self.last_access + } + + +class StreamManager: + """ + Manages camera stream lifecycle and resource allocation. + + This class provides centralized management of camera streams including: + - Stream lifecycle management (start/stop/restart) + - Resource allocation and sharing + - Subscriber management + - Connection state monitoring + """ + + def __init__(self, max_streams: int = DEFAULT_MAX_STREAMS): + """ + Initialize stream manager. + + Args: + max_streams: Maximum number of concurrent streams + """ + self.max_streams = max_streams + self._streams: Dict[str, StreamInfo] = {} + self._subscriptions: Dict[str, StreamSubscription] = {} + self._lock = None + + def _ensure_thread_safety(self): + """Initialize thread safety if not already done.""" + if self._lock is None: + import threading + self._lock = threading.RLock() + + def _connection_state_callback(self, camera_id: str, connected: bool, error_msg: Optional[str] = None): + """Callback for connection state changes.""" + set_camera_connected(camera_id, connected, error_msg) + + def _create_stream_info(self, + camera_id: str, + rtsp_url: Optional[str] = None, + snapshot_url: Optional[str] = None, + snapshot_interval: Optional[int] = None) -> StreamInfo: + """Create StreamInfo object based on stream type.""" + if snapshot_url and snapshot_interval: + return StreamInfo( + camera_id=camera_id, + stream_url=snapshot_url, + stream_type="snapshot", + snapshot_interval=snapshot_interval, + buffer=queue.Queue(maxsize=SHARED_STREAM_BUFFER_SIZE), + stop_event=threading.Event() + ) + elif rtsp_url: + return StreamInfo( + camera_id=camera_id, + stream_url=rtsp_url, + stream_type="rtsp", + buffer=queue.Queue(maxsize=SHARED_STREAM_BUFFER_SIZE), + stop_event=threading.Event() + ) + else: + raise ValueError("Must provide either RTSP URL or snapshot URL with interval") + + def create_subscription(self, + subscription_id: str, + camera_id: str, + subscriber_id: str, + rtsp_url: Optional[str] = None, + snapshot_url: Optional[str] = None, + snapshot_interval: Optional[int] = None) -> bool: + """ + Create a stream subscription. + + Args: + subscription_id: Unique subscription identifier + camera_id: Camera identifier + subscriber_id: Subscriber identifier + rtsp_url: RTSP stream URL (for RTSP streams) + snapshot_url: HTTP snapshot URL (for snapshot streams) + snapshot_interval: Snapshot interval in milliseconds + + Returns: + True if subscription was created successfully + """ + self._ensure_thread_safety() + + with self._lock: + try: + # Check if subscription already exists + if subscription_id in self._subscriptions: + logger.warning(f"Subscription {subscription_id} already exists") + return False + + # Check stream limit + if len(self._streams) >= self.max_streams and camera_id not in self._streams: + logger.error(f"Maximum streams ({self.max_streams}) reached, cannot create new stream for camera {camera_id}") + return False + + # Create or get existing stream + if camera_id not in self._streams: + stream_info = self._create_stream_info( + camera_id, rtsp_url, snapshot_url, snapshot_interval + ) + self._streams[camera_id] = stream_info + + # Create and start frame reader thread + thread = create_frame_reader_thread( + camera_id=camera_id, + rtsp_url=rtsp_url, + snapshot_url=snapshot_url, + snapshot_interval=snapshot_interval, + buffer=stream_info.buffer, + stop_event=stream_info.stop_event, + connection_callback=self._connection_state_callback + ) + + if thread: + stream_info.thread = thread + thread.start() + logger.info(f"Created new {stream_info.stream_type} stream for camera {camera_id}") + else: + # Clean up failed stream + del self._streams[camera_id] + return False + + # Add subscriber to stream + stream_info = self._streams[camera_id] + stream_info.subscribers.add(subscription_id) + + # Create subscription record + subscription = StreamSubscription( + subscription_id=subscription_id, + camera_id=camera_id, + subscriber_id=subscriber_id + ) + self._subscriptions[subscription_id] = subscription + + logger.info(f"Created subscription {subscription_id} for camera {camera_id}, subscribers: {len(stream_info.subscribers)}") + return True + + except Exception as e: + logger.error(f"Error creating subscription {subscription_id}: {e}") + return False + + def remove_subscription(self, subscription_id: str) -> bool: + """ + Remove a stream subscription. + + Args: + subscription_id: Unique subscription identifier + + Returns: + True if subscription was removed successfully + """ + self._ensure_thread_safety() + + with self._lock: + if subscription_id not in self._subscriptions: + logger.warning(f"Subscription {subscription_id} not found") + return False + + subscription = self._subscriptions[subscription_id] + camera_id = subscription.camera_id + + # Remove subscription + del self._subscriptions[subscription_id] + + # Remove subscriber from stream if stream exists + if camera_id in self._streams: + stream_info = self._streams[camera_id] + stream_info.subscribers.discard(subscription_id) + + logger.info(f"Removed subscription {subscription_id} for camera {camera_id}, remaining subscribers: {len(stream_info.subscribers)}") + + # Stop stream if no more subscribers + if not stream_info.subscribers: + self._stop_stream(camera_id) + + return True + + def _stop_stream(self, camera_id: str) -> None: + """Stop a stream and clean up resources.""" + if camera_id not in self._streams: + return + + stream_info = self._streams[camera_id] + + # Signal thread to stop + if stream_info.stop_event: + stream_info.stop_event.set() + + # Wait for thread to finish + if stream_info.thread and stream_info.thread.is_alive(): + stream_info.thread.join(timeout=5) + if stream_info.thread.is_alive(): + logger.warning(f"Stream thread for camera {camera_id} did not stop gracefully") + + # Clean up + del self._streams[camera_id] + logger.info(f"Stopped {stream_info.stream_type} stream for camera {camera_id}") + + def get_frame(self, subscription_id: str, timeout: float = 0.1) -> Optional[Any]: + """ + Get the latest frame for a subscription. + + Args: + subscription_id: Unique subscription identifier + timeout: Timeout for frame retrieval in seconds + + Returns: + Latest frame or None if not available + """ + self._ensure_thread_safety() + + with self._lock: + if subscription_id not in self._subscriptions: + return None + + subscription = self._subscriptions[subscription_id] + camera_id = subscription.camera_id + + if camera_id not in self._streams: + return None + + stream_info = self._streams[camera_id] + subscription.last_access = time.time() + + try: + frame = stream_info.buffer.get(timeout=timeout) + stream_info.last_frame_time = time.time() + stream_info.frame_count += 1 + return frame + except queue.Empty: + return None + except Exception as e: + logger.error(f"Error getting frame for subscription {subscription_id}: {e}") + return None + + def is_stream_active(self, camera_id: str) -> bool: + """ + Check if a stream is active. + + Args: + camera_id: Camera identifier + + Returns: + True if stream is active + """ + self._ensure_thread_safety() + + with self._lock: + if camera_id not in self._streams: + return False + + stream_info = self._streams[camera_id] + return stream_info.thread is not None and stream_info.thread.is_alive() + + def get_stream_stats(self, camera_id: str) -> Optional[Dict[str, Any]]: + """ + Get statistics for a stream. + + Args: + camera_id: Camera identifier + + Returns: + Stream statistics or None if stream not found + """ + self._ensure_thread_safety() + + with self._lock: + if camera_id not in self._streams: + return None + + stream_info = self._streams[camera_id] + current_time = time.time() + + stats = stream_info.to_dict() + stats["uptime_seconds"] = current_time - stream_info.created_at + + if stream_info.last_frame_time: + stats["seconds_since_last_frame"] = current_time - stream_info.last_frame_time + + return stats + + def get_subscription_info(self, subscription_id: str) -> Optional[Dict[str, Any]]: + """ + Get information about a subscription. + + Args: + subscription_id: Unique subscription identifier + + Returns: + Subscription information or None if not found + """ + self._ensure_thread_safety() + + with self._lock: + if subscription_id not in self._subscriptions: + return None + + return self._subscriptions[subscription_id].to_dict() + + def get_all_streams(self) -> Dict[str, Dict[str, Any]]: + """ + Get information about all active streams. + + Returns: + Dictionary mapping camera IDs to stream information + """ + self._ensure_thread_safety() + + with self._lock: + return { + camera_id: stream_info.to_dict() + for camera_id, stream_info in self._streams.items() + } + + def get_all_subscriptions(self) -> Dict[str, Dict[str, Any]]: + """ + Get information about all active subscriptions. + + Returns: + Dictionary mapping subscription IDs to subscription information + """ + self._ensure_thread_safety() + + with self._lock: + return { + sub_id: subscription.to_dict() + for sub_id, subscription in self._subscriptions.items() + } + + def cleanup_inactive_streams(self, inactive_threshold_seconds: int = 3600) -> int: + """ + Clean up streams that have been inactive for too long. + + Args: + inactive_threshold_seconds: Seconds of inactivity before cleanup + + Returns: + Number of streams cleaned up + """ + self._ensure_thread_safety() + + current_time = time.time() + cleanup_count = 0 + + with self._lock: + streams_to_remove = [] + + for camera_id, stream_info in self._streams.items(): + # Check if stream has subscribers + if stream_info.subscribers: + continue + + # Check if stream has been inactive + last_activity = max( + stream_info.created_at, + stream_info.last_frame_time or 0 + ) + + if current_time - last_activity > inactive_threshold_seconds: + streams_to_remove.append(camera_id) + + for camera_id in streams_to_remove: + self._stop_stream(camera_id) + cleanup_count += 1 + logger.info(f"Cleaned up inactive stream for camera {camera_id}") + + if cleanup_count > 0: + logger.info(f"Cleaned up {cleanup_count} inactive streams") + + return cleanup_count + + def restart_stream(self, camera_id: str) -> bool: + """ + Restart a stream. + + Args: + camera_id: Camera identifier + + Returns: + True if stream was restarted successfully + """ + self._ensure_thread_safety() + + with self._lock: + if camera_id not in self._streams: + logger.warning(f"Cannot restart stream for camera {camera_id}: stream not found") + return False + + stream_info = self._streams[camera_id] + subscribers = stream_info.subscribers.copy() + stream_url = stream_info.stream_url + stream_type = stream_info.stream_type + snapshot_interval = stream_info.snapshot_interval + + # Stop current stream + self._stop_stream(camera_id) + + # Recreate stream + try: + new_stream_info = self._create_stream_info( + camera_id, + rtsp_url=stream_url if stream_type == "rtsp" else None, + snapshot_url=stream_url if stream_type == "snapshot" else None, + snapshot_interval=snapshot_interval + ) + new_stream_info.subscribers = subscribers + self._streams[camera_id] = new_stream_info + + # Create and start new frame reader thread + thread = create_frame_reader_thread( + camera_id=camera_id, + rtsp_url=stream_url if stream_type == "rtsp" else None, + snapshot_url=stream_url if stream_type == "snapshot" else None, + snapshot_interval=snapshot_interval, + buffer=new_stream_info.buffer, + stop_event=new_stream_info.stop_event, + connection_callback=self._connection_state_callback + ) + + if thread: + new_stream_info.thread = thread + thread.start() + logger.info(f"Restarted {stream_type} stream for camera {camera_id}") + return True + else: + # Clean up failed restart + del self._streams[camera_id] + return False + + except Exception as e: + logger.error(f"Error restarting stream for camera {camera_id}: {e}") + return False + + def shutdown_all(self) -> None: + """Shutdown all streams and clean up resources.""" + self._ensure_thread_safety() + + with self._lock: + logger.info("Shutting down all streams...") + + # Stop all streams + camera_ids = list(self._streams.keys()) + for camera_id in camera_ids: + self._stop_stream(camera_id) + + # Clear all subscriptions + self._subscriptions.clear() + + logger.info("All streams shut down successfully") + + +# Global stream manager instance +stream_manager = StreamManager() + + +# ===== CONVENIENCE FUNCTIONS ===== +# These provide a simplified interface for common operations + +def create_stream_subscription(subscription_id: str, + camera_id: str, + subscriber_id: str, + rtsp_url: Optional[str] = None, + snapshot_url: Optional[str] = None, + snapshot_interval: Optional[int] = None) -> bool: + """Create a stream subscription using global stream manager.""" + return stream_manager.create_subscription( + subscription_id, camera_id, subscriber_id, rtsp_url, snapshot_url, snapshot_interval + ) + + +def remove_stream_subscription(subscription_id: str) -> bool: + """Remove a stream subscription using global stream manager.""" + return stream_manager.remove_subscription(subscription_id) + + +def get_stream_frame(subscription_id: str, timeout: float = 0.1) -> Optional[Any]: + """Get the latest frame for a subscription using global stream manager.""" + return stream_manager.get_frame(subscription_id, timeout) + + +def is_stream_active(camera_id: str) -> bool: + """Check if a stream is active using global stream manager.""" + return stream_manager.is_stream_active(camera_id) + + +def get_stream_statistics() -> Dict[str, Any]: + """Get comprehensive stream statistics.""" + return { + "streams": stream_manager.get_all_streams(), + "subscriptions": stream_manager.get_all_subscriptions(), + "total_streams": len(stream_manager._streams), + "total_subscriptions": len(stream_manager._subscriptions), + "max_streams": stream_manager.max_streams + } \ No newline at end of file