Refactor: PHASE 2: Core Module Extraction
This commit is contained in:
parent
96bedae80a
commit
4e9ae6bcc4
7 changed files with 3684 additions and 0 deletions
483
detector_worker/detection/stability_validator.py
Normal file
483
detector_worker/detection/stability_validator.py
Normal file
|
@ -0,0 +1,483 @@
|
||||||
|
"""
|
||||||
|
Detection stability validation and lightweight detection functionality.
|
||||||
|
|
||||||
|
This module provides validation functionality for detection pipelines including
|
||||||
|
pipeline execution pre-validation and lightweight detection for validation phases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Any, Optional, Tuple, Union
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ..core.constants import (
|
||||||
|
LIGHTWEIGHT_DETECTION_MIN_CONFIDENCE,
|
||||||
|
LIGHTWEIGHT_DETECTION_MIN_BBOX_AREA_RATIO,
|
||||||
|
DEFAULT_MIN_CONFIDENCE
|
||||||
|
)
|
||||||
|
from ..core.exceptions import ValidationError, create_detection_error
|
||||||
|
from ..detection.detection_result import LightweightDetectionResult, BoundingBox
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineValidationResult:
|
||||||
|
"""Result of pipeline execution validation."""
|
||||||
|
is_valid: bool
|
||||||
|
missing_branches: List[str] = field(default_factory=list)
|
||||||
|
required_branches: List[str] = field(default_factory=list)
|
||||||
|
available_classes: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"is_valid": self.is_valid,
|
||||||
|
"missing_branches": self.missing_branches.copy(),
|
||||||
|
"required_branches": self.required_branches.copy(),
|
||||||
|
"available_classes": self.available_classes.copy()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LightweightDetectionResult:
|
||||||
|
"""Result from lightweight detection."""
|
||||||
|
car_detected: bool
|
||||||
|
best_detection: Optional[Dict[str, Any]] = None
|
||||||
|
validation_passed: bool = False
|
||||||
|
confidence: Optional[float] = None
|
||||||
|
bbox_area_ratio: Optional[float] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
result = {
|
||||||
|
"car_detected": self.car_detected,
|
||||||
|
"validation_passed": self.validation_passed
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.best_detection:
|
||||||
|
result["best_detection"] = self.best_detection.copy()
|
||||||
|
if self.confidence is not None:
|
||||||
|
result["confidence"] = self.confidence
|
||||||
|
if self.bbox_area_ratio is not None:
|
||||||
|
result["bbox_area_ratio"] = self.bbox_area_ratio
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityValidator:
|
||||||
|
"""
|
||||||
|
Validates detection stability and pipeline execution requirements.
|
||||||
|
|
||||||
|
This class provides functionality for:
|
||||||
|
- Pipeline execution pre-validation
|
||||||
|
- Lightweight detection for validation phases
|
||||||
|
- Detection threshold validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize stability validator."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def validate_pipeline_execution(self,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
regions_dict: Dict[str, Any]) -> PipelineValidationResult:
|
||||||
|
"""
|
||||||
|
Pre-validate that all required branches will execute successfully before
|
||||||
|
committing to Redis actions and database records.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Pipeline node configuration
|
||||||
|
regions_dict: Dictionary mapping class names to detection regions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PipelineValidationResult with validation status and details
|
||||||
|
"""
|
||||||
|
# Get all branches that parallel actions are waiting for
|
||||||
|
required_branches = set()
|
||||||
|
|
||||||
|
for action in node.get("parallelActions", []):
|
||||||
|
if action.get("type") == "postgresql_update_combined":
|
||||||
|
wait_for_branches = action.get("waitForBranches", [])
|
||||||
|
required_branches.update(wait_for_branches)
|
||||||
|
|
||||||
|
if not required_branches:
|
||||||
|
# No parallel actions requiring specific branches
|
||||||
|
logger.debug("No parallel actions with waitForBranches - validation passes")
|
||||||
|
return PipelineValidationResult(
|
||||||
|
is_valid=True,
|
||||||
|
required_branches=[],
|
||||||
|
available_classes=list(regions_dict.keys())
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Pre-validation: checking if required branches {list(required_branches)} will execute")
|
||||||
|
|
||||||
|
# Check each required branch
|
||||||
|
missing_branches = []
|
||||||
|
|
||||||
|
for branch in node.get("branches", []):
|
||||||
|
branch_id = branch["modelId"]
|
||||||
|
|
||||||
|
if branch_id not in required_branches:
|
||||||
|
continue # This branch is not required by parallel actions
|
||||||
|
|
||||||
|
# Check if this branch would be triggered
|
||||||
|
trigger_classes = branch.get("triggerClasses", [])
|
||||||
|
min_conf = branch.get("minConfidence", 0)
|
||||||
|
|
||||||
|
branch_triggered = False
|
||||||
|
for det_class in regions_dict:
|
||||||
|
det_confidence = regions_dict[det_class]["confidence"]
|
||||||
|
|
||||||
|
if (det_class in trigger_classes and det_confidence >= min_conf):
|
||||||
|
branch_triggered = True
|
||||||
|
logger.debug(f"Pre-validation: branch {branch_id} WILL be triggered by {det_class} (conf={det_confidence:.3f} >= {min_conf})")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not branch_triggered:
|
||||||
|
missing_branches.append(branch_id)
|
||||||
|
logger.warning(f"Pre-validation: branch {branch_id} will NOT be triggered - no matching classes or insufficient confidence")
|
||||||
|
logger.debug(f" Required: {trigger_classes} with min_conf={min_conf}")
|
||||||
|
logger.debug(f" Available: {[(cls, regions_dict[cls]['confidence']) for cls in regions_dict]}")
|
||||||
|
|
||||||
|
is_valid = len(missing_branches) == 0
|
||||||
|
|
||||||
|
if missing_branches:
|
||||||
|
logger.error(f"Pipeline pre-validation FAILED: required branches {missing_branches} will not execute")
|
||||||
|
else:
|
||||||
|
logger.info(f"Pipeline pre-validation PASSED: all required branches {list(required_branches)} will execute")
|
||||||
|
|
||||||
|
return PipelineValidationResult(
|
||||||
|
is_valid=is_valid,
|
||||||
|
missing_branches=missing_branches,
|
||||||
|
required_branches=list(required_branches),
|
||||||
|
available_classes=list(regions_dict.keys())
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply_detection_filters(self,
|
||||||
|
box: Any,
|
||||||
|
model: Any,
|
||||||
|
min_confidence: float,
|
||||||
|
trigger_classes: Optional[List[str]] = None,
|
||||||
|
trigger_class_indices: Optional[List[int]] = None) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Apply confidence and class filtering to a detection box."""
|
||||||
|
try:
|
||||||
|
# Extract detection info
|
||||||
|
xyxy = box.xyxy[0].cpu().numpy()
|
||||||
|
conf = box.conf[0].cpu().numpy()
|
||||||
|
cls_id = int(box.cls[0].cpu().numpy())
|
||||||
|
class_name = model.names[cls_id]
|
||||||
|
|
||||||
|
# Apply confidence threshold
|
||||||
|
if conf < min_confidence:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Apply trigger class filtering if specified
|
||||||
|
if trigger_class_indices and cls_id not in trigger_class_indices:
|
||||||
|
return None
|
||||||
|
if trigger_classes and class_name not in trigger_classes:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"class": class_name,
|
||||||
|
"confidence": float(conf),
|
||||||
|
"bbox": [int(x) for x in xyxy],
|
||||||
|
"class_id": cls_id,
|
||||||
|
"xyxy": xyxy
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing detection box: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def run_lightweight_detection_with_validation(self,
|
||||||
|
frame: np.ndarray,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
min_confidence: float = LIGHTWEIGHT_DETECTION_MIN_CONFIDENCE,
|
||||||
|
min_bbox_area_ratio: float = LIGHTWEIGHT_DETECTION_MIN_BBOX_AREA_RATIO) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Run lightweight detection with validation rules for session ID triggering.
|
||||||
|
Returns detection info only if it passes validation thresholds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: Input frame for detection
|
||||||
|
node: Pipeline node configuration
|
||||||
|
min_confidence: Minimum confidence threshold for validation
|
||||||
|
min_bbox_area_ratio: Minimum bounding box area ratio for validation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with validation results and detection info
|
||||||
|
"""
|
||||||
|
logger.debug(f"Running lightweight detection with validation: {node['modelId']} (conf>={min_confidence}, bbox_area>={min_bbox_area_ratio})")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run basic detection only - no branches, no actions
|
||||||
|
model = node["model"]
|
||||||
|
trigger_classes = node.get("triggerClasses", [])
|
||||||
|
trigger_class_indices = node.get("triggerClassIndices")
|
||||||
|
|
||||||
|
# Run YOLO inference
|
||||||
|
res = model(frame, verbose=False)
|
||||||
|
|
||||||
|
best_detection = None
|
||||||
|
frame_height, frame_width = frame.shape[:2]
|
||||||
|
frame_area = frame_height * frame_width
|
||||||
|
|
||||||
|
for r in res:
|
||||||
|
boxes = r.boxes
|
||||||
|
if boxes is None or len(boxes) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for box in boxes:
|
||||||
|
detection = self._apply_detection_filters(
|
||||||
|
box, model, min_confidence, trigger_classes, trigger_class_indices
|
||||||
|
)
|
||||||
|
|
||||||
|
if not detection:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Calculate bbox area ratio
|
||||||
|
x1, y1, x2, y2 = detection["xyxy"]
|
||||||
|
bbox_area = (x2 - x1) * (y2 - y1)
|
||||||
|
bbox_area_ratio = bbox_area / frame_area if frame_area > 0 else 0
|
||||||
|
|
||||||
|
# Apply bbox area threshold
|
||||||
|
if bbox_area_ratio < min_bbox_area_ratio:
|
||||||
|
logger.debug(f"Detection filtered out: bbox_area_ratio={bbox_area_ratio:.3f} < {min_bbox_area_ratio}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validation passed
|
||||||
|
detection["bbox_area_ratio"] = float(bbox_area_ratio)
|
||||||
|
detection["validation_passed"] = True
|
||||||
|
|
||||||
|
if not best_detection or detection["confidence"] > best_detection["confidence"]:
|
||||||
|
best_detection = detection
|
||||||
|
|
||||||
|
if best_detection:
|
||||||
|
logger.debug(f"Validation PASSED: {best_detection['class']} (conf: {best_detection['confidence']:.3f}, area: {best_detection['bbox_area_ratio']:.3f})")
|
||||||
|
return best_detection
|
||||||
|
else:
|
||||||
|
logger.debug(f"Validation FAILED: No detection meets criteria (conf>={min_confidence}, area>={min_bbox_area_ratio})")
|
||||||
|
return {"validation_passed": False}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in lightweight detection with validation: {str(e)}", exc_info=True)
|
||||||
|
return {"validation_passed": False}
|
||||||
|
|
||||||
|
def run_lightweight_detection(self,
|
||||||
|
frame: np.ndarray,
|
||||||
|
node: Dict[str, Any]) -> LightweightDetectionResult:
|
||||||
|
"""
|
||||||
|
Run lightweight detection for car presence validation only.
|
||||||
|
Returns basic detection info without running branches or external actions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: Input frame for detection
|
||||||
|
node: Pipeline node configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LightweightDetectionResult with car detection status
|
||||||
|
"""
|
||||||
|
logger.debug(f"Running lightweight detection: {node['modelId']}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run basic detection only - no branches, no actions
|
||||||
|
model = node["model"]
|
||||||
|
min_confidence = node.get("minConfidence", DEFAULT_MIN_CONFIDENCE)
|
||||||
|
trigger_classes = node.get("triggerClasses", [])
|
||||||
|
trigger_class_indices = node.get("triggerClassIndices")
|
||||||
|
|
||||||
|
# Run YOLO inference
|
||||||
|
res = model(frame, verbose=False)
|
||||||
|
|
||||||
|
car_detected = False
|
||||||
|
best_detection = None
|
||||||
|
|
||||||
|
for r in res:
|
||||||
|
boxes = r.boxes
|
||||||
|
if boxes is None or len(boxes) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for box in boxes:
|
||||||
|
detection = self._apply_detection_filters(
|
||||||
|
box, model, min_confidence, trigger_classes, trigger_class_indices
|
||||||
|
)
|
||||||
|
|
||||||
|
if not detection:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Car detected
|
||||||
|
car_detected = True
|
||||||
|
if not best_detection or detection["confidence"] > best_detection["confidence"]:
|
||||||
|
best_detection = detection
|
||||||
|
|
||||||
|
logger.debug(f"Lightweight detection result: car_detected={car_detected}")
|
||||||
|
if best_detection:
|
||||||
|
logger.debug(f"Best detection: {best_detection['class']} (conf: {best_detection['confidence']:.3f})")
|
||||||
|
|
||||||
|
return LightweightDetectionResult(
|
||||||
|
car_detected=car_detected,
|
||||||
|
best_detection=best_detection,
|
||||||
|
confidence=best_detection["confidence"] if best_detection else None
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in lightweight detection: {str(e)}", exc_info=True)
|
||||||
|
return LightweightDetectionResult(car_detected=False, best_detection=None)
|
||||||
|
|
||||||
|
def validate_detection_thresholds(self,
|
||||||
|
detection: Dict[str, Any],
|
||||||
|
frame_shape: Tuple[int, int, int],
|
||||||
|
min_confidence: float,
|
||||||
|
min_bbox_area_ratio: float) -> bool:
|
||||||
|
"""
|
||||||
|
Validate detection against confidence and bbox area thresholds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detection: Detection dictionary with confidence and bbox
|
||||||
|
frame_shape: Frame dimensions (height, width, channels)
|
||||||
|
min_confidence: Minimum confidence threshold
|
||||||
|
min_bbox_area_ratio: Minimum bbox area ratio threshold
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if detection passes all validation thresholds
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check confidence
|
||||||
|
confidence = detection.get("confidence", 0.0)
|
||||||
|
if confidence < min_confidence:
|
||||||
|
logger.debug(f"Detection failed confidence threshold: {confidence:.3f} < {min_confidence}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check bbox area ratio
|
||||||
|
bbox = detection.get("bbox", [])
|
||||||
|
if len(bbox) >= 4:
|
||||||
|
x1, y1, x2, y2 = bbox[:4]
|
||||||
|
bbox_area = (x2 - x1) * (y2 - y1)
|
||||||
|
|
||||||
|
frame_height, frame_width = frame_shape[:2]
|
||||||
|
frame_area = frame_height * frame_width
|
||||||
|
bbox_area_ratio = bbox_area / frame_area if frame_area > 0 else 0
|
||||||
|
|
||||||
|
if bbox_area_ratio < min_bbox_area_ratio:
|
||||||
|
logger.debug(f"Detection failed bbox area threshold: {bbox_area_ratio:.3f} < {min_bbox_area_ratio}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.debug(f"Detection passed validation thresholds: conf={confidence:.3f}, area_ratio={bbox_area_ratio:.3f}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error validating detection thresholds: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def check_branch_execution_requirements(self,
|
||||||
|
branch: Dict[str, Any],
|
||||||
|
regions_dict: Dict[str, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a branch will be executed based on trigger classes and confidence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
branch: Branch configuration
|
||||||
|
regions_dict: Available detection regions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if branch will be executed
|
||||||
|
"""
|
||||||
|
trigger_classes = branch.get("triggerClasses", [])
|
||||||
|
min_conf = branch.get("minConfidence", 0)
|
||||||
|
|
||||||
|
for det_class in regions_dict:
|
||||||
|
det_confidence = regions_dict[det_class]["confidence"]
|
||||||
|
|
||||||
|
if det_class in trigger_classes and det_confidence >= min_conf:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_validation_summary(self,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
regions_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get comprehensive validation summary for a pipeline node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Pipeline node configuration
|
||||||
|
regions_dict: Available detection regions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with validation summary information
|
||||||
|
"""
|
||||||
|
summary = {
|
||||||
|
"model_id": node.get("modelId", "unknown"),
|
||||||
|
"available_classes": list(regions_dict.keys()),
|
||||||
|
"branch_count": len(node.get("branches", [])),
|
||||||
|
"parallel_action_count": len(node.get("parallelActions", [])),
|
||||||
|
"branches_that_will_execute": [],
|
||||||
|
"branches_that_will_not_execute": []
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check each branch
|
||||||
|
for branch in node.get("branches", []):
|
||||||
|
branch_id = branch["modelId"]
|
||||||
|
will_execute = self.check_branch_execution_requirements(branch, regions_dict)
|
||||||
|
|
||||||
|
if will_execute:
|
||||||
|
summary["branches_that_will_execute"].append(branch_id)
|
||||||
|
else:
|
||||||
|
summary["branches_that_will_not_execute"].append(branch_id)
|
||||||
|
|
||||||
|
# Pipeline validation
|
||||||
|
pipeline_validation = self.validate_pipeline_execution(node, regions_dict)
|
||||||
|
summary["pipeline_validation"] = pipeline_validation.to_dict()
|
||||||
|
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
# Global stability validator instance
|
||||||
|
stability_validator = StabilityValidator()
|
||||||
|
|
||||||
|
|
||||||
|
# ===== CONVENIENCE FUNCTIONS =====
|
||||||
|
# These provide the same interface as the original functions in pympta.py
|
||||||
|
|
||||||
|
def validate_pipeline_execution(node: Dict[str, Any], regions_dict: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
||||||
|
"""
|
||||||
|
Pre-validate that all required branches will execute successfully.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- (True, []) if pipeline can execute completely
|
||||||
|
- (False, missing_branches) if some required branches won't execute
|
||||||
|
"""
|
||||||
|
result = stability_validator.validate_pipeline_execution(node, regions_dict)
|
||||||
|
return result.is_valid, result.missing_branches
|
||||||
|
|
||||||
|
|
||||||
|
def run_lightweight_detection_with_validation(frame: np.ndarray,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
min_confidence: float = LIGHTWEIGHT_DETECTION_MIN_CONFIDENCE,
|
||||||
|
min_bbox_area_ratio: float = LIGHTWEIGHT_DETECTION_MIN_BBOX_AREA_RATIO) -> Dict[str, Any]:
|
||||||
|
"""Run lightweight detection with validation rules for session ID triggering."""
|
||||||
|
return stability_validator.run_lightweight_detection_with_validation(
|
||||||
|
frame, node, min_confidence, min_bbox_area_ratio
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_lightweight_detection(frame: np.ndarray, node: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Run lightweight detection for car presence validation only."""
|
||||||
|
result = stability_validator.run_lightweight_detection(frame, node)
|
||||||
|
return result.to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
def validate_detection_thresholds(detection: Dict[str, Any],
|
||||||
|
frame_shape: Tuple[int, int, int],
|
||||||
|
min_confidence: float,
|
||||||
|
min_bbox_area_ratio: float) -> bool:
|
||||||
|
"""Validate detection against confidence and bbox area thresholds."""
|
||||||
|
return stability_validator.validate_detection_thresholds(
|
||||||
|
detection, frame_shape, min_confidence, min_bbox_area_ratio
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_validation_summary(node: Dict[str, Any], regions_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Get comprehensive validation summary for a pipeline node."""
|
||||||
|
return stability_validator.get_validation_summary(node, regions_dict)
|
481
detector_worker/detection/tracking_manager.py
Normal file
481
detector_worker/detection/tracking_manager.py
Normal file
|
@ -0,0 +1,481 @@
|
||||||
|
"""
|
||||||
|
Object tracking management and state handling.
|
||||||
|
|
||||||
|
This module provides tracking state management, session handling,
|
||||||
|
and occupancy detection functionality for the detection pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Any, Optional, Tuple, Set
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from ..core.constants import SESSION_TIMEOUT_SECONDS
|
||||||
|
from ..core.exceptions import TrackingError, create_detection_error
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionState:
|
||||||
|
"""Session state for tracking management."""
|
||||||
|
active: bool = True
|
||||||
|
waiting_for_backend_session: bool = False
|
||||||
|
wait_start_time: float = 0.0
|
||||||
|
reset_tracker_on_resume: bool = False
|
||||||
|
occupancy_mode: bool = False
|
||||||
|
occupancy_enabled_at: Optional[float] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"active": self.active,
|
||||||
|
"waiting_for_backend_session": self.waiting_for_backend_session,
|
||||||
|
"wait_start_time": self.wait_start_time,
|
||||||
|
"reset_tracker_on_resume": self.reset_tracker_on_resume,
|
||||||
|
"occupancy_mode": self.occupancy_mode,
|
||||||
|
"occupancy_enabled_at": self.occupancy_enabled_at
|
||||||
|
}
|
||||||
|
|
||||||
|
def from_dict(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Update from dictionary data."""
|
||||||
|
self.active = data.get("active", True)
|
||||||
|
self.waiting_for_backend_session = data.get("waiting_for_backend_session", False)
|
||||||
|
self.wait_start_time = data.get("wait_start_time", 0.0)
|
||||||
|
self.reset_tracker_on_resume = data.get("reset_tracker_on_resume", False)
|
||||||
|
self.occupancy_mode = data.get("occupancy_mode", False)
|
||||||
|
self.occupancy_enabled_at = data.get("occupancy_enabled_at")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrackingData:
|
||||||
|
"""Complete tracking data for a camera and model."""
|
||||||
|
track_stability_counters: Dict[int, int] = field(default_factory=dict)
|
||||||
|
stable_tracks: Set[int] = field(default_factory=set)
|
||||||
|
session_state: SessionState = field(default_factory=SessionState)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"track_stability_counters": dict(self.track_stability_counters),
|
||||||
|
"stable_tracks": list(self.stable_tracks),
|
||||||
|
"session_state": self.session_state.to_dict()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TrackingManager:
|
||||||
|
"""
|
||||||
|
Manages object tracking state and session handling across cameras and models.
|
||||||
|
|
||||||
|
This class provides centralized tracking state management including:
|
||||||
|
- Track stability counters
|
||||||
|
- Session state management
|
||||||
|
- Backend session timeout handling
|
||||||
|
- Occupancy detection mode
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session_timeout_seconds: int = SESSION_TIMEOUT_SECONDS):
|
||||||
|
"""
|
||||||
|
Initialize tracking manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_timeout_seconds: Timeout for backend session waiting
|
||||||
|
"""
|
||||||
|
self.session_timeout_seconds = session_timeout_seconds
|
||||||
|
self._camera_tracking_data: Dict[str, Dict[str, TrackingData]] = {}
|
||||||
|
self._lock = None
|
||||||
|
|
||||||
|
def _ensure_thread_safety(self):
|
||||||
|
"""Initialize thread safety if not already done."""
|
||||||
|
if self._lock is None:
|
||||||
|
import threading
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
def get_camera_tracking_data(self, camera_id: str, model_id: str) -> TrackingData:
|
||||||
|
"""
|
||||||
|
Get or create tracking data for a specific camera and model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tracking data for the camera and model
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if camera_id not in self._camera_tracking_data:
|
||||||
|
self._camera_tracking_data[camera_id] = {}
|
||||||
|
|
||||||
|
if model_id not in self._camera_tracking_data[camera_id]:
|
||||||
|
logger.warning(f"🔄 Camera {camera_id}: Creating NEW tracking data for {model_id} - this will reset any cooldown!")
|
||||||
|
self._camera_tracking_data[camera_id][model_id] = TrackingData()
|
||||||
|
|
||||||
|
return self._camera_tracking_data[camera_id][model_id]
|
||||||
|
|
||||||
|
def check_stable_tracks(self,
|
||||||
|
camera_id: str,
|
||||||
|
model_id: str,
|
||||||
|
regions_dict: Dict[str, Any]) -> Tuple[bool, List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Check if any stable tracks match the detected classes for a specific camera.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
model_id: Model identifier
|
||||||
|
regions_dict: Dictionary mapping class names to detection regions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (has_stable_tracks, stable_detections)
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||||
|
stable_tracks = tracking_data.stable_tracks
|
||||||
|
|
||||||
|
if not stable_tracks:
|
||||||
|
return False, []
|
||||||
|
|
||||||
|
# Check for track-based stability
|
||||||
|
stable_detections = []
|
||||||
|
|
||||||
|
for class_name, region_data in regions_dict.items():
|
||||||
|
detection = region_data.get("detection", {})
|
||||||
|
track_id = detection.get("id")
|
||||||
|
|
||||||
|
if track_id in stable_tracks:
|
||||||
|
stable_detections.append(detection)
|
||||||
|
|
||||||
|
has_stable = len(stable_detections) > 0
|
||||||
|
logger.debug(f"Camera {camera_id}: Stable track check - stable_tracks: {list(stable_tracks)}, found: {has_stable}")
|
||||||
|
|
||||||
|
return has_stable, stable_detections
|
||||||
|
|
||||||
|
def reset_tracking_state(self,
|
||||||
|
camera_id: str,
|
||||||
|
model_id: str,
|
||||||
|
reason: str = "session ended") -> None:
|
||||||
|
"""
|
||||||
|
Reset tracking state after session completion or timeout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
model_id: Model identifier
|
||||||
|
reason: Reason for reset (for logging)
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||||
|
session_state = tracking_data.session_state
|
||||||
|
|
||||||
|
# Clear all tracking data for fresh start
|
||||||
|
tracking_data.track_stability_counters.clear()
|
||||||
|
tracking_data.stable_tracks.clear()
|
||||||
|
session_state.active = True
|
||||||
|
session_state.waiting_for_backend_session = False
|
||||||
|
session_state.wait_start_time = 0.0
|
||||||
|
session_state.reset_tracker_on_resume = True
|
||||||
|
|
||||||
|
logger.info(f"Camera {camera_id}: 🔄 Reset tracking state - {reason}")
|
||||||
|
logger.info(f"Camera {camera_id}: 🧹 Cleared stability counters and stable tracks for fresh session")
|
||||||
|
|
||||||
|
def is_camera_active(self, camera_id: str, model_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if camera should be processing detections.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if camera should process detections, False otherwise
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||||
|
session_state = tracking_data.session_state
|
||||||
|
|
||||||
|
# Check if waiting for backend sessionId has timed out
|
||||||
|
if session_state.waiting_for_backend_session:
|
||||||
|
current_time = time.time()
|
||||||
|
wait_start_time = session_state.wait_start_time
|
||||||
|
elapsed_time = current_time - wait_start_time
|
||||||
|
|
||||||
|
if elapsed_time >= self.session_timeout_seconds:
|
||||||
|
logger.warning(f"Camera {camera_id}: Backend sessionId timeout ({self.session_timeout_seconds}s) - resetting tracking")
|
||||||
|
self.reset_tracking_state(camera_id, model_id, "backend sessionId timeout")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
remaining_time = self.session_timeout_seconds - elapsed_time
|
||||||
|
logger.debug(f"Camera {camera_id}: Still waiting for backend sessionId ({remaining_time:.1f}s remaining)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return session_state.active
|
||||||
|
|
||||||
|
def set_waiting_for_backend_session(self,
|
||||||
|
camera_id: str,
|
||||||
|
model_id: str,
|
||||||
|
waiting: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Set whether camera is waiting for backend session ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
model_id: Model identifier
|
||||||
|
waiting: Whether to wait for backend session
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||||
|
session_state = tracking_data.session_state
|
||||||
|
|
||||||
|
session_state.waiting_for_backend_session = waiting
|
||||||
|
if waiting:
|
||||||
|
session_state.wait_start_time = time.time()
|
||||||
|
logger.debug(f"Camera {camera_id}: Started waiting for backend sessionId")
|
||||||
|
else:
|
||||||
|
session_state.wait_start_time = 0.0
|
||||||
|
logger.debug(f"Camera {camera_id}: Stopped waiting for backend sessionId")
|
||||||
|
|
||||||
|
def cleanup_camera_tracking(self, camera_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Clean up tracking data when a camera is disconnected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if camera_id in self._camera_tracking_data:
|
||||||
|
del self._camera_tracking_data[camera_id]
|
||||||
|
logger.info(f"Cleaned up tracking data for camera {camera_id}")
|
||||||
|
|
||||||
|
def set_occupancy_mode(self,
|
||||||
|
camera_id: str,
|
||||||
|
model_id: str,
|
||||||
|
enable: bool = True) -> bool:
|
||||||
|
"""
|
||||||
|
Enable or disable occupancy detection mode.
|
||||||
|
|
||||||
|
Occupancy mode stops model inference after pipeline completion
|
||||||
|
while backend session handling continues in background.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
model_id: Model identifier
|
||||||
|
enable: True to enable occupancy mode, False to disable
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current occupancy mode state
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||||
|
session_state = tracking_data.session_state
|
||||||
|
|
||||||
|
if enable:
|
||||||
|
session_state.occupancy_mode = True
|
||||||
|
session_state.occupancy_enabled_at = time.time()
|
||||||
|
logger.debug(f"Camera {camera_id}: Occupancy mode ENABLED - model will stop after pipeline completion")
|
||||||
|
else:
|
||||||
|
session_state.occupancy_mode = False
|
||||||
|
session_state.occupancy_enabled_at = None
|
||||||
|
logger.debug(f"Camera {camera_id}: Occupancy mode DISABLED - model will continue running")
|
||||||
|
|
||||||
|
return session_state.occupancy_mode
|
||||||
|
|
||||||
|
def is_occupancy_mode_enabled(self, camera_id: str, model_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if occupancy mode is enabled for a camera.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if occupancy mode is enabled
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||||
|
return tracking_data.session_state.occupancy_mode
|
||||||
|
|
||||||
|
def get_occupancy_duration(self, camera_id: str, model_id: str) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
Get duration since occupancy mode was enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Duration in seconds since occupancy mode was enabled, or None if not enabled
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||||
|
session_state = tracking_data.session_state
|
||||||
|
|
||||||
|
if not session_state.occupancy_mode or not session_state.occupancy_enabled_at:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return time.time() - session_state.occupancy_enabled_at
|
||||||
|
|
||||||
|
def get_session_state(self, camera_id: str, model_id: str) -> SessionState:
|
||||||
|
"""
|
||||||
|
Get session state for a camera and model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Session state object
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||||
|
return tracking_data.session_state
|
||||||
|
|
||||||
|
def update_session_state(self,
|
||||||
|
camera_id: str,
|
||||||
|
model_id: str,
|
||||||
|
**kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Update session state properties.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
model_id: Model identifier
|
||||||
|
**kwargs: Session state properties to update
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||||
|
session_state = tracking_data.session_state
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(session_state, key):
|
||||||
|
setattr(session_state, key, value)
|
||||||
|
logger.debug(f"Camera {camera_id}: Updated session state {key} = {value}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Camera {camera_id}: Unknown session state property: {key}")
|
||||||
|
|
||||||
|
def get_tracking_statistics(self, camera_id: str, model_id: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get comprehensive tracking statistics for a camera and model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with tracking statistics
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
stats = tracking_data.to_dict()
|
||||||
|
|
||||||
|
# Add computed statistics
|
||||||
|
stats["total_tracked_objects"] = len(tracking_data.track_stability_counters)
|
||||||
|
stats["stable_track_count"] = len(tracking_data.stable_tracks)
|
||||||
|
|
||||||
|
session_state = tracking_data.session_state
|
||||||
|
if session_state.waiting_for_backend_session and session_state.wait_start_time > 0:
|
||||||
|
stats["backend_wait_duration"] = current_time - session_state.wait_start_time
|
||||||
|
stats["backend_wait_remaining"] = max(0, self.session_timeout_seconds - stats["backend_wait_duration"])
|
||||||
|
|
||||||
|
if session_state.occupancy_mode and session_state.occupancy_enabled_at:
|
||||||
|
stats["occupancy_duration"] = current_time - session_state.occupancy_enabled_at
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def get_all_camera_stats(self) -> Dict[str, Dict[str, Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Get tracking statistics for all monitored cameras.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Nested dictionary: {camera_id: {model_id: stats}}
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
all_stats = {}
|
||||||
|
for camera_id, models in self._camera_tracking_data.items():
|
||||||
|
all_stats[camera_id] = {}
|
||||||
|
for model_id in models.keys():
|
||||||
|
all_stats[camera_id][model_id] = self.get_tracking_statistics(camera_id, model_id)
|
||||||
|
|
||||||
|
return all_stats
|
||||||
|
|
||||||
|
|
||||||
|
# Global tracking manager instance
|
||||||
|
tracking_manager = TrackingManager()
|
||||||
|
|
||||||
|
|
||||||
|
# ===== CONVENIENCE FUNCTIONS =====
|
||||||
|
# These provide the same interface as the original functions in pympta.py
|
||||||
|
|
||||||
|
def check_stable_tracks(camera_id: str, model_id: str, regions_dict: Dict[str, Any]) -> Tuple[bool, List[Dict[str, Any]]]:
|
||||||
|
"""Check if any stable tracks match the detected classes for a specific camera."""
|
||||||
|
return tracking_manager.check_stable_tracks(camera_id, model_id, regions_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_tracking_state(camera_id: str, model_id: str, reason: str = "session ended") -> None:
|
||||||
|
"""Reset tracking state after session completion or timeout."""
|
||||||
|
tracking_manager.reset_tracking_state(camera_id, model_id, reason)
|
||||||
|
|
||||||
|
|
||||||
|
def is_camera_active(camera_id: str, model_id: str) -> bool:
|
||||||
|
"""Check if camera should be processing detections."""
|
||||||
|
return tracking_manager.is_camera_active(camera_id, model_id)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_camera_stability(camera_id: str) -> None:
|
||||||
|
"""Clean up tracking data when a camera is disconnected."""
|
||||||
|
tracking_manager.cleanup_camera_tracking(camera_id)
|
||||||
|
|
||||||
|
|
||||||
|
def occupancy_detector(camera_id: str, model_id: str, enable: bool = True) -> bool:
|
||||||
|
"""
|
||||||
|
Enable or disable occupancy detection mode.
|
||||||
|
|
||||||
|
When enabled:
|
||||||
|
- Model stops inference after completing full pipeline
|
||||||
|
- Backend sessionId handling continues in background
|
||||||
|
|
||||||
|
Note: This is a temporary function that will be changed in the future.
|
||||||
|
"""
|
||||||
|
return tracking_manager.set_occupancy_mode(camera_id, model_id, enable)
|
||||||
|
|
||||||
|
|
||||||
|
def get_camera_tracking_data(camera_id: str, model_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get tracking data for a specific camera and model."""
|
||||||
|
tracking_data = tracking_manager.get_camera_tracking_data(camera_id, model_id)
|
||||||
|
return tracking_data.to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
def set_waiting_for_backend_session(camera_id: str, model_id: str, waiting: bool = True) -> None:
|
||||||
|
"""Set whether camera is waiting for backend session ID."""
|
||||||
|
tracking_manager.set_waiting_for_backend_session(camera_id, model_id, waiting)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tracking_statistics(camera_id: str, model_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get comprehensive tracking statistics for a camera and model."""
|
||||||
|
return tracking_manager.get_tracking_statistics(camera_id, model_id)
|
633
detector_worker/detection/yolo_detector.py
Normal file
633
detector_worker/detection/yolo_detector.py
Normal file
|
@ -0,0 +1,633 @@
|
||||||
|
"""
|
||||||
|
YOLO detection with BoT-SORT tracking integration.
|
||||||
|
|
||||||
|
This module provides the main detection functionality combining YOLO object detection
|
||||||
|
with BoT-SORT tracking for stable track validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Any, Optional, Tuple, Set
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from ..core.constants import DEFAULT_STABILITY_THRESHOLD, DEFAULT_MIN_CONFIDENCE
|
||||||
|
from ..core.exceptions import DetectionError, create_detection_error
|
||||||
|
from ..detection.detection_result import DetectionResult, BoundingBox, DetectionSession
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrackingConfig:
|
||||||
|
"""Configuration for object tracking."""
|
||||||
|
enabled: bool = True
|
||||||
|
method: str = "botsort"
|
||||||
|
reid_config_path: str = "botsort.yaml"
|
||||||
|
stability_threshold: int = DEFAULT_STABILITY_THRESHOLD
|
||||||
|
use_tracking_for_model: bool = True # False for frontal detection models
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StabilityData:
|
||||||
|
"""Track stability data for a camera and model."""
|
||||||
|
track_stability_counters: Dict[int, int] = field(default_factory=dict)
|
||||||
|
stable_tracks: Set[int] = field(default_factory=set)
|
||||||
|
session_state: Dict[str, Any] = field(default_factory=lambda: {
|
||||||
|
"active": True,
|
||||||
|
"waiting_for_backend_session": False,
|
||||||
|
"wait_start_time": 0.0,
|
||||||
|
"reset_tracker_on_resume": False
|
||||||
|
})
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"track_stability_counters": dict(self.track_stability_counters),
|
||||||
|
"stable_tracks": list(self.stable_tracks),
|
||||||
|
"session_state": self.session_state.copy()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationResult:
|
||||||
|
"""Result of track stability validation."""
|
||||||
|
validation_complete: bool
|
||||||
|
stable_tracks: List[int] = field(default_factory=list)
|
||||||
|
current_tracks: List[int] = field(default_factory=list)
|
||||||
|
newly_stable_tracks: List[int] = field(default_factory=list)
|
||||||
|
send_none_detection: bool = False
|
||||||
|
branch_node: bool = False
|
||||||
|
bypass_validation: bool = False
|
||||||
|
awaiting_stable_tracks: bool = False
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"validation_complete": self.validation_complete,
|
||||||
|
"stable_tracks": self.stable_tracks.copy(),
|
||||||
|
"current_tracks": self.current_tracks.copy(),
|
||||||
|
"newly_stable_tracks": self.newly_stable_tracks.copy(),
|
||||||
|
"send_none_detection": self.send_none_detection,
|
||||||
|
"branch_node": self.branch_node,
|
||||||
|
"bypass_validation": self.bypass_validation,
|
||||||
|
"awaiting_stable_tracks": self.awaiting_stable_tracks
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityTracker:
|
||||||
|
"""Manages track stability validation for cameras and models."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize stability tracker."""
|
||||||
|
self._camera_stability_tracking: Dict[str, Dict[str, StabilityData]] = {}
|
||||||
|
self._lock = None
|
||||||
|
|
||||||
|
def _ensure_thread_safety(self):
|
||||||
|
"""Initialize thread safety if not already done."""
|
||||||
|
if self._lock is None:
|
||||||
|
import threading
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
def get_camera_stability_data(self, camera_id: str, model_id: str) -> StabilityData:
|
||||||
|
"""
|
||||||
|
Get or create stability tracking data for a specific camera and model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Stability data for the camera and model
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if camera_id not in self._camera_stability_tracking:
|
||||||
|
self._camera_stability_tracking[camera_id] = {}
|
||||||
|
|
||||||
|
if model_id not in self._camera_stability_tracking[camera_id]:
|
||||||
|
logger.warning(f"🔄 Camera {camera_id}: Creating NEW stability data for {model_id} - this will reset any cooldown!")
|
||||||
|
self._camera_stability_tracking[camera_id][model_id] = StabilityData()
|
||||||
|
|
||||||
|
return self._camera_stability_tracking[camera_id][model_id]
|
||||||
|
|
||||||
|
def reset_camera_stability_tracking(self, camera_id: str, model_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Reset all stability tracking data for a specific camera and model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
model_id: Model identifier
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if camera_id in self._camera_stability_tracking and model_id in self._camera_stability_tracking[camera_id]:
|
||||||
|
stability_data = self._camera_stability_tracking[camera_id][model_id]
|
||||||
|
|
||||||
|
# Clear all tracking data
|
||||||
|
old_counters = dict(stability_data.track_stability_counters)
|
||||||
|
old_stable = list(stability_data.stable_tracks)
|
||||||
|
|
||||||
|
stability_data.track_stability_counters.clear()
|
||||||
|
stability_data.stable_tracks.clear()
|
||||||
|
|
||||||
|
# IMPORTANT: Set flag to reset YOLO tracker on next detection run
|
||||||
|
stability_data.session_state["reset_tracker_on_resume"] = True
|
||||||
|
|
||||||
|
logger.info(f"🧹 Camera {camera_id}: CLEARED stability tracking - old_counters={old_counters}, old_stable={old_stable}")
|
||||||
|
logger.info(f"🔄 Camera {camera_id}: YOLO tracker will be reset on next detection - fresh track IDs will start from 1")
|
||||||
|
|
||||||
|
def update_single_track_stability(self,
|
||||||
|
detection: Optional[Dict[str, Any]],
|
||||||
|
camera_id: str,
|
||||||
|
model_id: str,
|
||||||
|
stability_threshold: int,
|
||||||
|
current_mode: str,
|
||||||
|
is_branch_node: bool = False) -> ValidationResult:
|
||||||
|
"""
|
||||||
|
Update track stability validation for a single highest confidence detection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detection: Detection data with track_id, or None if no detection
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
model_id: Model identifier
|
||||||
|
stability_threshold: Number of consecutive frames needed for stability
|
||||||
|
current_mode: Current pipeline mode
|
||||||
|
is_branch_node: Whether this is a branch node (skip validation)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validation result with track status
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
# Branch nodes should not do validation - only main pipeline should
|
||||||
|
if is_branch_node:
|
||||||
|
logger.debug(f"⏭️ Camera {camera_id}: Skipping validation for branch node {model_id} - validation only done at main pipeline level")
|
||||||
|
return ValidationResult(
|
||||||
|
validation_complete=False,
|
||||||
|
branch_node=True,
|
||||||
|
stable_tracks=[],
|
||||||
|
current_tracks=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check current mode - validation counters should increment in both validation_detecting and full_pipeline modes
|
||||||
|
is_validation_mode = current_mode in ["validation_detecting", "full_pipeline"]
|
||||||
|
|
||||||
|
# Get camera-specific stability data
|
||||||
|
stability_data = self.get_camera_stability_data(camera_id, model_id)
|
||||||
|
track_counters = stability_data.track_stability_counters
|
||||||
|
stable_tracks = stability_data.stable_tracks
|
||||||
|
|
||||||
|
current_track_id = detection.get("id") if detection else None
|
||||||
|
|
||||||
|
# ═══ MODE-AWARE TRACK VALIDATION ═══
|
||||||
|
logger.debug(f"📋 Camera {camera_id}: === TRACK VALIDATION ANALYSIS ===")
|
||||||
|
logger.debug(f"📋 Camera {camera_id}: Current mode: {current_mode} (validation_mode={is_validation_mode})")
|
||||||
|
logger.debug(f"📋 Camera {camera_id}: Current track_id: {current_track_id} (assigned by YOLO tracking - not sequential)")
|
||||||
|
logger.debug(f"📋 Camera {camera_id}: Existing counters: {dict(track_counters)}")
|
||||||
|
logger.debug(f"📋 Camera {camera_id}: Stable tracks: {list(stable_tracks)}")
|
||||||
|
|
||||||
|
# IMPORTANT: Only modify validation counters during validation_detecting mode
|
||||||
|
if not is_validation_mode:
|
||||||
|
logger.debug(f"🚫 Camera {camera_id}: NOT in validation mode - skipping counter modifications")
|
||||||
|
return ValidationResult(
|
||||||
|
validation_complete=False,
|
||||||
|
stable_tracks=list(stable_tracks),
|
||||||
|
current_tracks=[current_track_id] if current_track_id is not None else []
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_track_id is not None:
|
||||||
|
# Check if this is a different track than we were tracking
|
||||||
|
previous_track_ids = list(track_counters.keys())
|
||||||
|
|
||||||
|
# VALIDATION MODE: Reset counter if different track OR if track was previously stable
|
||||||
|
should_reset = (
|
||||||
|
len(previous_track_ids) == 0 or # No previous tracking
|
||||||
|
current_track_id not in previous_track_ids or # Different track ID
|
||||||
|
current_track_id in stable_tracks # Track was stable - start fresh validation
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"📋 Camera {camera_id}: Previous track_ids: {previous_track_ids}")
|
||||||
|
logger.debug(f"📋 Camera {camera_id}: Track {current_track_id} was stable: {current_track_id in stable_tracks}")
|
||||||
|
logger.debug(f"📋 Camera {camera_id}: Should reset counters: {should_reset}")
|
||||||
|
|
||||||
|
if should_reset:
|
||||||
|
# Clear all previous tracking - fresh validation needed
|
||||||
|
if previous_track_ids:
|
||||||
|
for old_track_id in previous_track_ids:
|
||||||
|
old_count = track_counters.pop(old_track_id, 0)
|
||||||
|
stable_tracks.discard(old_track_id)
|
||||||
|
logger.info(f"🔄 Camera {camera_id}: VALIDATION RESET - track {old_track_id} counter from {old_count} to 0 (reason: {'stable_track_restart' if current_track_id == old_track_id else 'different_track'})")
|
||||||
|
|
||||||
|
# Start fresh validation for this track
|
||||||
|
old_count = track_counters.get(current_track_id, 0) # Store old count for logging
|
||||||
|
track_counters[current_track_id] = 1
|
||||||
|
current_count = 1
|
||||||
|
logger.info(f"🆕 Camera {camera_id}: FRESH VALIDATION - Track {current_track_id} starting at 1/{stability_threshold}")
|
||||||
|
else:
|
||||||
|
# Continue validation for same track
|
||||||
|
old_count = track_counters.get(current_track_id, 0)
|
||||||
|
track_counters[current_track_id] = old_count + 1
|
||||||
|
current_count = track_counters[current_track_id]
|
||||||
|
|
||||||
|
logger.debug(f"🔢 Camera {camera_id}: Track {current_track_id} counter: {old_count} → {current_count}")
|
||||||
|
logger.info(f"🔍 Camera {camera_id}: Track ID {current_track_id} validation {current_count}/{stability_threshold}")
|
||||||
|
|
||||||
|
# Check if track has reached stability threshold
|
||||||
|
logger.debug(f"📊 Camera {camera_id}: Checking stability: {current_count} >= {stability_threshold}? {current_count >= stability_threshold}")
|
||||||
|
logger.debug(f"📊 Camera {camera_id}: Already stable: {current_track_id in stable_tracks}")
|
||||||
|
|
||||||
|
if current_count >= stability_threshold and current_track_id not in stable_tracks:
|
||||||
|
stable_tracks.add(current_track_id)
|
||||||
|
logger.info(f"✅ Camera {camera_id}: Track ID {current_track_id} STABLE after {current_count} consecutive frames")
|
||||||
|
logger.info(f"🎯 Camera {camera_id}: TRACK VALIDATION COMPLETE")
|
||||||
|
logger.debug(f"🎯 Camera {camera_id}: Stable tracks now: {list(stable_tracks)}")
|
||||||
|
return ValidationResult(
|
||||||
|
validation_complete=True,
|
||||||
|
send_none_detection=True,
|
||||||
|
stable_tracks=[current_track_id],
|
||||||
|
newly_stable_tracks=[current_track_id],
|
||||||
|
current_tracks=[current_track_id]
|
||||||
|
)
|
||||||
|
elif current_count >= stability_threshold:
|
||||||
|
logger.debug(f"📊 Camera {camera_id}: Track {current_track_id} already stable - not re-adding")
|
||||||
|
else:
|
||||||
|
# No car detected - ALWAYS clear all tracking and reset counters
|
||||||
|
logger.debug(f"🚫 Camera {camera_id}: NO CAR DETECTED - clearing all tracking")
|
||||||
|
if track_counters or stable_tracks:
|
||||||
|
logger.debug(f"🚫 Camera {camera_id}: Existing state before reset: counters={dict(track_counters)}, stable={list(stable_tracks)}")
|
||||||
|
for track_id in list(track_counters.keys()):
|
||||||
|
old_count = track_counters.pop(track_id, 0)
|
||||||
|
logger.info(f"🔄 Camera {camera_id}: No car detected - RESET track {track_id} counter from {old_count} to 0")
|
||||||
|
track_counters.clear() # Ensure complete reset
|
||||||
|
stable_tracks.clear() # Clear all stable tracks
|
||||||
|
logger.info(f"✅ Camera {camera_id}: RESET TO VALIDATION PHASE - All counters and stable tracks cleared")
|
||||||
|
else:
|
||||||
|
logger.debug(f"🚫 Camera {camera_id}: No existing counters to clear")
|
||||||
|
logger.debug(f"Camera {camera_id}: VALIDATION - no car detected (all counters reset)")
|
||||||
|
|
||||||
|
# Final return - validation not complete
|
||||||
|
result = ValidationResult(
|
||||||
|
validation_complete=False,
|
||||||
|
stable_tracks=list(stable_tracks),
|
||||||
|
current_tracks=[current_track_id] if current_track_id is not None else []
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"📋 Camera {camera_id}: Track stability result: {result.to_dict()}")
|
||||||
|
logger.debug(f"📋 Camera {camera_id}: Final counters: {dict(track_counters)}")
|
||||||
|
logger.debug(f"📋 Camera {camera_id}: Final stable tracks: {list(stable_tracks)}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class YOLODetector:
|
||||||
|
"""
|
||||||
|
YOLO object detector with BoT-SORT tracking integration.
|
||||||
|
|
||||||
|
Provides structured detection functionality with track stability validation
|
||||||
|
and multi-class detection support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, stability_tracker: Optional[StabilityTracker] = None):
|
||||||
|
"""
|
||||||
|
Initialize YOLO detector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stability_tracker: Optional stability tracker instance
|
||||||
|
"""
|
||||||
|
self.stability_tracker = stability_tracker or StabilityTracker()
|
||||||
|
|
||||||
|
def _extract_tracking_config(self, node: Dict[str, Any]) -> TrackingConfig:
|
||||||
|
"""Extract tracking configuration from pipeline node."""
|
||||||
|
tracking_config = node.get("tracking", {})
|
||||||
|
model_id = node.get("modelId", "")
|
||||||
|
|
||||||
|
# Don't use tracking for frontal detection models
|
||||||
|
use_tracking = not ("frontal" in model_id.lower() or "detection" in model_id.lower())
|
||||||
|
|
||||||
|
return TrackingConfig(
|
||||||
|
enabled=tracking_config.get("enabled", True),
|
||||||
|
method=tracking_config.get("method", "botsort"),
|
||||||
|
reid_config_path=tracking_config.get("reidConfig", tracking_config.get("reidConfigPath", "botsort.yaml")),
|
||||||
|
stability_threshold=tracking_config.get("stabilityThreshold", node.get("stabilityThreshold", DEFAULT_STABILITY_THRESHOLD)),
|
||||||
|
use_tracking_for_model=use_tracking
|
||||||
|
)
|
||||||
|
|
||||||
|
def _determine_confidence_threshold(self, node: Dict[str, Any]) -> float:
|
||||||
|
"""Determine confidence threshold based on model type."""
|
||||||
|
model_id = node.get("modelId", "")
|
||||||
|
|
||||||
|
# Use frontalMinConfidence for frontal detection models
|
||||||
|
if "frontal" in model_id.lower() and "frontalMinConfidence" in node:
|
||||||
|
min_confidence = node.get("frontalMinConfidence", DEFAULT_MIN_CONFIDENCE)
|
||||||
|
logger.debug(f"Using frontalMinConfidence={min_confidence} for {model_id}")
|
||||||
|
else:
|
||||||
|
min_confidence = node.get("minConfidence", DEFAULT_MIN_CONFIDENCE)
|
||||||
|
|
||||||
|
return min_confidence
|
||||||
|
|
||||||
|
def _reset_yolo_tracker_if_needed(self, node: Dict[str, Any], camera_id: str, model_id: str) -> None:
|
||||||
|
"""Reset YOLO tracker if flagged for reset."""
|
||||||
|
stability_data = self.stability_tracker.get_camera_stability_data(camera_id, model_id)
|
||||||
|
session_state = stability_data.session_state
|
||||||
|
|
||||||
|
if session_state.get("reset_tracker_on_resume", False):
|
||||||
|
# Reset YOLO tracker to get fresh track IDs
|
||||||
|
if hasattr(node["model"], 'trackers') and node["model"].trackers:
|
||||||
|
node["model"].trackers.clear() # Clear tracker state
|
||||||
|
logger.info(f"Camera {camera_id}: 🔄 Reset YOLO tracker - new cars will get fresh track IDs")
|
||||||
|
session_state["reset_tracker_on_resume"] = False # Clear the flag
|
||||||
|
|
||||||
|
def _run_yolo_inference(self,
|
||||||
|
frame,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
tracking_config: TrackingConfig) -> Any:
|
||||||
|
"""Run YOLO inference with or without tracking."""
|
||||||
|
# Prepare class filtering
|
||||||
|
trigger_class_indices = node.get("triggerClassIndices")
|
||||||
|
class_filter = {"classes": trigger_class_indices} if trigger_class_indices else {}
|
||||||
|
|
||||||
|
model_id = node.get("modelId", "")
|
||||||
|
use_tracking = tracking_config.enabled and tracking_config.use_tracking_for_model
|
||||||
|
|
||||||
|
logger.debug(f"Running detection for {model_id} - tracking: {use_tracking}, stability_threshold: {tracking_config.stability_threshold}, classes: {node.get('triggerClasses', 'all')}")
|
||||||
|
|
||||||
|
if use_tracking:
|
||||||
|
# Use tracking for main detection models (yolo11m, etc.)
|
||||||
|
logger.debug(f"Using tracking for {model_id}")
|
||||||
|
res = node["model"].track(
|
||||||
|
frame,
|
||||||
|
stream=False,
|
||||||
|
persist=True,
|
||||||
|
**class_filter
|
||||||
|
)[0]
|
||||||
|
else:
|
||||||
|
# Use detection only for frontal detection and other detection-only models
|
||||||
|
logger.debug(f"Using prediction only for {model_id}")
|
||||||
|
res = node["model"].predict(
|
||||||
|
frame,
|
||||||
|
stream=False,
|
||||||
|
**class_filter
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _process_detections(self,
|
||||||
|
res: Any,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
camera_id: str,
|
||||||
|
min_confidence: float) -> List[Dict[str, Any]]:
|
||||||
|
"""Process YOLO detection results into candidate detections."""
|
||||||
|
candidate_detections = []
|
||||||
|
|
||||||
|
if res.boxes is None or len(res.boxes) == 0:
|
||||||
|
logger.debug(f"🚫 Camera {camera_id}: YOLO returned no detections")
|
||||||
|
return candidate_detections
|
||||||
|
|
||||||
|
logger.debug(f"🔍 Camera {camera_id}: YOLO detected {len(res.boxes)} raw objects - processing with tracking...")
|
||||||
|
logger.debug(f"🔍 Camera {camera_id}: === DETECTION ANALYSIS ===")
|
||||||
|
|
||||||
|
for i, box in enumerate(res.boxes):
|
||||||
|
# Extract detection data
|
||||||
|
conf = float(box.cpu().conf[0])
|
||||||
|
cls_id = int(box.cpu().cls[0])
|
||||||
|
class_name = node["model"].names[cls_id]
|
||||||
|
|
||||||
|
# Extract bounding box
|
||||||
|
xy = box.cpu().xyxy[0]
|
||||||
|
x1, y1, x2, y2 = map(int, xy)
|
||||||
|
bbox = (x1, y1, x2, y2)
|
||||||
|
|
||||||
|
# Extract tracking ID if available
|
||||||
|
track_id = None
|
||||||
|
if hasattr(box, "id") and box.id is not None:
|
||||||
|
track_id = int(box.id.item())
|
||||||
|
|
||||||
|
logger.debug(f"🔍 Camera {camera_id}: Detection {i+1}: class='{class_name}' conf={conf:.3f} track_id={track_id} bbox={bbox}")
|
||||||
|
|
||||||
|
# Apply confidence filtering
|
||||||
|
if conf < min_confidence:
|
||||||
|
logger.debug(f"❌ Camera {camera_id}: Detection {i+1} REJECTED - confidence {conf:.3f} < {min_confidence}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create detection object
|
||||||
|
detection = {
|
||||||
|
"class": class_name,
|
||||||
|
"confidence": conf,
|
||||||
|
"id": track_id,
|
||||||
|
"bbox": bbox,
|
||||||
|
"class_id": cls_id
|
||||||
|
}
|
||||||
|
|
||||||
|
candidate_detections.append(detection)
|
||||||
|
logger.debug(f"✅ Camera {camera_id}: Detection {i+1} ACCEPTED as candidate: {class_name} (conf={conf:.3f}, track_id={track_id})")
|
||||||
|
|
||||||
|
return candidate_detections
|
||||||
|
|
||||||
|
def _select_best_detection(self,
|
||||||
|
candidate_detections: List[Dict[str, Any]],
|
||||||
|
camera_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Select the highest confidence detection from candidates."""
|
||||||
|
if not candidate_detections:
|
||||||
|
logger.debug(f"🚫 Camera {camera_id}: No valid candidates after filtering - no car will be tracked")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug(f"🏆 Camera {camera_id}: === SELECTING HIGHEST CONFIDENCE CAR ===")
|
||||||
|
for i, detection in enumerate(candidate_detections):
|
||||||
|
logger.debug(f"🏆 Camera {camera_id}: Candidate {i+1}: {detection['class']} conf={detection['confidence']:.3f} track_id={detection['id']}")
|
||||||
|
|
||||||
|
# Find the single highest confidence detection across all detected classes
|
||||||
|
best_detection = max(candidate_detections, key=lambda x: x["confidence"])
|
||||||
|
|
||||||
|
logger.info(f"🎯 Camera {camera_id}: SELECTED WINNER: {best_detection['class']} (conf={best_detection['confidence']:.3f}, track_id={best_detection['id']}, bbox={best_detection['bbox']})")
|
||||||
|
|
||||||
|
# Show which cars were NOT selected
|
||||||
|
for detection in candidate_detections:
|
||||||
|
if detection != best_detection:
|
||||||
|
logger.debug(f"🚫 Camera {camera_id}: NOT SELECTED: {detection['class']} (conf={detection['confidence']:.3f}, track_id={detection['id']}) - lower confidence")
|
||||||
|
|
||||||
|
return best_detection
|
||||||
|
|
||||||
|
def _apply_class_mapping(self, detection: Dict[str, Any], node: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Apply class mapping if configured."""
|
||||||
|
original_class = detection["class"]
|
||||||
|
class_mapping = node.get("classMapping", {})
|
||||||
|
|
||||||
|
if original_class in class_mapping:
|
||||||
|
mapped_class = class_mapping[original_class]
|
||||||
|
logger.info(f"Class mapping applied: {original_class} → {mapped_class}")
|
||||||
|
# Update the detection object with mapped class
|
||||||
|
detection["class"] = mapped_class
|
||||||
|
detection["original_class"] = original_class # Keep original for reference
|
||||||
|
|
||||||
|
return detection
|
||||||
|
|
||||||
|
def _validate_multi_class(self,
|
||||||
|
regions_dict: Dict[str, Any],
|
||||||
|
node: Dict[str, Any]) -> bool:
|
||||||
|
"""Validate multi-class detection requirements."""
|
||||||
|
if not (node.get("multiClass", False) and node.get("expectedClasses")):
|
||||||
|
return True
|
||||||
|
|
||||||
|
expected_classes = node["expectedClasses"]
|
||||||
|
detected_classes = list(regions_dict.keys())
|
||||||
|
|
||||||
|
logger.debug(f"Multi-class validation: expected={expected_classes}, detected={detected_classes}")
|
||||||
|
|
||||||
|
# Check for required classes (flexible - at least one must match)
|
||||||
|
matching_classes = [cls for cls in expected_classes if cls in detected_classes]
|
||||||
|
if not matching_classes:
|
||||||
|
logger.warning(f"Multi-class validation failed: no expected classes detected")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info(f"Multi-class validation passed: {matching_classes} detected")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def run_detection_with_tracking(self,
|
||||||
|
frame,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
context: Optional[Dict[str, Any]] = None) -> Tuple[List[Dict[str, Any]], Dict[str, Any], Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Run YOLO detection with BoT-SORT tracking and stability validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: Input frame/image
|
||||||
|
node: Pipeline node configuration with model and settings
|
||||||
|
context: Optional context information (camera info, session data, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (all_detections, regions_dict, track_validation_result) where:
|
||||||
|
- all_detections: List of all detection objects
|
||||||
|
- regions_dict: Dict mapping class names to highest confidence detections
|
||||||
|
- track_validation_result: Dict with validation status and stable tracks
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
camera_id = context.get("camera_id", "unknown") if context else "unknown"
|
||||||
|
model_id = node.get("modelId", "unknown")
|
||||||
|
current_mode = context.get("current_mode", "unknown") if context else "unknown"
|
||||||
|
|
||||||
|
# Extract configuration
|
||||||
|
tracking_config = self._extract_tracking_config(node)
|
||||||
|
min_confidence = self._determine_confidence_threshold(node)
|
||||||
|
|
||||||
|
# Check if we need to reset tracker after cooldown
|
||||||
|
self._reset_yolo_tracker_if_needed(node, camera_id, model_id)
|
||||||
|
|
||||||
|
# Run YOLO inference
|
||||||
|
res = self._run_yolo_inference(frame, node, tracking_config)
|
||||||
|
|
||||||
|
# Process detection results
|
||||||
|
candidate_detections = self._process_detections(res, node, camera_id, min_confidence)
|
||||||
|
|
||||||
|
# Select best detection
|
||||||
|
best_detection = self._select_best_detection(candidate_detections, camera_id)
|
||||||
|
|
||||||
|
# Update track stability validation
|
||||||
|
is_branch_node = node.get("cropClass") is not None or node.get("parallel") is True
|
||||||
|
track_validation_result = self.stability_tracker.update_single_track_stability(
|
||||||
|
detection=best_detection,
|
||||||
|
camera_id=camera_id,
|
||||||
|
model_id=model_id,
|
||||||
|
stability_threshold=tracking_config.stability_threshold,
|
||||||
|
current_mode=current_mode,
|
||||||
|
is_branch_node=is_branch_node
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle no detection case
|
||||||
|
if best_detection is None:
|
||||||
|
# Store validation state in context for pipeline decisions
|
||||||
|
if context is not None:
|
||||||
|
context["track_validation_result"] = track_validation_result.to_dict()
|
||||||
|
|
||||||
|
return [], {}, track_validation_result.to_dict()
|
||||||
|
|
||||||
|
# Apply class mapping
|
||||||
|
best_detection = self._apply_class_mapping(best_detection, node)
|
||||||
|
|
||||||
|
# Create regions dictionary
|
||||||
|
mapped_class = best_detection["class"]
|
||||||
|
track_id = best_detection["id"]
|
||||||
|
|
||||||
|
all_detections = [best_detection]
|
||||||
|
regions_dict = {
|
||||||
|
mapped_class: {
|
||||||
|
"bbox": best_detection["bbox"],
|
||||||
|
"confidence": best_detection["confidence"],
|
||||||
|
"detection": best_detection,
|
||||||
|
"track_id": track_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Multi-class validation
|
||||||
|
if not self._validate_multi_class(regions_dict, node):
|
||||||
|
return [], {}, track_validation_result.to_dict()
|
||||||
|
|
||||||
|
logger.info(f"✅ Camera {camera_id}: DETECTION COMPLETE - tracking single car: track_id={track_id}, conf={best_detection['confidence']:.3f}")
|
||||||
|
logger.debug(f"📊 Camera {camera_id}: Detection summary: {len(res.boxes)} raw → {len(candidate_detections)} candidates → 1 selected")
|
||||||
|
|
||||||
|
# Store validation state in context for pipeline decisions
|
||||||
|
if context is not None:
|
||||||
|
context["track_validation_result"] = track_validation_result.to_dict()
|
||||||
|
|
||||||
|
return all_detections, regions_dict, track_validation_result.to_dict()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
camera_id = context.get("camera_id", "unknown") if context else "unknown"
|
||||||
|
model_id = node.get("modelId", "unknown")
|
||||||
|
raise create_detection_error(camera_id, model_id, "detection_with_tracking", e)
|
||||||
|
|
||||||
|
|
||||||
|
# Global instances for backward compatibility
|
||||||
|
_global_stability_tracker = StabilityTracker()
|
||||||
|
_global_yolo_detector = YOLODetector(_global_stability_tracker)
|
||||||
|
|
||||||
|
|
||||||
|
# ===== CONVENIENCE FUNCTIONS =====
|
||||||
|
# These provide the same interface as the original functions in pympta.py
|
||||||
|
|
||||||
|
def run_detection_with_tracking(frame, node: Dict[str, Any], context: Optional[Dict[str, Any]] = None) -> Tuple[List[Dict[str, Any]], Dict[str, Any], Dict[str, Any]]:
|
||||||
|
"""Run YOLO detection with BoT-SORT tracking using global detector instance."""
|
||||||
|
return _global_yolo_detector.run_detection_with_tracking(frame, node, context)
|
||||||
|
|
||||||
|
|
||||||
|
def get_camera_stability_data(camera_id: str, model_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get stability tracking data for a specific camera and model."""
|
||||||
|
stability_data = _global_stability_tracker.get_camera_stability_data(camera_id, model_id)
|
||||||
|
return stability_data.to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
def reset_camera_stability_tracking(camera_id: str, model_id: str) -> None:
|
||||||
|
"""Reset all stability tracking data for a specific camera and model."""
|
||||||
|
_global_stability_tracker.reset_camera_stability_tracking(camera_id, model_id)
|
||||||
|
|
||||||
|
|
||||||
|
def update_single_track_stability(node: Dict[str, Any],
|
||||||
|
detection: Optional[Dict[str, Any]],
|
||||||
|
camera_id: str,
|
||||||
|
frame_shape=None,
|
||||||
|
stability_threshold: int = DEFAULT_STABILITY_THRESHOLD,
|
||||||
|
context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||||
|
"""Update track stability validation for a single highest confidence detection."""
|
||||||
|
model_id = node.get("modelId", "unknown")
|
||||||
|
current_mode = context.get("current_mode", "unknown") if context else "unknown"
|
||||||
|
is_branch_node = node.get("cropClass") is not None or node.get("parallel") is True
|
||||||
|
|
||||||
|
result = _global_stability_tracker.update_single_track_stability(
|
||||||
|
detection=detection,
|
||||||
|
camera_id=camera_id,
|
||||||
|
model_id=model_id,
|
||||||
|
stability_threshold=stability_threshold,
|
||||||
|
current_mode=current_mode,
|
||||||
|
is_branch_node=is_branch_node
|
||||||
|
)
|
||||||
|
|
||||||
|
return result.to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
def reset_tracking_state(camera_id: str, model_id: str, reason: str = "reset_requested") -> None:
|
||||||
|
"""Reset tracking state for a camera and model."""
|
||||||
|
logger.info(f"🔄 Camera {camera_id}: Resetting tracking state for {model_id} - reason: {reason}")
|
||||||
|
reset_camera_stability_tracking(camera_id, model_id)
|
694
detector_worker/pipeline/pipeline_executor.py
Normal file
694
detector_worker/pipeline/pipeline_executor.py
Normal file
|
@ -0,0 +1,694 @@
|
||||||
|
"""
|
||||||
|
Pipeline execution engine for computer vision detection workflows.
|
||||||
|
|
||||||
|
This module provides the main pipeline execution functionality including:
|
||||||
|
- Multi-class detection coordination
|
||||||
|
- Branch processing (parallel and sequential)
|
||||||
|
- Action execution and database operations
|
||||||
|
- Session state management
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import concurrent.futures
|
||||||
|
from typing import Dict, List, Any, Optional, Tuple, Union
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ..core.constants import (
|
||||||
|
DEFAULT_THREAD_POOL_SIZE,
|
||||||
|
CLASSIFICATION_TIMEOUT_SECONDS,
|
||||||
|
PIPELINE_EXECUTION_TIMEOUT
|
||||||
|
)
|
||||||
|
from ..core.exceptions import PipelineError, create_pipeline_error
|
||||||
|
from ..detection.detection_result import DetectionResult, DetectionSession
|
||||||
|
from ..detection.yolo_detector import run_detection_with_tracking
|
||||||
|
from ..detection.stability_validator import validate_pipeline_execution
|
||||||
|
from ..detection.tracking_manager import is_camera_active, occupancy_detector
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineContext:
|
||||||
|
"""Context information passed through pipeline execution."""
|
||||||
|
camera_id: str = "unknown"
|
||||||
|
backend_session_id: Optional[str] = None
|
||||||
|
display_id: Optional[str] = None
|
||||||
|
current_mode: str = "unknown"
|
||||||
|
regions_dict: Optional[Dict[str, Any]] = None
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
timestamp: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"camera_id": self.camera_id,
|
||||||
|
"backend_session_id": self.backend_session_id,
|
||||||
|
"display_id": self.display_id,
|
||||||
|
"current_mode": self.current_mode,
|
||||||
|
"regions_dict": self.regions_dict,
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"timestamp": self.timestamp
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BranchResult:
|
||||||
|
"""Result from branch execution."""
|
||||||
|
model_id: str
|
||||||
|
success: bool
|
||||||
|
result: Optional[Dict[str, Any]] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
nested_results: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
data = {
|
||||||
|
"model_id": self.model_id,
|
||||||
|
"success": self.success
|
||||||
|
}
|
||||||
|
if self.result:
|
||||||
|
data["result"] = self.result
|
||||||
|
if self.error:
|
||||||
|
data["error"] = self.error
|
||||||
|
if self.nested_results:
|
||||||
|
data["nested_results"] = self.nested_results
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineResult:
|
||||||
|
"""Result from pipeline execution."""
|
||||||
|
success: bool
|
||||||
|
primary_detection: Optional[Dict[str, Any]] = None
|
||||||
|
primary_bbox: Optional[List[int]] = None
|
||||||
|
branch_results: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
awaiting_session_id: bool = False
|
||||||
|
awaiting_stable_tracks: bool = False
|
||||||
|
|
||||||
|
def to_tuple(self, return_bbox: bool = False) -> Union[Dict[str, Any], Tuple[Dict[str, Any], List[int]], Tuple[None, None]]:
|
||||||
|
"""Convert to return format expected by original function."""
|
||||||
|
if not self.success or not self.primary_detection:
|
||||||
|
return (None, None) if return_bbox else None
|
||||||
|
|
||||||
|
if return_bbox:
|
||||||
|
return (self.primary_detection, self.primary_bbox or [0, 0, 0, 0])
|
||||||
|
else:
|
||||||
|
return self.primary_detection
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineExecutor:
|
||||||
|
"""
|
||||||
|
Main pipeline execution engine for computer vision detection workflows.
|
||||||
|
|
||||||
|
This class handles the complete pipeline including detection, tracking,
|
||||||
|
branch processing, action execution, and database operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, thread_pool_size: int = DEFAULT_THREAD_POOL_SIZE):
|
||||||
|
"""
|
||||||
|
Initialize pipeline executor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_pool_size: Maximum number of threads for parallel processing
|
||||||
|
"""
|
||||||
|
self.thread_pool_size = thread_pool_size
|
||||||
|
|
||||||
|
def _extract_context(self, context: Optional[Dict[str, Any]]) -> PipelineContext:
|
||||||
|
"""Extract pipeline context from input dictionary."""
|
||||||
|
if not context:
|
||||||
|
return PipelineContext()
|
||||||
|
|
||||||
|
return PipelineContext(
|
||||||
|
camera_id=context.get("camera_id", "unknown"),
|
||||||
|
backend_session_id=context.get("backend_session_id"),
|
||||||
|
display_id=context.get("display_id"),
|
||||||
|
current_mode=context.get("current_mode", "unknown"),
|
||||||
|
regions_dict=context.get("regions_dict"),
|
||||||
|
session_id=context.get("session_id"),
|
||||||
|
timestamp=context.get("timestamp")
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_classification_task(self,
|
||||||
|
frame: np.ndarray,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
pipeline_context: PipelineContext,
|
||||||
|
return_bbox: bool) -> Union[Dict[str, Any], Tuple[Dict[str, Any], List[int]], Tuple[None, None]]:
|
||||||
|
"""Handle classification-only pipeline nodes."""
|
||||||
|
try:
|
||||||
|
results = node["model"].predict(frame, stream=False)
|
||||||
|
if not results:
|
||||||
|
return (None, None) if return_bbox else None
|
||||||
|
|
||||||
|
r = results[0]
|
||||||
|
probs = r.probs
|
||||||
|
if probs is None:
|
||||||
|
return (None, None) if return_bbox else None
|
||||||
|
|
||||||
|
top1_idx = int(probs.top1)
|
||||||
|
top1_conf = float(probs.top1conf)
|
||||||
|
class_name = node["model"].names[top1_idx]
|
||||||
|
|
||||||
|
det = {
|
||||||
|
"class": class_name,
|
||||||
|
"confidence": top1_conf,
|
||||||
|
"id": None,
|
||||||
|
class_name: class_name # Add class name as key for backward compatibility
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add specific field mappings for database operations based on model type
|
||||||
|
model_id = node.get("modelId", "").lower()
|
||||||
|
if "brand" in model_id or "brand_cls" in model_id:
|
||||||
|
det["brand"] = class_name
|
||||||
|
elif "bodytype" in model_id or "body" in model_id:
|
||||||
|
det["body_type"] = class_name
|
||||||
|
elif "color" in model_id:
|
||||||
|
det["color"] = class_name
|
||||||
|
|
||||||
|
# Execute actions for classification nodes
|
||||||
|
self._execute_node_actions(node, frame, det, pipeline_context.regions_dict)
|
||||||
|
|
||||||
|
return (det, None) if return_bbox else det
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in classification task for {node.get('modelId')}: {e}")
|
||||||
|
return (None, None) if return_bbox else None
|
||||||
|
|
||||||
|
def _check_camera_active(self, camera_id: str, model_id: str, return_bbox: bool) -> Optional[Union[Dict[str, Any], Tuple[Dict[str, Any], List[int]]]]:
|
||||||
|
"""Check if camera is active for processing."""
|
||||||
|
if not is_camera_active(camera_id, model_id):
|
||||||
|
logger.debug(f"⏰ Camera {camera_id}: Waiting for backend sessionId, sending 'none' detection")
|
||||||
|
none_detection = {
|
||||||
|
"class": "none",
|
||||||
|
"confidence": 1.0,
|
||||||
|
"bbox": [0, 0, 0, 0],
|
||||||
|
"branch_results": {}
|
||||||
|
}
|
||||||
|
return (none_detection, [0, 0, 0, 0]) if return_bbox else none_detection
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _run_detection_stage(self,
|
||||||
|
frame: np.ndarray,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
pipeline_context: PipelineContext,
|
||||||
|
validated_detection: Optional[Dict[str, Any]] = None) -> Tuple[List[Dict[str, Any]], Dict[str, Any], Dict[str, Any]]:
|
||||||
|
"""Run the detection stage of the pipeline."""
|
||||||
|
if validated_detection:
|
||||||
|
track_id = validated_detection.get('track_id')
|
||||||
|
logger.info(f"🔄 PIPELINE: Using validated detection from validation phase - track_id={track_id}")
|
||||||
|
|
||||||
|
# Convert validated detection back to all_detections format for branch processing
|
||||||
|
all_detections = [validated_detection]
|
||||||
|
|
||||||
|
# Create regions_dict based on validated detection class with proper structure
|
||||||
|
class_name = validated_detection.get("class", "car")
|
||||||
|
regions_dict = {
|
||||||
|
class_name: {
|
||||||
|
"confidence": validated_detection.get("confidence"),
|
||||||
|
"bbox": validated_detection.get("bbox", [0, 0, 0, 0]),
|
||||||
|
"detection": validated_detection
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Bypass track validation completely - force pipeline execution
|
||||||
|
track_validation_result = {
|
||||||
|
"validation_complete": True,
|
||||||
|
"stable_tracks": ["cached"], # Use dummy stable track to force pipeline execution
|
||||||
|
"current_tracks": ["cached"],
|
||||||
|
"bypass_validation": True
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Normal detection stage - Using structured detection function
|
||||||
|
all_detections, regions_dict, track_validation_result = run_detection_with_tracking(
|
||||||
|
frame, node, pipeline_context.to_dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_detections, regions_dict, track_validation_result
|
||||||
|
|
||||||
|
def _validate_tracking_requirements(self,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
track_validation_result: Dict[str, Any],
|
||||||
|
pipeline_context: PipelineContext,
|
||||||
|
return_bbox: bool) -> Optional[Union[Dict[str, Any], Tuple[Dict[str, Any], List[int]]]]:
|
||||||
|
"""Validate tracking requirements for pipeline execution."""
|
||||||
|
tracking_config = node.get("tracking", {})
|
||||||
|
stability_threshold = tracking_config.get("stabilityThreshold", node.get("stabilityThreshold", 1))
|
||||||
|
|
||||||
|
if stability_threshold <= 1 or not tracking_config.get("enabled", True):
|
||||||
|
return None # No tracking requirements
|
||||||
|
|
||||||
|
# Check if this is a branch node - branches should execute regardless of main validation state
|
||||||
|
is_branch_node = node.get("cropClass") is not None or node.get("parallel") is True
|
||||||
|
|
||||||
|
if is_branch_node:
|
||||||
|
logger.debug(f"🔍 Camera {pipeline_context.camera_id}: Branch node {node.get('modelId')} executing during track validation phase")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Main pipeline node during track validation - check for stable tracks
|
||||||
|
stable_tracks = track_validation_result.get("stable_tracks", [])
|
||||||
|
|
||||||
|
if not stable_tracks:
|
||||||
|
logger.debug(f"🔒 Camera {pipeline_context.camera_id}: Main pipeline requires stable tracks - none found, skipping pipeline execution")
|
||||||
|
none_detection = {
|
||||||
|
"class": "none",
|
||||||
|
"confidence": 1.0,
|
||||||
|
"bbox": [0, 0, 0, 0],
|
||||||
|
"awaiting_stable_tracks": True
|
||||||
|
}
|
||||||
|
return (none_detection, [0, 0, 0, 0]) if return_bbox else none_detection
|
||||||
|
|
||||||
|
logger.info(f"🎯 Camera {pipeline_context.camera_id}: STABLE TRACKS DETECTED - proceeding with full pipeline (tracks: {stable_tracks})")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _handle_database_operations(self,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
detection_result: Dict[str, Any],
|
||||||
|
regions_dict: Dict[str, Any],
|
||||||
|
pipeline_context: PipelineContext) -> None:
|
||||||
|
"""Handle database operations if database manager is available."""
|
||||||
|
if not (node.get("db_manager") and regions_dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
detected_classes = list(regions_dict.keys())
|
||||||
|
logger.debug(f"Valid detections found: {detected_classes}")
|
||||||
|
|
||||||
|
if pipeline_context.backend_session_id:
|
||||||
|
# Backend sessionId is available, proceed with database operations
|
||||||
|
display_id = pipeline_context.display_id or "unknown"
|
||||||
|
timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
||||||
|
|
||||||
|
inserted_session_id = node["db_manager"].insert_initial_detection(
|
||||||
|
display_id=display_id,
|
||||||
|
captured_timestamp=timestamp,
|
||||||
|
session_id=pipeline_context.backend_session_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if inserted_session_id:
|
||||||
|
detection_result["session_id"] = inserted_session_id
|
||||||
|
detection_result["timestamp"] = timestamp
|
||||||
|
logger.info(f"💾 DATABASE RECORD CREATED with backend session_id: {inserted_session_id}")
|
||||||
|
logger.debug(f"Database record: display_id={display_id}, timestamp={timestamp}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to create database record with backend session_id: {pipeline_context.backend_session_id}")
|
||||||
|
else:
|
||||||
|
logger.info(f"📡 Camera {pipeline_context.camera_id}: Full pipeline completed, detection data will be sent to backend. Database operations will occur when sessionId is received.")
|
||||||
|
# Store detection info for later database operations when sessionId arrives
|
||||||
|
detection_result["awaiting_session_id"] = True
|
||||||
|
detection_result["timestamp"] = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
||||||
|
|
||||||
|
def _execute_node_actions(self,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
frame: np.ndarray,
|
||||||
|
detection_result: Dict[str, Any],
|
||||||
|
regions_dict: Optional[Dict[str, Any]]) -> None:
|
||||||
|
"""Execute actions for a node."""
|
||||||
|
# This is a placeholder for action execution
|
||||||
|
# In the actual implementation, this would import and call execute_actions
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _execute_parallel_actions(self,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
frame: np.ndarray,
|
||||||
|
detection_result: Dict[str, Any],
|
||||||
|
regions_dict: Dict[str, Any]) -> None:
|
||||||
|
"""Execute parallel actions after branch completion."""
|
||||||
|
# This is a placeholder for parallel action execution
|
||||||
|
# In the actual implementation, this would import and call execute_parallel_actions
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _crop_region_by_class(self,
|
||||||
|
frame: np.ndarray,
|
||||||
|
regions_dict: Dict[str, Any],
|
||||||
|
class_name: str) -> Optional[np.ndarray]:
|
||||||
|
"""Crop a specific region from frame based on detected class."""
|
||||||
|
if class_name not in regions_dict:
|
||||||
|
logger.warning(f"Class '{class_name}' not found in detected regions")
|
||||||
|
return None
|
||||||
|
|
||||||
|
bbox = regions_dict[class_name]["bbox"]
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
|
||||||
|
# Validate bbox coordinates
|
||||||
|
if x2 <= x1 or y2 <= y1:
|
||||||
|
logger.warning(f"Invalid bbox for class {class_name}: {bbox}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
cropped = frame[y1:y2, x1:x2]
|
||||||
|
if cropped.size == 0:
|
||||||
|
logger.warning(f"Empty crop for class {class_name}")
|
||||||
|
return None
|
||||||
|
return cropped
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error cropping region for class {class_name}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _prepare_branch_context(self,
|
||||||
|
base_context: PipelineContext,
|
||||||
|
regions_dict: Dict[str, Any],
|
||||||
|
detection_result: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Prepare context for branch execution."""
|
||||||
|
branch_context = base_context.to_dict()
|
||||||
|
branch_context["regions_dict"] = regions_dict
|
||||||
|
|
||||||
|
# Pass session_id from detection_result to branch context for Redis actions
|
||||||
|
if "session_id" in detection_result:
|
||||||
|
branch_context["session_id"] = detection_result["session_id"]
|
||||||
|
logger.debug(f"Added session_id to branch context: {detection_result['session_id']}")
|
||||||
|
elif base_context.backend_session_id:
|
||||||
|
branch_context["session_id"] = base_context.backend_session_id
|
||||||
|
logger.debug(f"Added backend_session_id to branch context: {base_context.backend_session_id}")
|
||||||
|
|
||||||
|
return branch_context
|
||||||
|
|
||||||
|
def _execute_single_branch(self,
|
||||||
|
frame: np.ndarray,
|
||||||
|
branch: Dict[str, Any],
|
||||||
|
branch_context: Dict[str, Any],
|
||||||
|
regions_dict: Dict[str, Any]) -> BranchResult:
|
||||||
|
"""Execute a single branch."""
|
||||||
|
model_id = branch["modelId"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
sub_frame = frame
|
||||||
|
crop_class = branch.get("cropClass")
|
||||||
|
|
||||||
|
logger.info(f"Starting branch: {model_id}, cropClass: {crop_class}")
|
||||||
|
|
||||||
|
# Handle cropping if required
|
||||||
|
if branch.get("crop", False) and crop_class:
|
||||||
|
if crop_class in regions_dict:
|
||||||
|
cropped = self._crop_region_by_class(frame, regions_dict, crop_class)
|
||||||
|
if cropped is not None:
|
||||||
|
sub_frame = cropped # Use cropped image without manual resizing
|
||||||
|
logger.debug(f"Successfully cropped {crop_class} region for {model_id} - model will handle resizing")
|
||||||
|
else:
|
||||||
|
return BranchResult(
|
||||||
|
model_id=model_id,
|
||||||
|
success=False,
|
||||||
|
error=f"Failed to crop {crop_class} region"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return BranchResult(
|
||||||
|
model_id=model_id,
|
||||||
|
success=False,
|
||||||
|
error=f"Crop class {crop_class} not found in detected regions"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute branch pipeline
|
||||||
|
result, _ = self.run_pipeline(sub_frame, branch, True, branch_context)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
branch_result = BranchResult(
|
||||||
|
model_id=model_id,
|
||||||
|
success=True,
|
||||||
|
result=result
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect nested branch results if they exist
|
||||||
|
if "branch_results" in result:
|
||||||
|
branch_result.nested_results = result["branch_results"]
|
||||||
|
|
||||||
|
logger.info(f"Branch {model_id} completed: {result}")
|
||||||
|
return branch_result
|
||||||
|
else:
|
||||||
|
return BranchResult(
|
||||||
|
model_id=model_id,
|
||||||
|
success=False,
|
||||||
|
error="Branch returned no result"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in branch {model_id}: {e}")
|
||||||
|
return BranchResult(
|
||||||
|
model_id=model_id,
|
||||||
|
success=False,
|
||||||
|
error=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_branches_parallel(self,
|
||||||
|
frame: np.ndarray,
|
||||||
|
active_branches: List[Dict[str, Any]],
|
||||||
|
branch_context: Dict[str, Any],
|
||||||
|
regions_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Process branches in parallel."""
|
||||||
|
branch_results = {}
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=len(active_branches)) as executor:
|
||||||
|
futures = {}
|
||||||
|
|
||||||
|
for branch in active_branches:
|
||||||
|
future = executor.submit(
|
||||||
|
self._execute_single_branch,
|
||||||
|
frame, branch, branch_context, regions_dict
|
||||||
|
)
|
||||||
|
futures[future] = branch
|
||||||
|
|
||||||
|
# Collect results
|
||||||
|
for future in concurrent.futures.as_completed(futures):
|
||||||
|
branch = futures[future]
|
||||||
|
try:
|
||||||
|
result = future.result()
|
||||||
|
if result.success and result.result:
|
||||||
|
branch_results[result.model_id] = result.result
|
||||||
|
|
||||||
|
# Collect nested branch results
|
||||||
|
for nested_id, nested_result in result.nested_results.items():
|
||||||
|
branch_results[nested_id] = nested_result
|
||||||
|
logger.info(f"Collected nested branch result: {nested_id} = {nested_result}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Branch {result.model_id} failed: {result.error}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Branch {branch['modelId']} failed: {e}")
|
||||||
|
|
||||||
|
return branch_results
|
||||||
|
|
||||||
|
def _process_branches_sequential(self,
|
||||||
|
frame: np.ndarray,
|
||||||
|
active_branches: List[Dict[str, Any]],
|
||||||
|
branch_context: Dict[str, Any],
|
||||||
|
regions_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Process branches sequentially."""
|
||||||
|
branch_results = {}
|
||||||
|
|
||||||
|
for branch in active_branches:
|
||||||
|
result = self._execute_single_branch(frame, branch, branch_context, regions_dict)
|
||||||
|
|
||||||
|
if result.success and result.result:
|
||||||
|
branch_results[result.model_id] = result.result
|
||||||
|
|
||||||
|
# Collect nested branch results
|
||||||
|
for nested_id, nested_result in result.nested_results.items():
|
||||||
|
branch_results[nested_id] = nested_result
|
||||||
|
logger.info(f"Collected nested branch result: {nested_id} = {nested_result}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Branch {result.model_id} failed: {result.error}")
|
||||||
|
|
||||||
|
return branch_results
|
||||||
|
|
||||||
|
def _filter_active_branches(self,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
regions_dict: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||||
|
"""Filter branches that should be triggered based on detected regions."""
|
||||||
|
active_branches = []
|
||||||
|
|
||||||
|
for br in node["branches"]:
|
||||||
|
trigger_classes = br.get("triggerClasses", [])
|
||||||
|
min_conf = br.get("minConfidence", 0)
|
||||||
|
|
||||||
|
logger.debug(f"Evaluating branch {br['modelId']}: trigger_classes={trigger_classes}, min_conf={min_conf}")
|
||||||
|
|
||||||
|
# Check if any detected class matches branch trigger
|
||||||
|
branch_triggered = False
|
||||||
|
for det_class in regions_dict:
|
||||||
|
det_confidence = regions_dict[det_class]["confidence"]
|
||||||
|
logger.debug(f" Checking detected class '{det_class}' (confidence={det_confidence:.3f}) against triggers {trigger_classes}")
|
||||||
|
|
||||||
|
if (det_class in trigger_classes and det_confidence >= min_conf):
|
||||||
|
active_branches.append(br)
|
||||||
|
branch_triggered = True
|
||||||
|
logger.info(f"Branch {br['modelId']} activated by class '{det_class}' (conf={det_confidence:.3f} >= {min_conf})")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not branch_triggered:
|
||||||
|
logger.debug(f"Branch {br['modelId']} not triggered - no matching classes or insufficient confidence")
|
||||||
|
|
||||||
|
return active_branches
|
||||||
|
|
||||||
|
def _process_branches(self,
|
||||||
|
frame: np.ndarray,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
detection_result: Dict[str, Any],
|
||||||
|
regions_dict: Dict[str, Any],
|
||||||
|
pipeline_context: PipelineContext) -> Dict[str, Any]:
|
||||||
|
"""Process all branches for a node."""
|
||||||
|
if not node.get("branches"):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Filter branches that should be triggered
|
||||||
|
active_branches = self._filter_active_branches(node, regions_dict)
|
||||||
|
|
||||||
|
if not active_branches:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Prepare branch context
|
||||||
|
branch_context = self._prepare_branch_context(pipeline_context, regions_dict, detection_result)
|
||||||
|
|
||||||
|
# Execute branches
|
||||||
|
if node.get("parallel", False) or any(br.get("parallel", False) for br in active_branches):
|
||||||
|
# Run branches in parallel
|
||||||
|
branch_results = self._process_branches_parallel(frame, active_branches, branch_context, regions_dict)
|
||||||
|
else:
|
||||||
|
# Run branches sequentially
|
||||||
|
branch_results = self._process_branches_sequential(frame, active_branches, branch_context, regions_dict)
|
||||||
|
|
||||||
|
return branch_results
|
||||||
|
|
||||||
|
def run_pipeline(self,
|
||||||
|
frame: np.ndarray,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
return_bbox: bool = False,
|
||||||
|
context: Optional[Dict[str, Any]] = None,
|
||||||
|
validated_detection: Optional[Dict[str, Any]] = None) -> Union[Dict[str, Any], Tuple[Dict[str, Any], List[int]], Tuple[None, None]]:
|
||||||
|
"""
|
||||||
|
Run enhanced pipeline that supports:
|
||||||
|
- Multi-class detection (detecting multiple classes simultaneously)
|
||||||
|
- Parallel branch processing
|
||||||
|
- Region-based actions and cropping
|
||||||
|
- Context passing for session/camera information
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: Input frame for processing
|
||||||
|
node: Pipeline node configuration
|
||||||
|
return_bbox: Whether to return bounding box with result
|
||||||
|
context: Optional context information
|
||||||
|
validated_detection: Optional pre-validated detection to use
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Detection result, optionally with bounding box
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Extract context information
|
||||||
|
pipeline_context = self._extract_context(context)
|
||||||
|
model_id = node.get("modelId", "unknown")
|
||||||
|
|
||||||
|
if pipeline_context.backend_session_id:
|
||||||
|
logger.info(f"🔑 PIPELINE USING BACKEND SESSION_ID: {pipeline_context.backend_session_id} for camera {pipeline_context.camera_id}")
|
||||||
|
|
||||||
|
task = getattr(node["model"], "task", None)
|
||||||
|
|
||||||
|
# ─── Classification stage ───────────────────────────────────
|
||||||
|
if task == "classify":
|
||||||
|
return self._handle_classification_task(frame, node, pipeline_context, return_bbox)
|
||||||
|
|
||||||
|
# ─── Session management check ───────────────────────────────────────
|
||||||
|
camera_inactive_result = self._check_camera_active(pipeline_context.camera_id, model_id, return_bbox)
|
||||||
|
if camera_inactive_result is not None:
|
||||||
|
return camera_inactive_result
|
||||||
|
|
||||||
|
# ─── Detection stage ───
|
||||||
|
all_detections, regions_dict, track_validation_result = self._run_detection_stage(
|
||||||
|
frame, node, pipeline_context, validated_detection
|
||||||
|
)
|
||||||
|
|
||||||
|
if not all_detections:
|
||||||
|
logger.debug("No detections from structured detection function - sending 'none' detection")
|
||||||
|
none_detection = {
|
||||||
|
"class": "none",
|
||||||
|
"confidence": 1.0,
|
||||||
|
"bbox": [0, 0, 0, 0],
|
||||||
|
"branch_results": {}
|
||||||
|
}
|
||||||
|
return (none_detection, [0, 0, 0, 0]) if return_bbox else none_detection
|
||||||
|
|
||||||
|
# ─── Track-Based Validation System ────────────────────────
|
||||||
|
tracking_validation_result = self._validate_tracking_requirements(
|
||||||
|
node, track_validation_result, pipeline_context, return_bbox
|
||||||
|
)
|
||||||
|
if tracking_validation_result is not None:
|
||||||
|
return tracking_validation_result
|
||||||
|
|
||||||
|
# ─── Pre-validate pipeline execution ────────────────────────
|
||||||
|
pipeline_valid, missing_branches = validate_pipeline_execution(node, regions_dict)
|
||||||
|
|
||||||
|
if not pipeline_valid:
|
||||||
|
logger.error(f"Pipeline execution validation FAILED - required branches {missing_branches} cannot execute")
|
||||||
|
logger.error("Aborting pipeline: no Redis actions or database records will be created")
|
||||||
|
return (None, None) if return_bbox else None
|
||||||
|
|
||||||
|
# ─── Execute actions with region information ────────────────
|
||||||
|
detection_result = {
|
||||||
|
"detections": all_detections,
|
||||||
|
"regions": regions_dict,
|
||||||
|
**pipeline_context.to_dict()
|
||||||
|
}
|
||||||
|
|
||||||
|
# ─── Database operations ────
|
||||||
|
self._handle_database_operations(node, detection_result, regions_dict, pipeline_context)
|
||||||
|
|
||||||
|
# Execute actions for root node only if it doesn't have branches
|
||||||
|
# Branch nodes with actions will execute them after branch processing
|
||||||
|
if not node.get("branches") or node.get("modelId") == "yolo11n":
|
||||||
|
self._execute_node_actions(node, frame, detection_result, regions_dict)
|
||||||
|
|
||||||
|
# ─── Branch processing ─────────────────────────────
|
||||||
|
branch_results = self._process_branches(frame, node, detection_result, regions_dict, pipeline_context)
|
||||||
|
detection_result["branch_results"] = branch_results
|
||||||
|
|
||||||
|
# ─── Execute Parallel Actions ───────────────────────────────
|
||||||
|
if node.get("parallelActions") and "branch_results" in detection_result:
|
||||||
|
self._execute_parallel_actions(node, frame, detection_result, regions_dict)
|
||||||
|
|
||||||
|
# ─── Auto-enable occupancy mode after successful pipeline completion ─────────────────
|
||||||
|
occupancy_detector(pipeline_context.camera_id, model_id, enable=True)
|
||||||
|
|
||||||
|
logger.info(f"✅ Camera {pipeline_context.camera_id}: Pipeline completed, detection data will be sent to backend")
|
||||||
|
logger.info(f"🛑 Camera {pipeline_context.camera_id}: Model will stop inference for future frames")
|
||||||
|
logger.info(f"📡 Backend sessionId will be handled when received via WebSocket")
|
||||||
|
|
||||||
|
# ─── Execute actions after successful detection AND branch processing ──────────
|
||||||
|
# This ensures detection nodes (like frontal_detection_v1) execute their actions
|
||||||
|
# after completing both detection and branch processing
|
||||||
|
if node.get("actions") and regions_dict and node.get("modelId") != "yolo11n":
|
||||||
|
# Execute actions for branch detection nodes, skip root to avoid duplication
|
||||||
|
logger.debug(f"Executing post-detection actions for branch node {node.get('modelId')}")
|
||||||
|
self._execute_node_actions(node, frame, detection_result, regions_dict)
|
||||||
|
|
||||||
|
# ─── Return detection result ────────────────────────────────
|
||||||
|
primary_detection = max(all_detections, key=lambda x: x["confidence"])
|
||||||
|
primary_bbox = primary_detection["bbox"]
|
||||||
|
|
||||||
|
# Add branch results and session_id to primary detection for compatibility
|
||||||
|
if "branch_results" in detection_result:
|
||||||
|
primary_detection["branch_results"] = detection_result["branch_results"]
|
||||||
|
if "session_id" in detection_result:
|
||||||
|
primary_detection["session_id"] = detection_result["session_id"]
|
||||||
|
|
||||||
|
return (primary_detection, primary_bbox) if return_bbox else primary_detection
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pipeline_id = node.get("modelId", "unknown")
|
||||||
|
raise create_pipeline_error(pipeline_id, "pipeline_execution", e)
|
||||||
|
|
||||||
|
|
||||||
|
# Global pipeline executor instance
|
||||||
|
pipeline_executor = PipelineExecutor()
|
||||||
|
|
||||||
|
|
||||||
|
# ===== CONVENIENCE FUNCTIONS =====
|
||||||
|
# These provide the same interface as the original functions in pympta.py
|
||||||
|
|
||||||
|
def run_pipeline(frame: np.ndarray,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
return_bbox: bool = False,
|
||||||
|
context: Optional[Dict[str, Any]] = None,
|
||||||
|
validated_detection: Optional[Dict[str, Any]] = None) -> Union[Dict[str, Any], Tuple[Dict[str, Any], List[int]], Tuple[None, None]]:
|
||||||
|
"""Run enhanced pipeline using global executor instance."""
|
||||||
|
return pipeline_executor.run_pipeline(frame, node, return_bbox, context, validated_detection)
|
||||||
|
|
||||||
|
|
||||||
|
def crop_region_by_class(frame: np.ndarray, regions_dict: Dict[str, Any], class_name: str) -> Optional[np.ndarray]:
|
||||||
|
"""Crop a specific region from frame based on detected class."""
|
||||||
|
return pipeline_executor._crop_region_by_class(frame, regions_dict, class_name)
|
345
detector_worker/streams/camera_monitor.py
Normal file
345
detector_worker/streams/camera_monitor.py
Normal file
|
@ -0,0 +1,345 @@
|
||||||
|
"""
|
||||||
|
Camera connection state monitoring and management.
|
||||||
|
|
||||||
|
This module provides centralized tracking of camera connection states,
|
||||||
|
error handling, and disconnection notifications.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from ..core.exceptions import CameraConnectionError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CameraConnectionState:
|
||||||
|
"""Represents the connection state of a single camera."""
|
||||||
|
connected: bool = True
|
||||||
|
last_error: Optional[str] = None
|
||||||
|
last_error_time: Optional[float] = None
|
||||||
|
consecutive_failures: int = 0
|
||||||
|
disconnection_notified: bool = False
|
||||||
|
last_successful_frame: Optional[float] = None
|
||||||
|
connection_start_time: Optional[float] = field(default_factory=time.time)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"connected": self.connected,
|
||||||
|
"last_error": self.last_error,
|
||||||
|
"last_error_time": self.last_error_time,
|
||||||
|
"consecutive_failures": self.consecutive_failures,
|
||||||
|
"disconnection_notified": self.disconnection_notified,
|
||||||
|
"last_successful_frame": self.last_successful_frame,
|
||||||
|
"connection_start_time": self.connection_start_time
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CameraMonitor:
|
||||||
|
"""
|
||||||
|
Monitors camera connection states and handles disconnection notifications.
|
||||||
|
|
||||||
|
This class provides a centralized way to track camera connection status,
|
||||||
|
handle error states, and determine when to notify about disconnections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, failure_threshold: int = 3):
|
||||||
|
"""
|
||||||
|
Initialize camera monitor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
failure_threshold: Number of consecutive failures before considering disconnected
|
||||||
|
"""
|
||||||
|
self.failure_threshold = failure_threshold
|
||||||
|
self._camera_states: Dict[str, CameraConnectionState] = {}
|
||||||
|
self._lock = None # Will be initialized when needed for thread safety
|
||||||
|
|
||||||
|
def _ensure_thread_safety(self):
|
||||||
|
"""Initialize thread safety if not already done."""
|
||||||
|
if self._lock is None:
|
||||||
|
import threading
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
def get_camera_state(self, camera_id: str) -> CameraConnectionState:
|
||||||
|
"""
|
||||||
|
Get or create camera connection state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current connection state for the camera
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if camera_id not in self._camera_states:
|
||||||
|
self._camera_states[camera_id] = CameraConnectionState()
|
||||||
|
logger.debug(f"Initialized connection state for camera {camera_id}")
|
||||||
|
|
||||||
|
return self._camera_states[camera_id]
|
||||||
|
|
||||||
|
def set_camera_connected(self,
|
||||||
|
camera_id: str,
|
||||||
|
connected: bool = True,
|
||||||
|
error_msg: Optional[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
Set camera connection state and track error information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
connected: Whether camera is connected
|
||||||
|
error_msg: Error message if disconnected
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
state = self.get_camera_state(camera_id)
|
||||||
|
|
||||||
|
if connected:
|
||||||
|
# Camera is now connected
|
||||||
|
if not state.connected:
|
||||||
|
logger.info(f"Camera {camera_id} reconnected successfully")
|
||||||
|
|
||||||
|
state.connected = True
|
||||||
|
state.last_successful_frame = current_time
|
||||||
|
state.consecutive_failures = 0
|
||||||
|
state.disconnection_notified = False
|
||||||
|
state.last_error = None
|
||||||
|
state.last_error_time = None
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Camera is now disconnected
|
||||||
|
state.connected = False
|
||||||
|
state.consecutive_failures += 1
|
||||||
|
state.last_error = error_msg
|
||||||
|
state.last_error_time = current_time
|
||||||
|
|
||||||
|
if state.consecutive_failures == 1:
|
||||||
|
logger.warning(f"Camera {camera_id} connection lost: {error_msg}")
|
||||||
|
elif state.consecutive_failures >= self.failure_threshold:
|
||||||
|
logger.error(f"Camera {camera_id} has {state.consecutive_failures} consecutive failures")
|
||||||
|
|
||||||
|
logger.debug(f"Camera {camera_id} state updated - failures: {state.consecutive_failures}")
|
||||||
|
|
||||||
|
def is_camera_connected(self, camera_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if camera is currently connected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if camera is connected, False otherwise
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
state = self._camera_states.get(camera_id)
|
||||||
|
return state.connected if state else True # Default to connected for new cameras
|
||||||
|
|
||||||
|
def should_notify_disconnection(self, camera_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if we should notify backend about disconnection.
|
||||||
|
|
||||||
|
A disconnection notification should be sent when:
|
||||||
|
1. Camera is disconnected
|
||||||
|
2. We haven't already notified about this disconnection
|
||||||
|
3. We have enough consecutive failures
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if disconnection notification should be sent
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
state = self._camera_states.get(camera_id)
|
||||||
|
if not state:
|
||||||
|
return False
|
||||||
|
|
||||||
|
is_disconnected = not state.connected
|
||||||
|
not_yet_notified = not state.disconnection_notified
|
||||||
|
has_enough_failures = state.consecutive_failures >= self.failure_threshold
|
||||||
|
|
||||||
|
should_notify = is_disconnected and not_yet_notified and has_enough_failures
|
||||||
|
|
||||||
|
if should_notify:
|
||||||
|
logger.info(f"Camera {camera_id} qualifies for disconnection notification - "
|
||||||
|
f"failures: {state.consecutive_failures}, error: {state.last_error}")
|
||||||
|
|
||||||
|
return should_notify
|
||||||
|
|
||||||
|
def mark_disconnection_notified(self, camera_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Mark that we've notified backend about this disconnection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if camera_id in self._camera_states:
|
||||||
|
self._camera_states[camera_id].disconnection_notified = True
|
||||||
|
logger.debug(f"Marked disconnection notification sent for camera {camera_id}")
|
||||||
|
|
||||||
|
def get_connection_stats(self, camera_id: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get comprehensive connection statistics for a camera.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with connection statistics
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
state = self._camera_states.get(camera_id)
|
||||||
|
if not state:
|
||||||
|
return {"error": "Camera not found"}
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
stats = state.to_dict()
|
||||||
|
|
||||||
|
# Add computed stats
|
||||||
|
if state.connection_start_time:
|
||||||
|
stats["uptime_seconds"] = current_time - state.connection_start_time
|
||||||
|
|
||||||
|
if state.last_successful_frame:
|
||||||
|
stats["seconds_since_last_frame"] = current_time - state.last_successful_frame
|
||||||
|
|
||||||
|
if state.last_error_time:
|
||||||
|
stats["seconds_since_last_error"] = current_time - state.last_error_time
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def reset_camera_state(self, camera_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Reset camera state to initial connected state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if camera_id in self._camera_states:
|
||||||
|
del self._camera_states[camera_id]
|
||||||
|
logger.info(f"Reset connection state for camera {camera_id}")
|
||||||
|
|
||||||
|
def cleanup_inactive_cameras(self, inactive_threshold_seconds: int = 3600) -> int:
|
||||||
|
"""
|
||||||
|
Remove states for cameras inactive for too long.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inactive_threshold_seconds: Seconds of inactivity before cleanup
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of camera states cleaned up
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
cleanup_count = 0
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
camera_ids_to_remove = []
|
||||||
|
|
||||||
|
for camera_id, state in self._camera_states.items():
|
||||||
|
last_activity = max(
|
||||||
|
state.connection_start_time or 0,
|
||||||
|
state.last_successful_frame or 0,
|
||||||
|
state.last_error_time or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_time - last_activity > inactive_threshold_seconds:
|
||||||
|
camera_ids_to_remove.append(camera_id)
|
||||||
|
|
||||||
|
for camera_id in camera_ids_to_remove:
|
||||||
|
del self._camera_states[camera_id]
|
||||||
|
cleanup_count += 1
|
||||||
|
logger.debug(f"Cleaned up inactive camera state for {camera_id}")
|
||||||
|
|
||||||
|
if cleanup_count > 0:
|
||||||
|
logger.info(f"Cleaned up {cleanup_count} inactive camera states")
|
||||||
|
|
||||||
|
return cleanup_count
|
||||||
|
|
||||||
|
def get_all_camera_states(self) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get connection states for all monitored cameras.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping camera IDs to their connection states
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
return {
|
||||||
|
camera_id: state.to_dict()
|
||||||
|
for camera_id, state in self._camera_states.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_disconnected_cameras(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get list of currently disconnected camera IDs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of camera IDs that are currently disconnected
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
return [
|
||||||
|
camera_id for camera_id, state in self._camera_states.items()
|
||||||
|
if not state.connected
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Global camera monitor instance
|
||||||
|
camera_monitor = CameraMonitor()
|
||||||
|
|
||||||
|
|
||||||
|
# ===== CONVENIENCE FUNCTIONS =====
|
||||||
|
# These provide the same interface as the original functions in app.py
|
||||||
|
|
||||||
|
def set_camera_connected(camera_id: str, connected: bool = True, error_msg: Optional[str] = None) -> None:
|
||||||
|
"""Set camera connection state and track error information."""
|
||||||
|
camera_monitor.set_camera_connected(camera_id, connected, error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
def is_camera_connected(camera_id: str) -> bool:
|
||||||
|
"""Check if camera is currently connected."""
|
||||||
|
return camera_monitor.is_camera_connected(camera_id)
|
||||||
|
|
||||||
|
|
||||||
|
def should_notify_disconnection(camera_id: str) -> bool:
|
||||||
|
"""Check if we should notify backend about disconnection."""
|
||||||
|
return camera_monitor.should_notify_disconnection(camera_id)
|
||||||
|
|
||||||
|
|
||||||
|
def mark_disconnection_notified(camera_id: str) -> None:
|
||||||
|
"""Mark that we've notified backend about this disconnection."""
|
||||||
|
camera_monitor.mark_disconnection_notified(camera_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_connection_stats(camera_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get comprehensive connection statistics for a camera."""
|
||||||
|
return camera_monitor.get_connection_stats(camera_id)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_camera_state(camera_id: str) -> None:
|
||||||
|
"""Reset camera state to initial connected state."""
|
||||||
|
camera_monitor.reset_camera_state(camera_id)
|
476
detector_worker/streams/frame_reader.py
Normal file
476
detector_worker/streams/frame_reader.py
Normal file
|
@ -0,0 +1,476 @@
|
||||||
|
"""
|
||||||
|
Frame reading implementations for RTSP and HTTP snapshot streams.
|
||||||
|
|
||||||
|
This module provides thread-safe frame readers for different camera stream types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import time
|
||||||
|
import queue
|
||||||
|
import logging
|
||||||
|
import requests
|
||||||
|
import threading
|
||||||
|
from typing import Optional, Any
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ..core.constants import (
|
||||||
|
DEFAULT_RECONNECT_INTERVAL_SEC,
|
||||||
|
DEFAULT_MAX_RETRIES,
|
||||||
|
HTTP_SNAPSHOT_TIMEOUT,
|
||||||
|
SHARED_STREAM_BUFFER_SIZE
|
||||||
|
)
|
||||||
|
from ..core.exceptions import (
|
||||||
|
CameraConnectionError,
|
||||||
|
FrameReadError,
|
||||||
|
create_stream_error
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_snapshot(url: str, timeout: int = HTTP_SNAPSHOT_TIMEOUT) -> Optional[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Fetch a single snapshot from HTTP/HTTPS URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: HTTP/HTTPS URL to fetch snapshot from
|
||||||
|
timeout: Request timeout in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decoded image frame or None if fetch failed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.debug(f"Fetching snapshot from {url}")
|
||||||
|
response = requests.get(url, timeout=timeout, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Check if response has content
|
||||||
|
if not response.content:
|
||||||
|
logger.warning(f"Empty response from snapshot URL: {url}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Decode image from response bytes
|
||||||
|
img_array = np.frombuffer(response.content, dtype=np.uint8)
|
||||||
|
frame = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
||||||
|
|
||||||
|
if frame is None:
|
||||||
|
logger.warning(f"Failed to decode image from snapshot URL: {url}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug(f"Successfully fetched snapshot from {url}, shape: {frame.shape}")
|
||||||
|
return frame
|
||||||
|
|
||||||
|
except requests.exceptions.Timeout:
|
||||||
|
logger.warning(f"Timeout fetching snapshot from {url}")
|
||||||
|
return None
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
logger.warning(f"Connection error fetching snapshot from {url}")
|
||||||
|
return None
|
||||||
|
except requests.exceptions.HTTPError as e:
|
||||||
|
logger.warning(f"HTTP error fetching snapshot from {url}: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error fetching snapshot from {url}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class RTSPFrameReader:
|
||||||
|
"""Thread-safe RTSP frame reader."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
camera_id: str,
|
||||||
|
rtsp_url: str,
|
||||||
|
buffer: queue.Queue,
|
||||||
|
stop_event: threading.Event,
|
||||||
|
reconnect_interval: int = DEFAULT_RECONNECT_INTERVAL_SEC,
|
||||||
|
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||||
|
connection_callback=None):
|
||||||
|
"""
|
||||||
|
Initialize RTSP frame reader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
rtsp_url: RTSP stream URL
|
||||||
|
buffer: Queue to put frames into
|
||||||
|
stop_event: Event to signal thread shutdown
|
||||||
|
reconnect_interval: Seconds between reconnection attempts
|
||||||
|
max_retries: Maximum retry attempts (-1 for unlimited)
|
||||||
|
connection_callback: Callback function for connection state changes
|
||||||
|
"""
|
||||||
|
self.camera_id = camera_id
|
||||||
|
self.rtsp_url = rtsp_url
|
||||||
|
self.buffer = buffer
|
||||||
|
self.stop_event = stop_event
|
||||||
|
self.reconnect_interval = reconnect_interval
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.connection_callback = connection_callback
|
||||||
|
|
||||||
|
self.cap: Optional[cv2.VideoCapture] = None
|
||||||
|
self.retries = 0
|
||||||
|
self.frame_count = 0
|
||||||
|
self.last_log_time = time.time()
|
||||||
|
|
||||||
|
def _set_connection_state(self, connected: bool, error_msg: Optional[str] = None):
|
||||||
|
"""Update connection state via callback."""
|
||||||
|
if self.connection_callback:
|
||||||
|
self.connection_callback(self.camera_id, connected, error_msg)
|
||||||
|
|
||||||
|
def _initialize_capture(self) -> bool:
|
||||||
|
"""Initialize video capture."""
|
||||||
|
try:
|
||||||
|
self.cap = cv2.VideoCapture(self.rtsp_url)
|
||||||
|
if self.cap.isOpened():
|
||||||
|
# Log camera properties
|
||||||
|
width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
fps = self.cap.get(cv2.CAP_PROP_FPS)
|
||||||
|
logger.info(f"Camera {self.camera_id} opened: {width}x{height}, FPS: {fps}")
|
||||||
|
self._set_connection_state(True)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error(f"Camera {self.camera_id} failed to open")
|
||||||
|
self._set_connection_state(False, "Failed to open camera")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to initialize capture: {e}"
|
||||||
|
logger.error(f"Camera {self.camera_id}: {error_msg}")
|
||||||
|
self._set_connection_state(False, error_msg)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _reconnect(self) -> bool:
|
||||||
|
"""Attempt to reconnect to RTSP stream."""
|
||||||
|
if self.cap:
|
||||||
|
self.cap.release()
|
||||||
|
self.cap = None
|
||||||
|
|
||||||
|
logger.info(f"Attempting to reconnect RTSP stream for camera: {self.camera_id}")
|
||||||
|
time.sleep(self.reconnect_interval)
|
||||||
|
|
||||||
|
return self._initialize_capture()
|
||||||
|
|
||||||
|
def _read_frame(self) -> Optional[np.ndarray]:
|
||||||
|
"""Read a single frame from the stream."""
|
||||||
|
if not self.cap or not self.cap.isOpened():
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
ret, frame = self.cap.read()
|
||||||
|
if not ret:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
self.frame_count += 1
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Log frame stats every 5 seconds
|
||||||
|
if current_time - self.last_log_time > 5:
|
||||||
|
elapsed = current_time - self.last_log_time
|
||||||
|
logger.info(f"Camera {self.camera_id}: Read {self.frame_count} frames in {elapsed:.1f}s")
|
||||||
|
self.frame_count = 0
|
||||||
|
self.last_log_time = current_time
|
||||||
|
|
||||||
|
return frame
|
||||||
|
|
||||||
|
except cv2.error as e:
|
||||||
|
raise FrameReadError(f"OpenCV error reading frame: {e}")
|
||||||
|
|
||||||
|
def _put_frame(self, frame: np.ndarray):
|
||||||
|
"""Put frame into buffer, removing old frame if necessary."""
|
||||||
|
# Remove old frame if buffer is full
|
||||||
|
if not self.buffer.empty():
|
||||||
|
try:
|
||||||
|
self.buffer.get_nowait()
|
||||||
|
logger.debug(f"Removed old frame from buffer for camera {self.camera_id}")
|
||||||
|
except queue.Empty:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.buffer.put(frame)
|
||||||
|
logger.debug(f"Added frame to buffer for camera {self.camera_id}, buffer size: {self.buffer.qsize()}")
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""Main frame reading loop."""
|
||||||
|
logger.info(f"Starting RTSP frame reader for camera {self.camera_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize capture
|
||||||
|
if not self._initialize_capture():
|
||||||
|
return
|
||||||
|
|
||||||
|
while not self.stop_event.is_set():
|
||||||
|
try:
|
||||||
|
frame = self._read_frame()
|
||||||
|
|
||||||
|
if frame is None:
|
||||||
|
# Connection lost
|
||||||
|
self.retries += 1
|
||||||
|
error_msg = f"Connection lost, retry {self.retries}/{self.max_retries}"
|
||||||
|
logger.warning(f"Camera {self.camera_id}: {error_msg}")
|
||||||
|
self._set_connection_state(False, error_msg)
|
||||||
|
|
||||||
|
# Check max retries
|
||||||
|
if self.retries > self.max_retries and self.max_retries != -1:
|
||||||
|
logger.error(f"Camera {self.camera_id}: Max retries reached, stopping")
|
||||||
|
self._set_connection_state(False, "Max retries reached")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Attempt reconnection
|
||||||
|
if not self._reconnect():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Reset retry counter on successful reconnection
|
||||||
|
self.retries = 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Successfully read frame
|
||||||
|
logger.debug(f"Camera {self.camera_id}: Read frame, shape: {frame.shape}")
|
||||||
|
self.retries = 0
|
||||||
|
self._set_connection_state(True)
|
||||||
|
self._put_frame(frame)
|
||||||
|
|
||||||
|
# Short sleep to avoid CPU overuse
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
except FrameReadError as e:
|
||||||
|
self.retries += 1
|
||||||
|
error_msg = f"Frame read error: {e}"
|
||||||
|
logger.error(f"Camera {self.camera_id}: {error_msg}")
|
||||||
|
self._set_connection_state(False, error_msg)
|
||||||
|
|
||||||
|
# Check max retries
|
||||||
|
if self.retries > self.max_retries and self.max_retries != -1:
|
||||||
|
logger.error(f"Camera {self.camera_id}: Max retries reached after error")
|
||||||
|
self._set_connection_state(False, "Max retries reached after error")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Attempt reconnection
|
||||||
|
if not self._reconnect():
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Unexpected error: {str(e)}"
|
||||||
|
logger.error(f"Camera {self.camera_id}: {error_msg}", exc_info=True)
|
||||||
|
self._set_connection_state(False, error_msg)
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in RTSP frame reader for camera {self.camera_id}: {str(e)}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
logger.info(f"RTSP frame reader for camera {self.camera_id} is exiting")
|
||||||
|
if self.cap and self.cap.isOpened():
|
||||||
|
self.cap.release()
|
||||||
|
|
||||||
|
|
||||||
|
class SnapshotFrameReader:
|
||||||
|
"""Thread-safe HTTP snapshot frame reader."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
camera_id: str,
|
||||||
|
snapshot_url: str,
|
||||||
|
snapshot_interval: int,
|
||||||
|
buffer: queue.Queue,
|
||||||
|
stop_event: threading.Event,
|
||||||
|
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||||
|
connection_callback=None):
|
||||||
|
"""
|
||||||
|
Initialize snapshot frame reader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
snapshot_url: HTTP/HTTPS snapshot URL
|
||||||
|
snapshot_interval: Interval between snapshots in milliseconds
|
||||||
|
buffer: Queue to put frames into
|
||||||
|
stop_event: Event to signal thread shutdown
|
||||||
|
max_retries: Maximum retry attempts (-1 for unlimited)
|
||||||
|
connection_callback: Callback function for connection state changes
|
||||||
|
"""
|
||||||
|
self.camera_id = camera_id
|
||||||
|
self.snapshot_url = snapshot_url
|
||||||
|
self.snapshot_interval = snapshot_interval
|
||||||
|
self.buffer = buffer
|
||||||
|
self.stop_event = stop_event
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.connection_callback = connection_callback
|
||||||
|
|
||||||
|
self.retries = 0
|
||||||
|
self.consecutive_failures = 0
|
||||||
|
self.frame_count = 0
|
||||||
|
self.last_log_time = time.time()
|
||||||
|
|
||||||
|
def _set_connection_state(self, connected: bool, error_msg: Optional[str] = None):
|
||||||
|
"""Update connection state via callback."""
|
||||||
|
if self.connection_callback:
|
||||||
|
self.connection_callback(self.camera_id, connected, error_msg)
|
||||||
|
|
||||||
|
def _calculate_backoff_delay(self) -> float:
|
||||||
|
"""Calculate exponential backoff delay based on consecutive failures."""
|
||||||
|
interval_seconds = self.snapshot_interval / 1000.0
|
||||||
|
backoff_delay = min(30, max(1, min(2 ** min(self.consecutive_failures - 1, 6), interval_seconds * 2)))
|
||||||
|
return backoff_delay
|
||||||
|
|
||||||
|
def _test_connectivity(self):
|
||||||
|
"""Test connectivity to snapshot URL."""
|
||||||
|
if self.consecutive_failures % 5 == 1: # Every 5th failure
|
||||||
|
try:
|
||||||
|
test_response = requests.get(self.snapshot_url, timeout=(2, 5), stream=False)
|
||||||
|
logger.info(f"Camera {self.camera_id}: Connectivity test result: {test_response.status_code}")
|
||||||
|
except Exception as test_error:
|
||||||
|
logger.warning(f"Camera {self.camera_id}: Connectivity test failed: {test_error}")
|
||||||
|
|
||||||
|
def _put_frame(self, frame: np.ndarray):
|
||||||
|
"""Put frame into buffer, removing old frame if necessary."""
|
||||||
|
# Remove old frame if buffer is full
|
||||||
|
if not self.buffer.empty():
|
||||||
|
try:
|
||||||
|
self.buffer.get_nowait()
|
||||||
|
logger.debug(f"Removed old snapshot from buffer for camera {self.camera_id}")
|
||||||
|
except queue.Empty:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.buffer.put(frame)
|
||||||
|
logger.debug(f"Added snapshot to buffer for camera {self.camera_id}, buffer size: {self.buffer.qsize()}")
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""Main snapshot reading loop."""
|
||||||
|
logger.info(f"Starting snapshot reader for camera {self.camera_id} from {self.snapshot_url}")
|
||||||
|
|
||||||
|
interval_seconds = self.snapshot_interval / 1000.0
|
||||||
|
logger.info(f"Snapshot interval for camera {self.camera_id}: {interval_seconds}s")
|
||||||
|
|
||||||
|
# Initialize connection state
|
||||||
|
self._set_connection_state(True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while not self.stop_event.is_set():
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
frame = fetch_snapshot(self.snapshot_url)
|
||||||
|
|
||||||
|
if frame is None:
|
||||||
|
# Failed to fetch snapshot
|
||||||
|
self.consecutive_failures += 1
|
||||||
|
self.retries += 1
|
||||||
|
error_msg = f"Failed to fetch snapshot, consecutive failures: {self.consecutive_failures}"
|
||||||
|
logger.warning(f"Camera {self.camera_id}: {error_msg}")
|
||||||
|
self._set_connection_state(False, error_msg)
|
||||||
|
|
||||||
|
# Test connectivity periodically
|
||||||
|
self._test_connectivity()
|
||||||
|
|
||||||
|
# Check max retries
|
||||||
|
if self.retries > self.max_retries and self.max_retries != -1:
|
||||||
|
logger.error(f"Camera {self.camera_id}: Max retries reached for snapshot, stopping")
|
||||||
|
self._set_connection_state(False, "Max retries reached for snapshot")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Exponential backoff
|
||||||
|
backoff_delay = self._calculate_backoff_delay()
|
||||||
|
logger.debug(f"Camera {self.camera_id}: Backing off for {backoff_delay:.1f}s")
|
||||||
|
if self.stop_event.wait(backoff_delay):
|
||||||
|
break # Exit if stop event set during backoff
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Successfully fetched snapshot
|
||||||
|
self.consecutive_failures = 0 # Reset on success
|
||||||
|
self.retries = 0
|
||||||
|
self.frame_count += 1
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Log frame stats every 5 seconds
|
||||||
|
if current_time - self.last_log_time > 5:
|
||||||
|
elapsed = current_time - self.last_log_time
|
||||||
|
logger.info(f"Camera {self.camera_id}: Fetched {self.frame_count} snapshots in {elapsed:.1f}s")
|
||||||
|
self.frame_count = 0
|
||||||
|
self.last_log_time = current_time
|
||||||
|
|
||||||
|
logger.debug(f"Camera {self.camera_id}: Fetched snapshot, shape: {frame.shape}")
|
||||||
|
self._set_connection_state(True)
|
||||||
|
self._put_frame(frame)
|
||||||
|
|
||||||
|
# Wait for interval
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
sleep_time = max(interval_seconds - elapsed, 0)
|
||||||
|
if sleep_time > 0:
|
||||||
|
if self.stop_event.wait(sleep_time):
|
||||||
|
break # Exit if stop event set during sleep
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.consecutive_failures += 1
|
||||||
|
self.retries += 1
|
||||||
|
error_msg = f"Unexpected error: {str(e)}"
|
||||||
|
logger.error(f"Camera {self.camera_id}: {error_msg}", exc_info=True)
|
||||||
|
self._set_connection_state(False, error_msg)
|
||||||
|
|
||||||
|
# Check max retries
|
||||||
|
if self.retries > self.max_retries and self.max_retries != -1:
|
||||||
|
logger.error(f"Camera {self.camera_id}: Max retries reached after error")
|
||||||
|
self._set_connection_state(False, "Max retries reached after error")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Exponential backoff for exceptions too
|
||||||
|
backoff_delay = self._calculate_backoff_delay()
|
||||||
|
logger.debug(f"Camera {self.camera_id}: Exception backoff for {backoff_delay:.1f}s")
|
||||||
|
if self.stop_event.wait(backoff_delay):
|
||||||
|
break # Exit if stop event set during backoff
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in snapshot reader for camera {self.camera_id}: {str(e)}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
logger.info(f"Snapshot reader for camera {self.camera_id} is exiting")
|
||||||
|
|
||||||
|
|
||||||
|
def create_frame_reader_thread(camera_id: str,
|
||||||
|
rtsp_url: Optional[str] = None,
|
||||||
|
snapshot_url: Optional[str] = None,
|
||||||
|
snapshot_interval: Optional[int] = None,
|
||||||
|
buffer: Optional[queue.Queue] = None,
|
||||||
|
stop_event: Optional[threading.Event] = None,
|
||||||
|
connection_callback=None) -> Optional[threading.Thread]:
|
||||||
|
"""
|
||||||
|
Create appropriate frame reader thread based on stream type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Unique camera identifier
|
||||||
|
rtsp_url: RTSP stream URL (for RTSP streams)
|
||||||
|
snapshot_url: HTTP snapshot URL (for snapshot streams)
|
||||||
|
snapshot_interval: Snapshot interval in milliseconds
|
||||||
|
buffer: Frame buffer queue
|
||||||
|
stop_event: Thread stop event
|
||||||
|
connection_callback: Connection state callback
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured thread ready to start, or None if invalid parameters
|
||||||
|
"""
|
||||||
|
if not buffer:
|
||||||
|
buffer = queue.Queue(maxsize=SHARED_STREAM_BUFFER_SIZE)
|
||||||
|
if not stop_event:
|
||||||
|
stop_event = threading.Event()
|
||||||
|
|
||||||
|
if snapshot_url and snapshot_interval:
|
||||||
|
# Create snapshot reader
|
||||||
|
reader = SnapshotFrameReader(
|
||||||
|
camera_id=camera_id,
|
||||||
|
snapshot_url=snapshot_url,
|
||||||
|
snapshot_interval=snapshot_interval,
|
||||||
|
buffer=buffer,
|
||||||
|
stop_event=stop_event,
|
||||||
|
connection_callback=connection_callback
|
||||||
|
)
|
||||||
|
thread = threading.Thread(target=reader.run, name=f"snapshot-{camera_id}")
|
||||||
|
|
||||||
|
elif rtsp_url:
|
||||||
|
# Create RTSP reader
|
||||||
|
reader = RTSPFrameReader(
|
||||||
|
camera_id=camera_id,
|
||||||
|
rtsp_url=rtsp_url,
|
||||||
|
buffer=buffer,
|
||||||
|
stop_event=stop_event,
|
||||||
|
connection_callback=connection_callback
|
||||||
|
)
|
||||||
|
thread = threading.Thread(target=reader.run, name=f"rtsp-{camera_id}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.error(f"No valid URL provided for camera {camera_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
thread.daemon = True
|
||||||
|
return thread
|
572
detector_worker/streams/stream_manager.py
Normal file
572
detector_worker/streams/stream_manager.py
Normal file
|
@ -0,0 +1,572 @@
|
||||||
|
"""
|
||||||
|
Stream lifecycle management and coordination.
|
||||||
|
|
||||||
|
This module provides centralized management of camera streams including
|
||||||
|
lifecycle management, resource allocation, and stream coordination.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import queue
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import Dict, List, Any, Optional, Tuple, Set
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from ..core.constants import (
|
||||||
|
DEFAULT_MAX_STREAMS,
|
||||||
|
SHARED_STREAM_BUFFER_SIZE,
|
||||||
|
DEFAULT_RECONNECT_INTERVAL_SEC,
|
||||||
|
DEFAULT_MAX_RETRIES
|
||||||
|
)
|
||||||
|
from ..core.exceptions import StreamError, create_stream_error
|
||||||
|
from ..streams.frame_reader import create_frame_reader_thread
|
||||||
|
from ..streams.camera_monitor import set_camera_connected
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamInfo:
|
||||||
|
"""Information about a single camera stream."""
|
||||||
|
camera_id: str
|
||||||
|
stream_url: str
|
||||||
|
stream_type: str # "rtsp" or "snapshot"
|
||||||
|
snapshot_interval: Optional[int] = None
|
||||||
|
buffer: Optional[queue.Queue] = None
|
||||||
|
stop_event: Optional[threading.Event] = None
|
||||||
|
thread: Optional[threading.Thread] = None
|
||||||
|
subscribers: Set[str] = field(default_factory=set)
|
||||||
|
created_at: float = field(default_factory=time.time)
|
||||||
|
last_frame_time: Optional[float] = None
|
||||||
|
frame_count: int = 0
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"camera_id": self.camera_id,
|
||||||
|
"stream_url": self.stream_url,
|
||||||
|
"stream_type": self.stream_type,
|
||||||
|
"snapshot_interval": self.snapshot_interval,
|
||||||
|
"subscriber_count": len(self.subscribers),
|
||||||
|
"subscribers": list(self.subscribers),
|
||||||
|
"created_at": self.created_at,
|
||||||
|
"last_frame_time": self.last_frame_time,
|
||||||
|
"frame_count": self.frame_count,
|
||||||
|
"is_active": self.thread is not None and self.thread.is_alive()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamSubscription:
|
||||||
|
"""Information about a stream subscription."""
|
||||||
|
subscription_id: str
|
||||||
|
camera_id: str
|
||||||
|
subscriber_id: str
|
||||||
|
created_at: float = field(default_factory=time.time)
|
||||||
|
last_access: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"subscription_id": self.subscription_id,
|
||||||
|
"camera_id": self.camera_id,
|
||||||
|
"subscriber_id": self.subscriber_id,
|
||||||
|
"created_at": self.created_at,
|
||||||
|
"last_access": self.last_access
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class StreamManager:
|
||||||
|
"""
|
||||||
|
Manages camera stream lifecycle and resource allocation.
|
||||||
|
|
||||||
|
This class provides centralized management of camera streams including:
|
||||||
|
- Stream lifecycle management (start/stop/restart)
|
||||||
|
- Resource allocation and sharing
|
||||||
|
- Subscriber management
|
||||||
|
- Connection state monitoring
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_streams: int = DEFAULT_MAX_STREAMS):
|
||||||
|
"""
|
||||||
|
Initialize stream manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_streams: Maximum number of concurrent streams
|
||||||
|
"""
|
||||||
|
self.max_streams = max_streams
|
||||||
|
self._streams: Dict[str, StreamInfo] = {}
|
||||||
|
self._subscriptions: Dict[str, StreamSubscription] = {}
|
||||||
|
self._lock = None
|
||||||
|
|
||||||
|
def _ensure_thread_safety(self):
|
||||||
|
"""Initialize thread safety if not already done."""
|
||||||
|
if self._lock is None:
|
||||||
|
import threading
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
def _connection_state_callback(self, camera_id: str, connected: bool, error_msg: Optional[str] = None):
|
||||||
|
"""Callback for connection state changes."""
|
||||||
|
set_camera_connected(camera_id, connected, error_msg)
|
||||||
|
|
||||||
|
def _create_stream_info(self,
|
||||||
|
camera_id: str,
|
||||||
|
rtsp_url: Optional[str] = None,
|
||||||
|
snapshot_url: Optional[str] = None,
|
||||||
|
snapshot_interval: Optional[int] = None) -> StreamInfo:
|
||||||
|
"""Create StreamInfo object based on stream type."""
|
||||||
|
if snapshot_url and snapshot_interval:
|
||||||
|
return StreamInfo(
|
||||||
|
camera_id=camera_id,
|
||||||
|
stream_url=snapshot_url,
|
||||||
|
stream_type="snapshot",
|
||||||
|
snapshot_interval=snapshot_interval,
|
||||||
|
buffer=queue.Queue(maxsize=SHARED_STREAM_BUFFER_SIZE),
|
||||||
|
stop_event=threading.Event()
|
||||||
|
)
|
||||||
|
elif rtsp_url:
|
||||||
|
return StreamInfo(
|
||||||
|
camera_id=camera_id,
|
||||||
|
stream_url=rtsp_url,
|
||||||
|
stream_type="rtsp",
|
||||||
|
buffer=queue.Queue(maxsize=SHARED_STREAM_BUFFER_SIZE),
|
||||||
|
stop_event=threading.Event()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Must provide either RTSP URL or snapshot URL with interval")
|
||||||
|
|
||||||
|
def create_subscription(self,
|
||||||
|
subscription_id: str,
|
||||||
|
camera_id: str,
|
||||||
|
subscriber_id: str,
|
||||||
|
rtsp_url: Optional[str] = None,
|
||||||
|
snapshot_url: Optional[str] = None,
|
||||||
|
snapshot_interval: Optional[int] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Create a stream subscription.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subscription_id: Unique subscription identifier
|
||||||
|
camera_id: Camera identifier
|
||||||
|
subscriber_id: Subscriber identifier
|
||||||
|
rtsp_url: RTSP stream URL (for RTSP streams)
|
||||||
|
snapshot_url: HTTP snapshot URL (for snapshot streams)
|
||||||
|
snapshot_interval: Snapshot interval in milliseconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if subscription was created successfully
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
try:
|
||||||
|
# Check if subscription already exists
|
||||||
|
if subscription_id in self._subscriptions:
|
||||||
|
logger.warning(f"Subscription {subscription_id} already exists")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check stream limit
|
||||||
|
if len(self._streams) >= self.max_streams and camera_id not in self._streams:
|
||||||
|
logger.error(f"Maximum streams ({self.max_streams}) reached, cannot create new stream for camera {camera_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Create or get existing stream
|
||||||
|
if camera_id not in self._streams:
|
||||||
|
stream_info = self._create_stream_info(
|
||||||
|
camera_id, rtsp_url, snapshot_url, snapshot_interval
|
||||||
|
)
|
||||||
|
self._streams[camera_id] = stream_info
|
||||||
|
|
||||||
|
# Create and start frame reader thread
|
||||||
|
thread = create_frame_reader_thread(
|
||||||
|
camera_id=camera_id,
|
||||||
|
rtsp_url=rtsp_url,
|
||||||
|
snapshot_url=snapshot_url,
|
||||||
|
snapshot_interval=snapshot_interval,
|
||||||
|
buffer=stream_info.buffer,
|
||||||
|
stop_event=stream_info.stop_event,
|
||||||
|
connection_callback=self._connection_state_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
if thread:
|
||||||
|
stream_info.thread = thread
|
||||||
|
thread.start()
|
||||||
|
logger.info(f"Created new {stream_info.stream_type} stream for camera {camera_id}")
|
||||||
|
else:
|
||||||
|
# Clean up failed stream
|
||||||
|
del self._streams[camera_id]
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Add subscriber to stream
|
||||||
|
stream_info = self._streams[camera_id]
|
||||||
|
stream_info.subscribers.add(subscription_id)
|
||||||
|
|
||||||
|
# Create subscription record
|
||||||
|
subscription = StreamSubscription(
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
camera_id=camera_id,
|
||||||
|
subscriber_id=subscriber_id
|
||||||
|
)
|
||||||
|
self._subscriptions[subscription_id] = subscription
|
||||||
|
|
||||||
|
logger.info(f"Created subscription {subscription_id} for camera {camera_id}, subscribers: {len(stream_info.subscribers)}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating subscription {subscription_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def remove_subscription(self, subscription_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Remove a stream subscription.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subscription_id: Unique subscription identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if subscription was removed successfully
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if subscription_id not in self._subscriptions:
|
||||||
|
logger.warning(f"Subscription {subscription_id} not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
subscription = self._subscriptions[subscription_id]
|
||||||
|
camera_id = subscription.camera_id
|
||||||
|
|
||||||
|
# Remove subscription
|
||||||
|
del self._subscriptions[subscription_id]
|
||||||
|
|
||||||
|
# Remove subscriber from stream if stream exists
|
||||||
|
if camera_id in self._streams:
|
||||||
|
stream_info = self._streams[camera_id]
|
||||||
|
stream_info.subscribers.discard(subscription_id)
|
||||||
|
|
||||||
|
logger.info(f"Removed subscription {subscription_id} for camera {camera_id}, remaining subscribers: {len(stream_info.subscribers)}")
|
||||||
|
|
||||||
|
# Stop stream if no more subscribers
|
||||||
|
if not stream_info.subscribers:
|
||||||
|
self._stop_stream(camera_id)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _stop_stream(self, camera_id: str) -> None:
|
||||||
|
"""Stop a stream and clean up resources."""
|
||||||
|
if camera_id not in self._streams:
|
||||||
|
return
|
||||||
|
|
||||||
|
stream_info = self._streams[camera_id]
|
||||||
|
|
||||||
|
# Signal thread to stop
|
||||||
|
if stream_info.stop_event:
|
||||||
|
stream_info.stop_event.set()
|
||||||
|
|
||||||
|
# Wait for thread to finish
|
||||||
|
if stream_info.thread and stream_info.thread.is_alive():
|
||||||
|
stream_info.thread.join(timeout=5)
|
||||||
|
if stream_info.thread.is_alive():
|
||||||
|
logger.warning(f"Stream thread for camera {camera_id} did not stop gracefully")
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
del self._streams[camera_id]
|
||||||
|
logger.info(f"Stopped {stream_info.stream_type} stream for camera {camera_id}")
|
||||||
|
|
||||||
|
def get_frame(self, subscription_id: str, timeout: float = 0.1) -> Optional[Any]:
|
||||||
|
"""
|
||||||
|
Get the latest frame for a subscription.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subscription_id: Unique subscription identifier
|
||||||
|
timeout: Timeout for frame retrieval in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Latest frame or None if not available
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if subscription_id not in self._subscriptions:
|
||||||
|
return None
|
||||||
|
|
||||||
|
subscription = self._subscriptions[subscription_id]
|
||||||
|
camera_id = subscription.camera_id
|
||||||
|
|
||||||
|
if camera_id not in self._streams:
|
||||||
|
return None
|
||||||
|
|
||||||
|
stream_info = self._streams[camera_id]
|
||||||
|
subscription.last_access = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
frame = stream_info.buffer.get(timeout=timeout)
|
||||||
|
stream_info.last_frame_time = time.time()
|
||||||
|
stream_info.frame_count += 1
|
||||||
|
return frame
|
||||||
|
except queue.Empty:
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting frame for subscription {subscription_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_stream_active(self, camera_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a stream is active.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Camera identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if stream is active
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if camera_id not in self._streams:
|
||||||
|
return False
|
||||||
|
|
||||||
|
stream_info = self._streams[camera_id]
|
||||||
|
return stream_info.thread is not None and stream_info.thread.is_alive()
|
||||||
|
|
||||||
|
def get_stream_stats(self, camera_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get statistics for a stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Camera identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Stream statistics or None if stream not found
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if camera_id not in self._streams:
|
||||||
|
return None
|
||||||
|
|
||||||
|
stream_info = self._streams[camera_id]
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
stats = stream_info.to_dict()
|
||||||
|
stats["uptime_seconds"] = current_time - stream_info.created_at
|
||||||
|
|
||||||
|
if stream_info.last_frame_time:
|
||||||
|
stats["seconds_since_last_frame"] = current_time - stream_info.last_frame_time
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def get_subscription_info(self, subscription_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get information about a subscription.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subscription_id: Unique subscription identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Subscription information or None if not found
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if subscription_id not in self._subscriptions:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self._subscriptions[subscription_id].to_dict()
|
||||||
|
|
||||||
|
def get_all_streams(self) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get information about all active streams.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping camera IDs to stream information
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
return {
|
||||||
|
camera_id: stream_info.to_dict()
|
||||||
|
for camera_id, stream_info in self._streams.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_all_subscriptions(self) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get information about all active subscriptions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping subscription IDs to subscription information
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
return {
|
||||||
|
sub_id: subscription.to_dict()
|
||||||
|
for sub_id, subscription in self._subscriptions.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def cleanup_inactive_streams(self, inactive_threshold_seconds: int = 3600) -> int:
|
||||||
|
"""
|
||||||
|
Clean up streams that have been inactive for too long.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inactive_threshold_seconds: Seconds of inactivity before cleanup
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of streams cleaned up
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
cleanup_count = 0
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
streams_to_remove = []
|
||||||
|
|
||||||
|
for camera_id, stream_info in self._streams.items():
|
||||||
|
# Check if stream has subscribers
|
||||||
|
if stream_info.subscribers:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if stream has been inactive
|
||||||
|
last_activity = max(
|
||||||
|
stream_info.created_at,
|
||||||
|
stream_info.last_frame_time or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_time - last_activity > inactive_threshold_seconds:
|
||||||
|
streams_to_remove.append(camera_id)
|
||||||
|
|
||||||
|
for camera_id in streams_to_remove:
|
||||||
|
self._stop_stream(camera_id)
|
||||||
|
cleanup_count += 1
|
||||||
|
logger.info(f"Cleaned up inactive stream for camera {camera_id}")
|
||||||
|
|
||||||
|
if cleanup_count > 0:
|
||||||
|
logger.info(f"Cleaned up {cleanup_count} inactive streams")
|
||||||
|
|
||||||
|
return cleanup_count
|
||||||
|
|
||||||
|
def restart_stream(self, camera_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Restart a stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Camera identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if stream was restarted successfully
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if camera_id not in self._streams:
|
||||||
|
logger.warning(f"Cannot restart stream for camera {camera_id}: stream not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
stream_info = self._streams[camera_id]
|
||||||
|
subscribers = stream_info.subscribers.copy()
|
||||||
|
stream_url = stream_info.stream_url
|
||||||
|
stream_type = stream_info.stream_type
|
||||||
|
snapshot_interval = stream_info.snapshot_interval
|
||||||
|
|
||||||
|
# Stop current stream
|
||||||
|
self._stop_stream(camera_id)
|
||||||
|
|
||||||
|
# Recreate stream
|
||||||
|
try:
|
||||||
|
new_stream_info = self._create_stream_info(
|
||||||
|
camera_id,
|
||||||
|
rtsp_url=stream_url if stream_type == "rtsp" else None,
|
||||||
|
snapshot_url=stream_url if stream_type == "snapshot" else None,
|
||||||
|
snapshot_interval=snapshot_interval
|
||||||
|
)
|
||||||
|
new_stream_info.subscribers = subscribers
|
||||||
|
self._streams[camera_id] = new_stream_info
|
||||||
|
|
||||||
|
# Create and start new frame reader thread
|
||||||
|
thread = create_frame_reader_thread(
|
||||||
|
camera_id=camera_id,
|
||||||
|
rtsp_url=stream_url if stream_type == "rtsp" else None,
|
||||||
|
snapshot_url=stream_url if stream_type == "snapshot" else None,
|
||||||
|
snapshot_interval=snapshot_interval,
|
||||||
|
buffer=new_stream_info.buffer,
|
||||||
|
stop_event=new_stream_info.stop_event,
|
||||||
|
connection_callback=self._connection_state_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
if thread:
|
||||||
|
new_stream_info.thread = thread
|
||||||
|
thread.start()
|
||||||
|
logger.info(f"Restarted {stream_type} stream for camera {camera_id}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# Clean up failed restart
|
||||||
|
del self._streams[camera_id]
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error restarting stream for camera {camera_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def shutdown_all(self) -> None:
|
||||||
|
"""Shutdown all streams and clean up resources."""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
logger.info("Shutting down all streams...")
|
||||||
|
|
||||||
|
# Stop all streams
|
||||||
|
camera_ids = list(self._streams.keys())
|
||||||
|
for camera_id in camera_ids:
|
||||||
|
self._stop_stream(camera_id)
|
||||||
|
|
||||||
|
# Clear all subscriptions
|
||||||
|
self._subscriptions.clear()
|
||||||
|
|
||||||
|
logger.info("All streams shut down successfully")
|
||||||
|
|
||||||
|
|
||||||
|
# Global stream manager instance
|
||||||
|
stream_manager = StreamManager()
|
||||||
|
|
||||||
|
|
||||||
|
# ===== CONVENIENCE FUNCTIONS =====
|
||||||
|
# These provide a simplified interface for common operations
|
||||||
|
|
||||||
|
def create_stream_subscription(subscription_id: str,
|
||||||
|
camera_id: str,
|
||||||
|
subscriber_id: str,
|
||||||
|
rtsp_url: Optional[str] = None,
|
||||||
|
snapshot_url: Optional[str] = None,
|
||||||
|
snapshot_interval: Optional[int] = None) -> bool:
|
||||||
|
"""Create a stream subscription using global stream manager."""
|
||||||
|
return stream_manager.create_subscription(
|
||||||
|
subscription_id, camera_id, subscriber_id, rtsp_url, snapshot_url, snapshot_interval
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_stream_subscription(subscription_id: str) -> bool:
|
||||||
|
"""Remove a stream subscription using global stream manager."""
|
||||||
|
return stream_manager.remove_subscription(subscription_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_stream_frame(subscription_id: str, timeout: float = 0.1) -> Optional[Any]:
|
||||||
|
"""Get the latest frame for a subscription using global stream manager."""
|
||||||
|
return stream_manager.get_frame(subscription_id, timeout)
|
||||||
|
|
||||||
|
|
||||||
|
def is_stream_active(camera_id: str) -> bool:
|
||||||
|
"""Check if a stream is active using global stream manager."""
|
||||||
|
return stream_manager.is_stream_active(camera_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_stream_statistics() -> Dict[str, Any]:
|
||||||
|
"""Get comprehensive stream statistics."""
|
||||||
|
return {
|
||||||
|
"streams": stream_manager.get_all_streams(),
|
||||||
|
"subscriptions": stream_manager.get_all_subscriptions(),
|
||||||
|
"total_streams": len(stream_manager._streams),
|
||||||
|
"total_subscriptions": len(stream_manager._subscriptions),
|
||||||
|
"max_streams": stream_manager.max_streams
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue