Refactor: PHASE 2: Core Module Extraction

This commit is contained in:
ziesorx 2025-09-12 14:45:11 +07:00
parent 96bedae80a
commit 4e9ae6bcc4
7 changed files with 3684 additions and 0 deletions

View file

@ -0,0 +1,483 @@
"""
Detection stability validation and lightweight detection functionality.
This module provides validation functionality for detection pipelines including
pipeline execution pre-validation and lightweight detection for validation phases.
"""
import logging
from typing import Dict, List, Any, Optional, Tuple, Union
from dataclasses import dataclass, field
import numpy as np
from ..core.constants import (
LIGHTWEIGHT_DETECTION_MIN_CONFIDENCE,
LIGHTWEIGHT_DETECTION_MIN_BBOX_AREA_RATIO,
DEFAULT_MIN_CONFIDENCE
)
from ..core.exceptions import ValidationError, create_detection_error
from ..detection.detection_result import LightweightDetectionResult, BoundingBox
logger = logging.getLogger(__name__)
@dataclass
class PipelineValidationResult:
"""Result of pipeline execution validation."""
is_valid: bool
missing_branches: List[str] = field(default_factory=list)
required_branches: List[str] = field(default_factory=list)
available_classes: List[str] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"is_valid": self.is_valid,
"missing_branches": self.missing_branches.copy(),
"required_branches": self.required_branches.copy(),
"available_classes": self.available_classes.copy()
}
@dataclass
class LightweightDetectionResult:
"""Result from lightweight detection."""
car_detected: bool
best_detection: Optional[Dict[str, Any]] = None
validation_passed: bool = False
confidence: Optional[float] = None
bbox_area_ratio: Optional[float] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
result = {
"car_detected": self.car_detected,
"validation_passed": self.validation_passed
}
if self.best_detection:
result["best_detection"] = self.best_detection.copy()
if self.confidence is not None:
result["confidence"] = self.confidence
if self.bbox_area_ratio is not None:
result["bbox_area_ratio"] = self.bbox_area_ratio
return result
class StabilityValidator:
"""
Validates detection stability and pipeline execution requirements.
This class provides functionality for:
- Pipeline execution pre-validation
- Lightweight detection for validation phases
- Detection threshold validation
"""
def __init__(self):
"""Initialize stability validator."""
pass
def validate_pipeline_execution(self,
node: Dict[str, Any],
regions_dict: Dict[str, Any]) -> PipelineValidationResult:
"""
Pre-validate that all required branches will execute successfully before
committing to Redis actions and database records.
Args:
node: Pipeline node configuration
regions_dict: Dictionary mapping class names to detection regions
Returns:
PipelineValidationResult with validation status and details
"""
# Get all branches that parallel actions are waiting for
required_branches = set()
for action in node.get("parallelActions", []):
if action.get("type") == "postgresql_update_combined":
wait_for_branches = action.get("waitForBranches", [])
required_branches.update(wait_for_branches)
if not required_branches:
# No parallel actions requiring specific branches
logger.debug("No parallel actions with waitForBranches - validation passes")
return PipelineValidationResult(
is_valid=True,
required_branches=[],
available_classes=list(regions_dict.keys())
)
logger.debug(f"Pre-validation: checking if required branches {list(required_branches)} will execute")
# Check each required branch
missing_branches = []
for branch in node.get("branches", []):
branch_id = branch["modelId"]
if branch_id not in required_branches:
continue # This branch is not required by parallel actions
# Check if this branch would be triggered
trigger_classes = branch.get("triggerClasses", [])
min_conf = branch.get("minConfidence", 0)
branch_triggered = False
for det_class in regions_dict:
det_confidence = regions_dict[det_class]["confidence"]
if (det_class in trigger_classes and det_confidence >= min_conf):
branch_triggered = True
logger.debug(f"Pre-validation: branch {branch_id} WILL be triggered by {det_class} (conf={det_confidence:.3f} >= {min_conf})")
break
if not branch_triggered:
missing_branches.append(branch_id)
logger.warning(f"Pre-validation: branch {branch_id} will NOT be triggered - no matching classes or insufficient confidence")
logger.debug(f" Required: {trigger_classes} with min_conf={min_conf}")
logger.debug(f" Available: {[(cls, regions_dict[cls]['confidence']) for cls in regions_dict]}")
is_valid = len(missing_branches) == 0
if missing_branches:
logger.error(f"Pipeline pre-validation FAILED: required branches {missing_branches} will not execute")
else:
logger.info(f"Pipeline pre-validation PASSED: all required branches {list(required_branches)} will execute")
return PipelineValidationResult(
is_valid=is_valid,
missing_branches=missing_branches,
required_branches=list(required_branches),
available_classes=list(regions_dict.keys())
)
def _apply_detection_filters(self,
box: Any,
model: Any,
min_confidence: float,
trigger_classes: Optional[List[str]] = None,
trigger_class_indices: Optional[List[int]] = None) -> Optional[Dict[str, Any]]:
"""Apply confidence and class filtering to a detection box."""
try:
# Extract detection info
xyxy = box.xyxy[0].cpu().numpy()
conf = box.conf[0].cpu().numpy()
cls_id = int(box.cls[0].cpu().numpy())
class_name = model.names[cls_id]
# Apply confidence threshold
if conf < min_confidence:
return None
# Apply trigger class filtering if specified
if trigger_class_indices and cls_id not in trigger_class_indices:
return None
if trigger_classes and class_name not in trigger_classes:
return None
return {
"class": class_name,
"confidence": float(conf),
"bbox": [int(x) for x in xyxy],
"class_id": cls_id,
"xyxy": xyxy
}
except Exception as e:
logger.error(f"Error processing detection box: {e}")
return None
def run_lightweight_detection_with_validation(self,
frame: np.ndarray,
node: Dict[str, Any],
min_confidence: float = LIGHTWEIGHT_DETECTION_MIN_CONFIDENCE,
min_bbox_area_ratio: float = LIGHTWEIGHT_DETECTION_MIN_BBOX_AREA_RATIO) -> Dict[str, Any]:
"""
Run lightweight detection with validation rules for session ID triggering.
Returns detection info only if it passes validation thresholds.
Args:
frame: Input frame for detection
node: Pipeline node configuration
min_confidence: Minimum confidence threshold for validation
min_bbox_area_ratio: Minimum bounding box area ratio for validation
Returns:
Dictionary with validation results and detection info
"""
logger.debug(f"Running lightweight detection with validation: {node['modelId']} (conf>={min_confidence}, bbox_area>={min_bbox_area_ratio})")
try:
# Run basic detection only - no branches, no actions
model = node["model"]
trigger_classes = node.get("triggerClasses", [])
trigger_class_indices = node.get("triggerClassIndices")
# Run YOLO inference
res = model(frame, verbose=False)
best_detection = None
frame_height, frame_width = frame.shape[:2]
frame_area = frame_height * frame_width
for r in res:
boxes = r.boxes
if boxes is None or len(boxes) == 0:
continue
for box in boxes:
detection = self._apply_detection_filters(
box, model, min_confidence, trigger_classes, trigger_class_indices
)
if not detection:
continue
# Calculate bbox area ratio
x1, y1, x2, y2 = detection["xyxy"]
bbox_area = (x2 - x1) * (y2 - y1)
bbox_area_ratio = bbox_area / frame_area if frame_area > 0 else 0
# Apply bbox area threshold
if bbox_area_ratio < min_bbox_area_ratio:
logger.debug(f"Detection filtered out: bbox_area_ratio={bbox_area_ratio:.3f} < {min_bbox_area_ratio}")
continue
# Validation passed
detection["bbox_area_ratio"] = float(bbox_area_ratio)
detection["validation_passed"] = True
if not best_detection or detection["confidence"] > best_detection["confidence"]:
best_detection = detection
if best_detection:
logger.debug(f"Validation PASSED: {best_detection['class']} (conf: {best_detection['confidence']:.3f}, area: {best_detection['bbox_area_ratio']:.3f})")
return best_detection
else:
logger.debug(f"Validation FAILED: No detection meets criteria (conf>={min_confidence}, area>={min_bbox_area_ratio})")
return {"validation_passed": False}
except Exception as e:
logger.error(f"Error in lightweight detection with validation: {str(e)}", exc_info=True)
return {"validation_passed": False}
def run_lightweight_detection(self,
frame: np.ndarray,
node: Dict[str, Any]) -> LightweightDetectionResult:
"""
Run lightweight detection for car presence validation only.
Returns basic detection info without running branches or external actions.
Args:
frame: Input frame for detection
node: Pipeline node configuration
Returns:
LightweightDetectionResult with car detection status
"""
logger.debug(f"Running lightweight detection: {node['modelId']}")
try:
# Run basic detection only - no branches, no actions
model = node["model"]
min_confidence = node.get("minConfidence", DEFAULT_MIN_CONFIDENCE)
trigger_classes = node.get("triggerClasses", [])
trigger_class_indices = node.get("triggerClassIndices")
# Run YOLO inference
res = model(frame, verbose=False)
car_detected = False
best_detection = None
for r in res:
boxes = r.boxes
if boxes is None or len(boxes) == 0:
continue
for box in boxes:
detection = self._apply_detection_filters(
box, model, min_confidence, trigger_classes, trigger_class_indices
)
if not detection:
continue
# Car detected
car_detected = True
if not best_detection or detection["confidence"] > best_detection["confidence"]:
best_detection = detection
logger.debug(f"Lightweight detection result: car_detected={car_detected}")
if best_detection:
logger.debug(f"Best detection: {best_detection['class']} (conf: {best_detection['confidence']:.3f})")
return LightweightDetectionResult(
car_detected=car_detected,
best_detection=best_detection,
confidence=best_detection["confidence"] if best_detection else None
)
except Exception as e:
logger.error(f"Error in lightweight detection: {str(e)}", exc_info=True)
return LightweightDetectionResult(car_detected=False, best_detection=None)
def validate_detection_thresholds(self,
detection: Dict[str, Any],
frame_shape: Tuple[int, int, int],
min_confidence: float,
min_bbox_area_ratio: float) -> bool:
"""
Validate detection against confidence and bbox area thresholds.
Args:
detection: Detection dictionary with confidence and bbox
frame_shape: Frame dimensions (height, width, channels)
min_confidence: Minimum confidence threshold
min_bbox_area_ratio: Minimum bbox area ratio threshold
Returns:
True if detection passes all validation thresholds
"""
try:
# Check confidence
confidence = detection.get("confidence", 0.0)
if confidence < min_confidence:
logger.debug(f"Detection failed confidence threshold: {confidence:.3f} < {min_confidence}")
return False
# Check bbox area ratio
bbox = detection.get("bbox", [])
if len(bbox) >= 4:
x1, y1, x2, y2 = bbox[:4]
bbox_area = (x2 - x1) * (y2 - y1)
frame_height, frame_width = frame_shape[:2]
frame_area = frame_height * frame_width
bbox_area_ratio = bbox_area / frame_area if frame_area > 0 else 0
if bbox_area_ratio < min_bbox_area_ratio:
logger.debug(f"Detection failed bbox area threshold: {bbox_area_ratio:.3f} < {min_bbox_area_ratio}")
return False
logger.debug(f"Detection passed validation thresholds: conf={confidence:.3f}, area_ratio={bbox_area_ratio:.3f}")
return True
except Exception as e:
logger.error(f"Error validating detection thresholds: {e}")
return False
def check_branch_execution_requirements(self,
branch: Dict[str, Any],
regions_dict: Dict[str, Any]) -> bool:
"""
Check if a branch will be executed based on trigger classes and confidence.
Args:
branch: Branch configuration
regions_dict: Available detection regions
Returns:
True if branch will be executed
"""
trigger_classes = branch.get("triggerClasses", [])
min_conf = branch.get("minConfidence", 0)
for det_class in regions_dict:
det_confidence = regions_dict[det_class]["confidence"]
if det_class in trigger_classes and det_confidence >= min_conf:
return True
return False
def get_validation_summary(self,
node: Dict[str, Any],
regions_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Get comprehensive validation summary for a pipeline node.
Args:
node: Pipeline node configuration
regions_dict: Available detection regions
Returns:
Dictionary with validation summary information
"""
summary = {
"model_id": node.get("modelId", "unknown"),
"available_classes": list(regions_dict.keys()),
"branch_count": len(node.get("branches", [])),
"parallel_action_count": len(node.get("parallelActions", [])),
"branches_that_will_execute": [],
"branches_that_will_not_execute": []
}
# Check each branch
for branch in node.get("branches", []):
branch_id = branch["modelId"]
will_execute = self.check_branch_execution_requirements(branch, regions_dict)
if will_execute:
summary["branches_that_will_execute"].append(branch_id)
else:
summary["branches_that_will_not_execute"].append(branch_id)
# Pipeline validation
pipeline_validation = self.validate_pipeline_execution(node, regions_dict)
summary["pipeline_validation"] = pipeline_validation.to_dict()
return summary
# Global stability validator instance
stability_validator = StabilityValidator()
# ===== CONVENIENCE FUNCTIONS =====
# These provide the same interface as the original functions in pympta.py
def validate_pipeline_execution(node: Dict[str, Any], regions_dict: Dict[str, Any]) -> Tuple[bool, List[str]]:
"""
Pre-validate that all required branches will execute successfully.
Returns:
- (True, []) if pipeline can execute completely
- (False, missing_branches) if some required branches won't execute
"""
result = stability_validator.validate_pipeline_execution(node, regions_dict)
return result.is_valid, result.missing_branches
def run_lightweight_detection_with_validation(frame: np.ndarray,
node: Dict[str, Any],
min_confidence: float = LIGHTWEIGHT_DETECTION_MIN_CONFIDENCE,
min_bbox_area_ratio: float = LIGHTWEIGHT_DETECTION_MIN_BBOX_AREA_RATIO) -> Dict[str, Any]:
"""Run lightweight detection with validation rules for session ID triggering."""
return stability_validator.run_lightweight_detection_with_validation(
frame, node, min_confidence, min_bbox_area_ratio
)
def run_lightweight_detection(frame: np.ndarray, node: Dict[str, Any]) -> Dict[str, Any]:
"""Run lightweight detection for car presence validation only."""
result = stability_validator.run_lightweight_detection(frame, node)
return result.to_dict()
def validate_detection_thresholds(detection: Dict[str, Any],
frame_shape: Tuple[int, int, int],
min_confidence: float,
min_bbox_area_ratio: float) -> bool:
"""Validate detection against confidence and bbox area thresholds."""
return stability_validator.validate_detection_thresholds(
detection, frame_shape, min_confidence, min_bbox_area_ratio
)
def get_validation_summary(node: Dict[str, Any], regions_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Get comprehensive validation summary for a pipeline node."""
return stability_validator.get_validation_summary(node, regions_dict)

View file

@ -0,0 +1,481 @@
"""
Object tracking management and state handling.
This module provides tracking state management, session handling,
and occupancy detection functionality for the detection pipeline.
"""
import time
import logging
from typing import Dict, List, Any, Optional, Tuple, Set
from dataclasses import dataclass, field
from datetime import datetime
from ..core.constants import SESSION_TIMEOUT_SECONDS
from ..core.exceptions import TrackingError, create_detection_error
logger = logging.getLogger(__name__)
@dataclass
class SessionState:
"""Session state for tracking management."""
active: bool = True
waiting_for_backend_session: bool = False
wait_start_time: float = 0.0
reset_tracker_on_resume: bool = False
occupancy_mode: bool = False
occupancy_enabled_at: Optional[float] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"active": self.active,
"waiting_for_backend_session": self.waiting_for_backend_session,
"wait_start_time": self.wait_start_time,
"reset_tracker_on_resume": self.reset_tracker_on_resume,
"occupancy_mode": self.occupancy_mode,
"occupancy_enabled_at": self.occupancy_enabled_at
}
def from_dict(self, data: Dict[str, Any]) -> None:
"""Update from dictionary data."""
self.active = data.get("active", True)
self.waiting_for_backend_session = data.get("waiting_for_backend_session", False)
self.wait_start_time = data.get("wait_start_time", 0.0)
self.reset_tracker_on_resume = data.get("reset_tracker_on_resume", False)
self.occupancy_mode = data.get("occupancy_mode", False)
self.occupancy_enabled_at = data.get("occupancy_enabled_at")
@dataclass
class TrackingData:
"""Complete tracking data for a camera and model."""
track_stability_counters: Dict[int, int] = field(default_factory=dict)
stable_tracks: Set[int] = field(default_factory=set)
session_state: SessionState = field(default_factory=SessionState)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"track_stability_counters": dict(self.track_stability_counters),
"stable_tracks": list(self.stable_tracks),
"session_state": self.session_state.to_dict()
}
class TrackingManager:
"""
Manages object tracking state and session handling across cameras and models.
This class provides centralized tracking state management including:
- Track stability counters
- Session state management
- Backend session timeout handling
- Occupancy detection mode
"""
def __init__(self, session_timeout_seconds: int = SESSION_TIMEOUT_SECONDS):
"""
Initialize tracking manager.
Args:
session_timeout_seconds: Timeout for backend session waiting
"""
self.session_timeout_seconds = session_timeout_seconds
self._camera_tracking_data: Dict[str, Dict[str, TrackingData]] = {}
self._lock = None
def _ensure_thread_safety(self):
"""Initialize thread safety if not already done."""
if self._lock is None:
import threading
self._lock = threading.RLock()
def get_camera_tracking_data(self, camera_id: str, model_id: str) -> TrackingData:
"""
Get or create tracking data for a specific camera and model.
Args:
camera_id: Unique camera identifier
model_id: Model identifier
Returns:
Tracking data for the camera and model
"""
self._ensure_thread_safety()
with self._lock:
if camera_id not in self._camera_tracking_data:
self._camera_tracking_data[camera_id] = {}
if model_id not in self._camera_tracking_data[camera_id]:
logger.warning(f"🔄 Camera {camera_id}: Creating NEW tracking data for {model_id} - this will reset any cooldown!")
self._camera_tracking_data[camera_id][model_id] = TrackingData()
return self._camera_tracking_data[camera_id][model_id]
def check_stable_tracks(self,
camera_id: str,
model_id: str,
regions_dict: Dict[str, Any]) -> Tuple[bool, List[Dict[str, Any]]]:
"""
Check if any stable tracks match the detected classes for a specific camera.
Args:
camera_id: Unique camera identifier
model_id: Model identifier
regions_dict: Dictionary mapping class names to detection regions
Returns:
Tuple of (has_stable_tracks, stable_detections)
"""
self._ensure_thread_safety()
with self._lock:
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
stable_tracks = tracking_data.stable_tracks
if not stable_tracks:
return False, []
# Check for track-based stability
stable_detections = []
for class_name, region_data in regions_dict.items():
detection = region_data.get("detection", {})
track_id = detection.get("id")
if track_id in stable_tracks:
stable_detections.append(detection)
has_stable = len(stable_detections) > 0
logger.debug(f"Camera {camera_id}: Stable track check - stable_tracks: {list(stable_tracks)}, found: {has_stable}")
return has_stable, stable_detections
def reset_tracking_state(self,
camera_id: str,
model_id: str,
reason: str = "session ended") -> None:
"""
Reset tracking state after session completion or timeout.
Args:
camera_id: Unique camera identifier
model_id: Model identifier
reason: Reason for reset (for logging)
"""
self._ensure_thread_safety()
with self._lock:
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
session_state = tracking_data.session_state
# Clear all tracking data for fresh start
tracking_data.track_stability_counters.clear()
tracking_data.stable_tracks.clear()
session_state.active = True
session_state.waiting_for_backend_session = False
session_state.wait_start_time = 0.0
session_state.reset_tracker_on_resume = True
logger.info(f"Camera {camera_id}: 🔄 Reset tracking state - {reason}")
logger.info(f"Camera {camera_id}: 🧹 Cleared stability counters and stable tracks for fresh session")
def is_camera_active(self, camera_id: str, model_id: str) -> bool:
"""
Check if camera should be processing detections.
Args:
camera_id: Unique camera identifier
model_id: Model identifier
Returns:
True if camera should process detections, False otherwise
"""
self._ensure_thread_safety()
with self._lock:
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
session_state = tracking_data.session_state
# Check if waiting for backend sessionId has timed out
if session_state.waiting_for_backend_session:
current_time = time.time()
wait_start_time = session_state.wait_start_time
elapsed_time = current_time - wait_start_time
if elapsed_time >= self.session_timeout_seconds:
logger.warning(f"Camera {camera_id}: Backend sessionId timeout ({self.session_timeout_seconds}s) - resetting tracking")
self.reset_tracking_state(camera_id, model_id, "backend sessionId timeout")
return True
else:
remaining_time = self.session_timeout_seconds - elapsed_time
logger.debug(f"Camera {camera_id}: Still waiting for backend sessionId ({remaining_time:.1f}s remaining)")
return False
return session_state.active
def set_waiting_for_backend_session(self,
camera_id: str,
model_id: str,
waiting: bool = True) -> None:
"""
Set whether camera is waiting for backend session ID.
Args:
camera_id: Unique camera identifier
model_id: Model identifier
waiting: Whether to wait for backend session
"""
self._ensure_thread_safety()
with self._lock:
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
session_state = tracking_data.session_state
session_state.waiting_for_backend_session = waiting
if waiting:
session_state.wait_start_time = time.time()
logger.debug(f"Camera {camera_id}: Started waiting for backend sessionId")
else:
session_state.wait_start_time = 0.0
logger.debug(f"Camera {camera_id}: Stopped waiting for backend sessionId")
def cleanup_camera_tracking(self, camera_id: str) -> None:
"""
Clean up tracking data when a camera is disconnected.
Args:
camera_id: Unique camera identifier
"""
self._ensure_thread_safety()
with self._lock:
if camera_id in self._camera_tracking_data:
del self._camera_tracking_data[camera_id]
logger.info(f"Cleaned up tracking data for camera {camera_id}")
def set_occupancy_mode(self,
camera_id: str,
model_id: str,
enable: bool = True) -> bool:
"""
Enable or disable occupancy detection mode.
Occupancy mode stops model inference after pipeline completion
while backend session handling continues in background.
Args:
camera_id: Unique camera identifier
model_id: Model identifier
enable: True to enable occupancy mode, False to disable
Returns:
Current occupancy mode state
"""
self._ensure_thread_safety()
with self._lock:
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
session_state = tracking_data.session_state
if enable:
session_state.occupancy_mode = True
session_state.occupancy_enabled_at = time.time()
logger.debug(f"Camera {camera_id}: Occupancy mode ENABLED - model will stop after pipeline completion")
else:
session_state.occupancy_mode = False
session_state.occupancy_enabled_at = None
logger.debug(f"Camera {camera_id}: Occupancy mode DISABLED - model will continue running")
return session_state.occupancy_mode
def is_occupancy_mode_enabled(self, camera_id: str, model_id: str) -> bool:
"""
Check if occupancy mode is enabled for a camera.
Args:
camera_id: Unique camera identifier
model_id: Model identifier
Returns:
True if occupancy mode is enabled
"""
self._ensure_thread_safety()
with self._lock:
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
return tracking_data.session_state.occupancy_mode
def get_occupancy_duration(self, camera_id: str, model_id: str) -> Optional[float]:
"""
Get duration since occupancy mode was enabled.
Args:
camera_id: Unique camera identifier
model_id: Model identifier
Returns:
Duration in seconds since occupancy mode was enabled, or None if not enabled
"""
self._ensure_thread_safety()
with self._lock:
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
session_state = tracking_data.session_state
if not session_state.occupancy_mode or not session_state.occupancy_enabled_at:
return None
return time.time() - session_state.occupancy_enabled_at
def get_session_state(self, camera_id: str, model_id: str) -> SessionState:
"""
Get session state for a camera and model.
Args:
camera_id: Unique camera identifier
model_id: Model identifier
Returns:
Session state object
"""
self._ensure_thread_safety()
with self._lock:
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
return tracking_data.session_state
def update_session_state(self,
camera_id: str,
model_id: str,
**kwargs) -> None:
"""
Update session state properties.
Args:
camera_id: Unique camera identifier
model_id: Model identifier
**kwargs: Session state properties to update
"""
self._ensure_thread_safety()
with self._lock:
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
session_state = tracking_data.session_state
for key, value in kwargs.items():
if hasattr(session_state, key):
setattr(session_state, key, value)
logger.debug(f"Camera {camera_id}: Updated session state {key} = {value}")
else:
logger.warning(f"Camera {camera_id}: Unknown session state property: {key}")
def get_tracking_statistics(self, camera_id: str, model_id: str) -> Dict[str, Any]:
"""
Get comprehensive tracking statistics for a camera and model.
Args:
camera_id: Unique camera identifier
model_id: Model identifier
Returns:
Dictionary with tracking statistics
"""
self._ensure_thread_safety()
with self._lock:
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
current_time = time.time()
stats = tracking_data.to_dict()
# Add computed statistics
stats["total_tracked_objects"] = len(tracking_data.track_stability_counters)
stats["stable_track_count"] = len(tracking_data.stable_tracks)
session_state = tracking_data.session_state
if session_state.waiting_for_backend_session and session_state.wait_start_time > 0:
stats["backend_wait_duration"] = current_time - session_state.wait_start_time
stats["backend_wait_remaining"] = max(0, self.session_timeout_seconds - stats["backend_wait_duration"])
if session_state.occupancy_mode and session_state.occupancy_enabled_at:
stats["occupancy_duration"] = current_time - session_state.occupancy_enabled_at
return stats
def get_all_camera_stats(self) -> Dict[str, Dict[str, Dict[str, Any]]]:
"""
Get tracking statistics for all monitored cameras.
Returns:
Nested dictionary: {camera_id: {model_id: stats}}
"""
self._ensure_thread_safety()
with self._lock:
all_stats = {}
for camera_id, models in self._camera_tracking_data.items():
all_stats[camera_id] = {}
for model_id in models.keys():
all_stats[camera_id][model_id] = self.get_tracking_statistics(camera_id, model_id)
return all_stats
# Global tracking manager instance
tracking_manager = TrackingManager()
# ===== CONVENIENCE FUNCTIONS =====
# These provide the same interface as the original functions in pympta.py
def check_stable_tracks(camera_id: str, model_id: str, regions_dict: Dict[str, Any]) -> Tuple[bool, List[Dict[str, Any]]]:
"""Check if any stable tracks match the detected classes for a specific camera."""
return tracking_manager.check_stable_tracks(camera_id, model_id, regions_dict)
def reset_tracking_state(camera_id: str, model_id: str, reason: str = "session ended") -> None:
"""Reset tracking state after session completion or timeout."""
tracking_manager.reset_tracking_state(camera_id, model_id, reason)
def is_camera_active(camera_id: str, model_id: str) -> bool:
"""Check if camera should be processing detections."""
return tracking_manager.is_camera_active(camera_id, model_id)
def cleanup_camera_stability(camera_id: str) -> None:
"""Clean up tracking data when a camera is disconnected."""
tracking_manager.cleanup_camera_tracking(camera_id)
def occupancy_detector(camera_id: str, model_id: str, enable: bool = True) -> bool:
"""
Enable or disable occupancy detection mode.
When enabled:
- Model stops inference after completing full pipeline
- Backend sessionId handling continues in background
Note: This is a temporary function that will be changed in the future.
"""
return tracking_manager.set_occupancy_mode(camera_id, model_id, enable)
def get_camera_tracking_data(camera_id: str, model_id: str) -> Dict[str, Any]:
"""Get tracking data for a specific camera and model."""
tracking_data = tracking_manager.get_camera_tracking_data(camera_id, model_id)
return tracking_data.to_dict()
def set_waiting_for_backend_session(camera_id: str, model_id: str, waiting: bool = True) -> None:
"""Set whether camera is waiting for backend session ID."""
tracking_manager.set_waiting_for_backend_session(camera_id, model_id, waiting)
def get_tracking_statistics(camera_id: str, model_id: str) -> Dict[str, Any]:
"""Get comprehensive tracking statistics for a camera and model."""
return tracking_manager.get_tracking_statistics(camera_id, model_id)

View file

@ -0,0 +1,633 @@
"""
YOLO detection with BoT-SORT tracking integration.
This module provides the main detection functionality combining YOLO object detection
with BoT-SORT tracking for stable track validation.
"""
import cv2
import logging
from typing import Dict, List, Any, Optional, Tuple, Set
from dataclasses import dataclass, field
from datetime import datetime
from ..core.constants import DEFAULT_STABILITY_THRESHOLD, DEFAULT_MIN_CONFIDENCE
from ..core.exceptions import DetectionError, create_detection_error
from ..detection.detection_result import DetectionResult, BoundingBox, DetectionSession
logger = logging.getLogger(__name__)
@dataclass
class TrackingConfig:
"""Configuration for object tracking."""
enabled: bool = True
method: str = "botsort"
reid_config_path: str = "botsort.yaml"
stability_threshold: int = DEFAULT_STABILITY_THRESHOLD
use_tracking_for_model: bool = True # False for frontal detection models
@dataclass
class StabilityData:
"""Track stability data for a camera and model."""
track_stability_counters: Dict[int, int] = field(default_factory=dict)
stable_tracks: Set[int] = field(default_factory=set)
session_state: Dict[str, Any] = field(default_factory=lambda: {
"active": True,
"waiting_for_backend_session": False,
"wait_start_time": 0.0,
"reset_tracker_on_resume": False
})
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"track_stability_counters": dict(self.track_stability_counters),
"stable_tracks": list(self.stable_tracks),
"session_state": self.session_state.copy()
}
@dataclass
class ValidationResult:
"""Result of track stability validation."""
validation_complete: bool
stable_tracks: List[int] = field(default_factory=list)
current_tracks: List[int] = field(default_factory=list)
newly_stable_tracks: List[int] = field(default_factory=list)
send_none_detection: bool = False
branch_node: bool = False
bypass_validation: bool = False
awaiting_stable_tracks: bool = False
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"validation_complete": self.validation_complete,
"stable_tracks": self.stable_tracks.copy(),
"current_tracks": self.current_tracks.copy(),
"newly_stable_tracks": self.newly_stable_tracks.copy(),
"send_none_detection": self.send_none_detection,
"branch_node": self.branch_node,
"bypass_validation": self.bypass_validation,
"awaiting_stable_tracks": self.awaiting_stable_tracks
}
class StabilityTracker:
"""Manages track stability validation for cameras and models."""
def __init__(self):
"""Initialize stability tracker."""
self._camera_stability_tracking: Dict[str, Dict[str, StabilityData]] = {}
self._lock = None
def _ensure_thread_safety(self):
"""Initialize thread safety if not already done."""
if self._lock is None:
import threading
self._lock = threading.RLock()
def get_camera_stability_data(self, camera_id: str, model_id: str) -> StabilityData:
"""
Get or create stability tracking data for a specific camera and model.
Args:
camera_id: Unique camera identifier
model_id: Model identifier
Returns:
Stability data for the camera and model
"""
self._ensure_thread_safety()
with self._lock:
if camera_id not in self._camera_stability_tracking:
self._camera_stability_tracking[camera_id] = {}
if model_id not in self._camera_stability_tracking[camera_id]:
logger.warning(f"🔄 Camera {camera_id}: Creating NEW stability data for {model_id} - this will reset any cooldown!")
self._camera_stability_tracking[camera_id][model_id] = StabilityData()
return self._camera_stability_tracking[camera_id][model_id]
def reset_camera_stability_tracking(self, camera_id: str, model_id: str) -> None:
"""
Reset all stability tracking data for a specific camera and model.
Args:
camera_id: Unique camera identifier
model_id: Model identifier
"""
self._ensure_thread_safety()
with self._lock:
if camera_id in self._camera_stability_tracking and model_id in self._camera_stability_tracking[camera_id]:
stability_data = self._camera_stability_tracking[camera_id][model_id]
# Clear all tracking data
old_counters = dict(stability_data.track_stability_counters)
old_stable = list(stability_data.stable_tracks)
stability_data.track_stability_counters.clear()
stability_data.stable_tracks.clear()
# IMPORTANT: Set flag to reset YOLO tracker on next detection run
stability_data.session_state["reset_tracker_on_resume"] = True
logger.info(f"🧹 Camera {camera_id}: CLEARED stability tracking - old_counters={old_counters}, old_stable={old_stable}")
logger.info(f"🔄 Camera {camera_id}: YOLO tracker will be reset on next detection - fresh track IDs will start from 1")
def update_single_track_stability(self,
detection: Optional[Dict[str, Any]],
camera_id: str,
model_id: str,
stability_threshold: int,
current_mode: str,
is_branch_node: bool = False) -> ValidationResult:
"""
Update track stability validation for a single highest confidence detection.
Args:
detection: Detection data with track_id, or None if no detection
camera_id: Unique camera identifier
model_id: Model identifier
stability_threshold: Number of consecutive frames needed for stability
current_mode: Current pipeline mode
is_branch_node: Whether this is a branch node (skip validation)
Returns:
Validation result with track status
"""
self._ensure_thread_safety()
with self._lock:
# Branch nodes should not do validation - only main pipeline should
if is_branch_node:
logger.debug(f"⏭️ Camera {camera_id}: Skipping validation for branch node {model_id} - validation only done at main pipeline level")
return ValidationResult(
validation_complete=False,
branch_node=True,
stable_tracks=[],
current_tracks=[]
)
# Check current mode - validation counters should increment in both validation_detecting and full_pipeline modes
is_validation_mode = current_mode in ["validation_detecting", "full_pipeline"]
# Get camera-specific stability data
stability_data = self.get_camera_stability_data(camera_id, model_id)
track_counters = stability_data.track_stability_counters
stable_tracks = stability_data.stable_tracks
current_track_id = detection.get("id") if detection else None
# ═══ MODE-AWARE TRACK VALIDATION ═══
logger.debug(f"📋 Camera {camera_id}: === TRACK VALIDATION ANALYSIS ===")
logger.debug(f"📋 Camera {camera_id}: Current mode: {current_mode} (validation_mode={is_validation_mode})")
logger.debug(f"📋 Camera {camera_id}: Current track_id: {current_track_id} (assigned by YOLO tracking - not sequential)")
logger.debug(f"📋 Camera {camera_id}: Existing counters: {dict(track_counters)}")
logger.debug(f"📋 Camera {camera_id}: Stable tracks: {list(stable_tracks)}")
# IMPORTANT: Only modify validation counters during validation_detecting mode
if not is_validation_mode:
logger.debug(f"🚫 Camera {camera_id}: NOT in validation mode - skipping counter modifications")
return ValidationResult(
validation_complete=False,
stable_tracks=list(stable_tracks),
current_tracks=[current_track_id] if current_track_id is not None else []
)
if current_track_id is not None:
# Check if this is a different track than we were tracking
previous_track_ids = list(track_counters.keys())
# VALIDATION MODE: Reset counter if different track OR if track was previously stable
should_reset = (
len(previous_track_ids) == 0 or # No previous tracking
current_track_id not in previous_track_ids or # Different track ID
current_track_id in stable_tracks # Track was stable - start fresh validation
)
logger.debug(f"📋 Camera {camera_id}: Previous track_ids: {previous_track_ids}")
logger.debug(f"📋 Camera {camera_id}: Track {current_track_id} was stable: {current_track_id in stable_tracks}")
logger.debug(f"📋 Camera {camera_id}: Should reset counters: {should_reset}")
if should_reset:
# Clear all previous tracking - fresh validation needed
if previous_track_ids:
for old_track_id in previous_track_ids:
old_count = track_counters.pop(old_track_id, 0)
stable_tracks.discard(old_track_id)
logger.info(f"🔄 Camera {camera_id}: VALIDATION RESET - track {old_track_id} counter from {old_count} to 0 (reason: {'stable_track_restart' if current_track_id == old_track_id else 'different_track'})")
# Start fresh validation for this track
old_count = track_counters.get(current_track_id, 0) # Store old count for logging
track_counters[current_track_id] = 1
current_count = 1
logger.info(f"🆕 Camera {camera_id}: FRESH VALIDATION - Track {current_track_id} starting at 1/{stability_threshold}")
else:
# Continue validation for same track
old_count = track_counters.get(current_track_id, 0)
track_counters[current_track_id] = old_count + 1
current_count = track_counters[current_track_id]
logger.debug(f"🔢 Camera {camera_id}: Track {current_track_id} counter: {old_count}{current_count}")
logger.info(f"🔍 Camera {camera_id}: Track ID {current_track_id} validation {current_count}/{stability_threshold}")
# Check if track has reached stability threshold
logger.debug(f"📊 Camera {camera_id}: Checking stability: {current_count} >= {stability_threshold}? {current_count >= stability_threshold}")
logger.debug(f"📊 Camera {camera_id}: Already stable: {current_track_id in stable_tracks}")
if current_count >= stability_threshold and current_track_id not in stable_tracks:
stable_tracks.add(current_track_id)
logger.info(f"✅ Camera {camera_id}: Track ID {current_track_id} STABLE after {current_count} consecutive frames")
logger.info(f"🎯 Camera {camera_id}: TRACK VALIDATION COMPLETE")
logger.debug(f"🎯 Camera {camera_id}: Stable tracks now: {list(stable_tracks)}")
return ValidationResult(
validation_complete=True,
send_none_detection=True,
stable_tracks=[current_track_id],
newly_stable_tracks=[current_track_id],
current_tracks=[current_track_id]
)
elif current_count >= stability_threshold:
logger.debug(f"📊 Camera {camera_id}: Track {current_track_id} already stable - not re-adding")
else:
# No car detected - ALWAYS clear all tracking and reset counters
logger.debug(f"🚫 Camera {camera_id}: NO CAR DETECTED - clearing all tracking")
if track_counters or stable_tracks:
logger.debug(f"🚫 Camera {camera_id}: Existing state before reset: counters={dict(track_counters)}, stable={list(stable_tracks)}")
for track_id in list(track_counters.keys()):
old_count = track_counters.pop(track_id, 0)
logger.info(f"🔄 Camera {camera_id}: No car detected - RESET track {track_id} counter from {old_count} to 0")
track_counters.clear() # Ensure complete reset
stable_tracks.clear() # Clear all stable tracks
logger.info(f"✅ Camera {camera_id}: RESET TO VALIDATION PHASE - All counters and stable tracks cleared")
else:
logger.debug(f"🚫 Camera {camera_id}: No existing counters to clear")
logger.debug(f"Camera {camera_id}: VALIDATION - no car detected (all counters reset)")
# Final return - validation not complete
result = ValidationResult(
validation_complete=False,
stable_tracks=list(stable_tracks),
current_tracks=[current_track_id] if current_track_id is not None else []
)
logger.debug(f"📋 Camera {camera_id}: Track stability result: {result.to_dict()}")
logger.debug(f"📋 Camera {camera_id}: Final counters: {dict(track_counters)}")
logger.debug(f"📋 Camera {camera_id}: Final stable tracks: {list(stable_tracks)}")
return result
class YOLODetector:
"""
YOLO object detector with BoT-SORT tracking integration.
Provides structured detection functionality with track stability validation
and multi-class detection support.
"""
def __init__(self, stability_tracker: Optional[StabilityTracker] = None):
"""
Initialize YOLO detector.
Args:
stability_tracker: Optional stability tracker instance
"""
self.stability_tracker = stability_tracker or StabilityTracker()
def _extract_tracking_config(self, node: Dict[str, Any]) -> TrackingConfig:
"""Extract tracking configuration from pipeline node."""
tracking_config = node.get("tracking", {})
model_id = node.get("modelId", "")
# Don't use tracking for frontal detection models
use_tracking = not ("frontal" in model_id.lower() or "detection" in model_id.lower())
return TrackingConfig(
enabled=tracking_config.get("enabled", True),
method=tracking_config.get("method", "botsort"),
reid_config_path=tracking_config.get("reidConfig", tracking_config.get("reidConfigPath", "botsort.yaml")),
stability_threshold=tracking_config.get("stabilityThreshold", node.get("stabilityThreshold", DEFAULT_STABILITY_THRESHOLD)),
use_tracking_for_model=use_tracking
)
def _determine_confidence_threshold(self, node: Dict[str, Any]) -> float:
"""Determine confidence threshold based on model type."""
model_id = node.get("modelId", "")
# Use frontalMinConfidence for frontal detection models
if "frontal" in model_id.lower() and "frontalMinConfidence" in node:
min_confidence = node.get("frontalMinConfidence", DEFAULT_MIN_CONFIDENCE)
logger.debug(f"Using frontalMinConfidence={min_confidence} for {model_id}")
else:
min_confidence = node.get("minConfidence", DEFAULT_MIN_CONFIDENCE)
return min_confidence
def _reset_yolo_tracker_if_needed(self, node: Dict[str, Any], camera_id: str, model_id: str) -> None:
"""Reset YOLO tracker if flagged for reset."""
stability_data = self.stability_tracker.get_camera_stability_data(camera_id, model_id)
session_state = stability_data.session_state
if session_state.get("reset_tracker_on_resume", False):
# Reset YOLO tracker to get fresh track IDs
if hasattr(node["model"], 'trackers') and node["model"].trackers:
node["model"].trackers.clear() # Clear tracker state
logger.info(f"Camera {camera_id}: 🔄 Reset YOLO tracker - new cars will get fresh track IDs")
session_state["reset_tracker_on_resume"] = False # Clear the flag
def _run_yolo_inference(self,
frame,
node: Dict[str, Any],
tracking_config: TrackingConfig) -> Any:
"""Run YOLO inference with or without tracking."""
# Prepare class filtering
trigger_class_indices = node.get("triggerClassIndices")
class_filter = {"classes": trigger_class_indices} if trigger_class_indices else {}
model_id = node.get("modelId", "")
use_tracking = tracking_config.enabled and tracking_config.use_tracking_for_model
logger.debug(f"Running detection for {model_id} - tracking: {use_tracking}, stability_threshold: {tracking_config.stability_threshold}, classes: {node.get('triggerClasses', 'all')}")
if use_tracking:
# Use tracking for main detection models (yolo11m, etc.)
logger.debug(f"Using tracking for {model_id}")
res = node["model"].track(
frame,
stream=False,
persist=True,
**class_filter
)[0]
else:
# Use detection only for frontal detection and other detection-only models
logger.debug(f"Using prediction only for {model_id}")
res = node["model"].predict(
frame,
stream=False,
**class_filter
)[0]
return res
def _process_detections(self,
res: Any,
node: Dict[str, Any],
camera_id: str,
min_confidence: float) -> List[Dict[str, Any]]:
"""Process YOLO detection results into candidate detections."""
candidate_detections = []
if res.boxes is None or len(res.boxes) == 0:
logger.debug(f"🚫 Camera {camera_id}: YOLO returned no detections")
return candidate_detections
logger.debug(f"🔍 Camera {camera_id}: YOLO detected {len(res.boxes)} raw objects - processing with tracking...")
logger.debug(f"🔍 Camera {camera_id}: === DETECTION ANALYSIS ===")
for i, box in enumerate(res.boxes):
# Extract detection data
conf = float(box.cpu().conf[0])
cls_id = int(box.cpu().cls[0])
class_name = node["model"].names[cls_id]
# Extract bounding box
xy = box.cpu().xyxy[0]
x1, y1, x2, y2 = map(int, xy)
bbox = (x1, y1, x2, y2)
# Extract tracking ID if available
track_id = None
if hasattr(box, "id") and box.id is not None:
track_id = int(box.id.item())
logger.debug(f"🔍 Camera {camera_id}: Detection {i+1}: class='{class_name}' conf={conf:.3f} track_id={track_id} bbox={bbox}")
# Apply confidence filtering
if conf < min_confidence:
logger.debug(f"❌ Camera {camera_id}: Detection {i+1} REJECTED - confidence {conf:.3f} < {min_confidence}")
continue
# Create detection object
detection = {
"class": class_name,
"confidence": conf,
"id": track_id,
"bbox": bbox,
"class_id": cls_id
}
candidate_detections.append(detection)
logger.debug(f"✅ Camera {camera_id}: Detection {i+1} ACCEPTED as candidate: {class_name} (conf={conf:.3f}, track_id={track_id})")
return candidate_detections
def _select_best_detection(self,
candidate_detections: List[Dict[str, Any]],
camera_id: str) -> Optional[Dict[str, Any]]:
"""Select the highest confidence detection from candidates."""
if not candidate_detections:
logger.debug(f"🚫 Camera {camera_id}: No valid candidates after filtering - no car will be tracked")
return None
logger.debug(f"🏆 Camera {camera_id}: === SELECTING HIGHEST CONFIDENCE CAR ===")
for i, detection in enumerate(candidate_detections):
logger.debug(f"🏆 Camera {camera_id}: Candidate {i+1}: {detection['class']} conf={detection['confidence']:.3f} track_id={detection['id']}")
# Find the single highest confidence detection across all detected classes
best_detection = max(candidate_detections, key=lambda x: x["confidence"])
logger.info(f"🎯 Camera {camera_id}: SELECTED WINNER: {best_detection['class']} (conf={best_detection['confidence']:.3f}, track_id={best_detection['id']}, bbox={best_detection['bbox']})")
# Show which cars were NOT selected
for detection in candidate_detections:
if detection != best_detection:
logger.debug(f"🚫 Camera {camera_id}: NOT SELECTED: {detection['class']} (conf={detection['confidence']:.3f}, track_id={detection['id']}) - lower confidence")
return best_detection
def _apply_class_mapping(self, detection: Dict[str, Any], node: Dict[str, Any]) -> Dict[str, Any]:
"""Apply class mapping if configured."""
original_class = detection["class"]
class_mapping = node.get("classMapping", {})
if original_class in class_mapping:
mapped_class = class_mapping[original_class]
logger.info(f"Class mapping applied: {original_class}{mapped_class}")
# Update the detection object with mapped class
detection["class"] = mapped_class
detection["original_class"] = original_class # Keep original for reference
return detection
def _validate_multi_class(self,
regions_dict: Dict[str, Any],
node: Dict[str, Any]) -> bool:
"""Validate multi-class detection requirements."""
if not (node.get("multiClass", False) and node.get("expectedClasses")):
return True
expected_classes = node["expectedClasses"]
detected_classes = list(regions_dict.keys())
logger.debug(f"Multi-class validation: expected={expected_classes}, detected={detected_classes}")
# Check for required classes (flexible - at least one must match)
matching_classes = [cls for cls in expected_classes if cls in detected_classes]
if not matching_classes:
logger.warning(f"Multi-class validation failed: no expected classes detected")
return False
logger.info(f"Multi-class validation passed: {matching_classes} detected")
return True
def run_detection_with_tracking(self,
frame,
node: Dict[str, Any],
context: Optional[Dict[str, Any]] = None) -> Tuple[List[Dict[str, Any]], Dict[str, Any], Dict[str, Any]]:
"""
Run YOLO detection with BoT-SORT tracking and stability validation.
Args:
frame: Input frame/image
node: Pipeline node configuration with model and settings
context: Optional context information (camera info, session data, etc.)
Returns:
tuple: (all_detections, regions_dict, track_validation_result) where:
- all_detections: List of all detection objects
- regions_dict: Dict mapping class names to highest confidence detections
- track_validation_result: Dict with validation status and stable tracks
"""
try:
camera_id = context.get("camera_id", "unknown") if context else "unknown"
model_id = node.get("modelId", "unknown")
current_mode = context.get("current_mode", "unknown") if context else "unknown"
# Extract configuration
tracking_config = self._extract_tracking_config(node)
min_confidence = self._determine_confidence_threshold(node)
# Check if we need to reset tracker after cooldown
self._reset_yolo_tracker_if_needed(node, camera_id, model_id)
# Run YOLO inference
res = self._run_yolo_inference(frame, node, tracking_config)
# Process detection results
candidate_detections = self._process_detections(res, node, camera_id, min_confidence)
# Select best detection
best_detection = self._select_best_detection(candidate_detections, camera_id)
# Update track stability validation
is_branch_node = node.get("cropClass") is not None or node.get("parallel") is True
track_validation_result = self.stability_tracker.update_single_track_stability(
detection=best_detection,
camera_id=camera_id,
model_id=model_id,
stability_threshold=tracking_config.stability_threshold,
current_mode=current_mode,
is_branch_node=is_branch_node
)
# Handle no detection case
if best_detection is None:
# Store validation state in context for pipeline decisions
if context is not None:
context["track_validation_result"] = track_validation_result.to_dict()
return [], {}, track_validation_result.to_dict()
# Apply class mapping
best_detection = self._apply_class_mapping(best_detection, node)
# Create regions dictionary
mapped_class = best_detection["class"]
track_id = best_detection["id"]
all_detections = [best_detection]
regions_dict = {
mapped_class: {
"bbox": best_detection["bbox"],
"confidence": best_detection["confidence"],
"detection": best_detection,
"track_id": track_id
}
}
# Multi-class validation
if not self._validate_multi_class(regions_dict, node):
return [], {}, track_validation_result.to_dict()
logger.info(f"✅ Camera {camera_id}: DETECTION COMPLETE - tracking single car: track_id={track_id}, conf={best_detection['confidence']:.3f}")
logger.debug(f"📊 Camera {camera_id}: Detection summary: {len(res.boxes)} raw → {len(candidate_detections)} candidates → 1 selected")
# Store validation state in context for pipeline decisions
if context is not None:
context["track_validation_result"] = track_validation_result.to_dict()
return all_detections, regions_dict, track_validation_result.to_dict()
except Exception as e:
camera_id = context.get("camera_id", "unknown") if context else "unknown"
model_id = node.get("modelId", "unknown")
raise create_detection_error(camera_id, model_id, "detection_with_tracking", e)
# Global instances for backward compatibility
_global_stability_tracker = StabilityTracker()
_global_yolo_detector = YOLODetector(_global_stability_tracker)
# ===== CONVENIENCE FUNCTIONS =====
# These provide the same interface as the original functions in pympta.py
def run_detection_with_tracking(frame, node: Dict[str, Any], context: Optional[Dict[str, Any]] = None) -> Tuple[List[Dict[str, Any]], Dict[str, Any], Dict[str, Any]]:
"""Run YOLO detection with BoT-SORT tracking using global detector instance."""
return _global_yolo_detector.run_detection_with_tracking(frame, node, context)
def get_camera_stability_data(camera_id: str, model_id: str) -> Dict[str, Any]:
"""Get stability tracking data for a specific camera and model."""
stability_data = _global_stability_tracker.get_camera_stability_data(camera_id, model_id)
return stability_data.to_dict()
def reset_camera_stability_tracking(camera_id: str, model_id: str) -> None:
"""Reset all stability tracking data for a specific camera and model."""
_global_stability_tracker.reset_camera_stability_tracking(camera_id, model_id)
def update_single_track_stability(node: Dict[str, Any],
detection: Optional[Dict[str, Any]],
camera_id: str,
frame_shape=None,
stability_threshold: int = DEFAULT_STABILITY_THRESHOLD,
context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Update track stability validation for a single highest confidence detection."""
model_id = node.get("modelId", "unknown")
current_mode = context.get("current_mode", "unknown") if context else "unknown"
is_branch_node = node.get("cropClass") is not None or node.get("parallel") is True
result = _global_stability_tracker.update_single_track_stability(
detection=detection,
camera_id=camera_id,
model_id=model_id,
stability_threshold=stability_threshold,
current_mode=current_mode,
is_branch_node=is_branch_node
)
return result.to_dict()
def reset_tracking_state(camera_id: str, model_id: str, reason: str = "reset_requested") -> None:
"""Reset tracking state for a camera and model."""
logger.info(f"🔄 Camera {camera_id}: Resetting tracking state for {model_id} - reason: {reason}")
reset_camera_stability_tracking(camera_id, model_id)

View file

@ -0,0 +1,694 @@
"""
Pipeline execution engine for computer vision detection workflows.
This module provides the main pipeline execution functionality including:
- Multi-class detection coordination
- Branch processing (parallel and sequential)
- Action execution and database operations
- Session state management
"""
import logging
import concurrent.futures
from typing import Dict, List, Any, Optional, Tuple, Union
from dataclasses import dataclass, field
from datetime import datetime
import numpy as np
from ..core.constants import (
DEFAULT_THREAD_POOL_SIZE,
CLASSIFICATION_TIMEOUT_SECONDS,
PIPELINE_EXECUTION_TIMEOUT
)
from ..core.exceptions import PipelineError, create_pipeline_error
from ..detection.detection_result import DetectionResult, DetectionSession
from ..detection.yolo_detector import run_detection_with_tracking
from ..detection.stability_validator import validate_pipeline_execution
from ..detection.tracking_manager import is_camera_active, occupancy_detector
logger = logging.getLogger(__name__)
@dataclass
class PipelineContext:
"""Context information passed through pipeline execution."""
camera_id: str = "unknown"
backend_session_id: Optional[str] = None
display_id: Optional[str] = None
current_mode: str = "unknown"
regions_dict: Optional[Dict[str, Any]] = None
session_id: Optional[str] = None
timestamp: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"camera_id": self.camera_id,
"backend_session_id": self.backend_session_id,
"display_id": self.display_id,
"current_mode": self.current_mode,
"regions_dict": self.regions_dict,
"session_id": self.session_id,
"timestamp": self.timestamp
}
@dataclass
class BranchResult:
"""Result from branch execution."""
model_id: str
success: bool
result: Optional[Dict[str, Any]] = None
error: Optional[str] = None
nested_results: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
data = {
"model_id": self.model_id,
"success": self.success
}
if self.result:
data["result"] = self.result
if self.error:
data["error"] = self.error
if self.nested_results:
data["nested_results"] = self.nested_results
return data
@dataclass
class PipelineResult:
"""Result from pipeline execution."""
success: bool
primary_detection: Optional[Dict[str, Any]] = None
primary_bbox: Optional[List[int]] = None
branch_results: Dict[str, Any] = field(default_factory=dict)
session_id: Optional[str] = None
awaiting_session_id: bool = False
awaiting_stable_tracks: bool = False
def to_tuple(self, return_bbox: bool = False) -> Union[Dict[str, Any], Tuple[Dict[str, Any], List[int]], Tuple[None, None]]:
"""Convert to return format expected by original function."""
if not self.success or not self.primary_detection:
return (None, None) if return_bbox else None
if return_bbox:
return (self.primary_detection, self.primary_bbox or [0, 0, 0, 0])
else:
return self.primary_detection
class PipelineExecutor:
"""
Main pipeline execution engine for computer vision detection workflows.
This class handles the complete pipeline including detection, tracking,
branch processing, action execution, and database operations.
"""
def __init__(self, thread_pool_size: int = DEFAULT_THREAD_POOL_SIZE):
"""
Initialize pipeline executor.
Args:
thread_pool_size: Maximum number of threads for parallel processing
"""
self.thread_pool_size = thread_pool_size
def _extract_context(self, context: Optional[Dict[str, Any]]) -> PipelineContext:
"""Extract pipeline context from input dictionary."""
if not context:
return PipelineContext()
return PipelineContext(
camera_id=context.get("camera_id", "unknown"),
backend_session_id=context.get("backend_session_id"),
display_id=context.get("display_id"),
current_mode=context.get("current_mode", "unknown"),
regions_dict=context.get("regions_dict"),
session_id=context.get("session_id"),
timestamp=context.get("timestamp")
)
def _handle_classification_task(self,
frame: np.ndarray,
node: Dict[str, Any],
pipeline_context: PipelineContext,
return_bbox: bool) -> Union[Dict[str, Any], Tuple[Dict[str, Any], List[int]], Tuple[None, None]]:
"""Handle classification-only pipeline nodes."""
try:
results = node["model"].predict(frame, stream=False)
if not results:
return (None, None) if return_bbox else None
r = results[0]
probs = r.probs
if probs is None:
return (None, None) if return_bbox else None
top1_idx = int(probs.top1)
top1_conf = float(probs.top1conf)
class_name = node["model"].names[top1_idx]
det = {
"class": class_name,
"confidence": top1_conf,
"id": None,
class_name: class_name # Add class name as key for backward compatibility
}
# Add specific field mappings for database operations based on model type
model_id = node.get("modelId", "").lower()
if "brand" in model_id or "brand_cls" in model_id:
det["brand"] = class_name
elif "bodytype" in model_id or "body" in model_id:
det["body_type"] = class_name
elif "color" in model_id:
det["color"] = class_name
# Execute actions for classification nodes
self._execute_node_actions(node, frame, det, pipeline_context.regions_dict)
return (det, None) if return_bbox else det
except Exception as e:
logger.error(f"Error in classification task for {node.get('modelId')}: {e}")
return (None, None) if return_bbox else None
def _check_camera_active(self, camera_id: str, model_id: str, return_bbox: bool) -> Optional[Union[Dict[str, Any], Tuple[Dict[str, Any], List[int]]]]:
"""Check if camera is active for processing."""
if not is_camera_active(camera_id, model_id):
logger.debug(f"⏰ Camera {camera_id}: Waiting for backend sessionId, sending 'none' detection")
none_detection = {
"class": "none",
"confidence": 1.0,
"bbox": [0, 0, 0, 0],
"branch_results": {}
}
return (none_detection, [0, 0, 0, 0]) if return_bbox else none_detection
return None
def _run_detection_stage(self,
frame: np.ndarray,
node: Dict[str, Any],
pipeline_context: PipelineContext,
validated_detection: Optional[Dict[str, Any]] = None) -> Tuple[List[Dict[str, Any]], Dict[str, Any], Dict[str, Any]]:
"""Run the detection stage of the pipeline."""
if validated_detection:
track_id = validated_detection.get('track_id')
logger.info(f"🔄 PIPELINE: Using validated detection from validation phase - track_id={track_id}")
# Convert validated detection back to all_detections format for branch processing
all_detections = [validated_detection]
# Create regions_dict based on validated detection class with proper structure
class_name = validated_detection.get("class", "car")
regions_dict = {
class_name: {
"confidence": validated_detection.get("confidence"),
"bbox": validated_detection.get("bbox", [0, 0, 0, 0]),
"detection": validated_detection
}
}
# Bypass track validation completely - force pipeline execution
track_validation_result = {
"validation_complete": True,
"stable_tracks": ["cached"], # Use dummy stable track to force pipeline execution
"current_tracks": ["cached"],
"bypass_validation": True
}
else:
# Normal detection stage - Using structured detection function
all_detections, regions_dict, track_validation_result = run_detection_with_tracking(
frame, node, pipeline_context.to_dict()
)
return all_detections, regions_dict, track_validation_result
def _validate_tracking_requirements(self,
node: Dict[str, Any],
track_validation_result: Dict[str, Any],
pipeline_context: PipelineContext,
return_bbox: bool) -> Optional[Union[Dict[str, Any], Tuple[Dict[str, Any], List[int]]]]:
"""Validate tracking requirements for pipeline execution."""
tracking_config = node.get("tracking", {})
stability_threshold = tracking_config.get("stabilityThreshold", node.get("stabilityThreshold", 1))
if stability_threshold <= 1 or not tracking_config.get("enabled", True):
return None # No tracking requirements
# Check if this is a branch node - branches should execute regardless of main validation state
is_branch_node = node.get("cropClass") is not None or node.get("parallel") is True
if is_branch_node:
logger.debug(f"🔍 Camera {pipeline_context.camera_id}: Branch node {node.get('modelId')} executing during track validation phase")
return None
# Main pipeline node during track validation - check for stable tracks
stable_tracks = track_validation_result.get("stable_tracks", [])
if not stable_tracks:
logger.debug(f"🔒 Camera {pipeline_context.camera_id}: Main pipeline requires stable tracks - none found, skipping pipeline execution")
none_detection = {
"class": "none",
"confidence": 1.0,
"bbox": [0, 0, 0, 0],
"awaiting_stable_tracks": True
}
return (none_detection, [0, 0, 0, 0]) if return_bbox else none_detection
logger.info(f"🎯 Camera {pipeline_context.camera_id}: STABLE TRACKS DETECTED - proceeding with full pipeline (tracks: {stable_tracks})")
return None
def _handle_database_operations(self,
node: Dict[str, Any],
detection_result: Dict[str, Any],
regions_dict: Dict[str, Any],
pipeline_context: PipelineContext) -> None:
"""Handle database operations if database manager is available."""
if not (node.get("db_manager") and regions_dict):
return
detected_classes = list(regions_dict.keys())
logger.debug(f"Valid detections found: {detected_classes}")
if pipeline_context.backend_session_id:
# Backend sessionId is available, proceed with database operations
display_id = pipeline_context.display_id or "unknown"
timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
inserted_session_id = node["db_manager"].insert_initial_detection(
display_id=display_id,
captured_timestamp=timestamp,
session_id=pipeline_context.backend_session_id
)
if inserted_session_id:
detection_result["session_id"] = inserted_session_id
detection_result["timestamp"] = timestamp
logger.info(f"💾 DATABASE RECORD CREATED with backend session_id: {inserted_session_id}")
logger.debug(f"Database record: display_id={display_id}, timestamp={timestamp}")
else:
logger.error(f"Failed to create database record with backend session_id: {pipeline_context.backend_session_id}")
else:
logger.info(f"📡 Camera {pipeline_context.camera_id}: Full pipeline completed, detection data will be sent to backend. Database operations will occur when sessionId is received.")
# Store detection info for later database operations when sessionId arrives
detection_result["awaiting_session_id"] = True
detection_result["timestamp"] = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
def _execute_node_actions(self,
node: Dict[str, Any],
frame: np.ndarray,
detection_result: Dict[str, Any],
regions_dict: Optional[Dict[str, Any]]) -> None:
"""Execute actions for a node."""
# This is a placeholder for action execution
# In the actual implementation, this would import and call execute_actions
pass
def _execute_parallel_actions(self,
node: Dict[str, Any],
frame: np.ndarray,
detection_result: Dict[str, Any],
regions_dict: Dict[str, Any]) -> None:
"""Execute parallel actions after branch completion."""
# This is a placeholder for parallel action execution
# In the actual implementation, this would import and call execute_parallel_actions
pass
def _crop_region_by_class(self,
frame: np.ndarray,
regions_dict: Dict[str, Any],
class_name: str) -> Optional[np.ndarray]:
"""Crop a specific region from frame based on detected class."""
if class_name not in regions_dict:
logger.warning(f"Class '{class_name}' not found in detected regions")
return None
bbox = regions_dict[class_name]["bbox"]
x1, y1, x2, y2 = bbox
# Validate bbox coordinates
if x2 <= x1 or y2 <= y1:
logger.warning(f"Invalid bbox for class {class_name}: {bbox}")
return None
try:
cropped = frame[y1:y2, x1:x2]
if cropped.size == 0:
logger.warning(f"Empty crop for class {class_name}")
return None
return cropped
except Exception as e:
logger.error(f"Error cropping region for class {class_name}: {e}")
return None
def _prepare_branch_context(self,
base_context: PipelineContext,
regions_dict: Dict[str, Any],
detection_result: Dict[str, Any]) -> Dict[str, Any]:
"""Prepare context for branch execution."""
branch_context = base_context.to_dict()
branch_context["regions_dict"] = regions_dict
# Pass session_id from detection_result to branch context for Redis actions
if "session_id" in detection_result:
branch_context["session_id"] = detection_result["session_id"]
logger.debug(f"Added session_id to branch context: {detection_result['session_id']}")
elif base_context.backend_session_id:
branch_context["session_id"] = base_context.backend_session_id
logger.debug(f"Added backend_session_id to branch context: {base_context.backend_session_id}")
return branch_context
def _execute_single_branch(self,
frame: np.ndarray,
branch: Dict[str, Any],
branch_context: Dict[str, Any],
regions_dict: Dict[str, Any]) -> BranchResult:
"""Execute a single branch."""
model_id = branch["modelId"]
try:
sub_frame = frame
crop_class = branch.get("cropClass")
logger.info(f"Starting branch: {model_id}, cropClass: {crop_class}")
# Handle cropping if required
if branch.get("crop", False) and crop_class:
if crop_class in regions_dict:
cropped = self._crop_region_by_class(frame, regions_dict, crop_class)
if cropped is not None:
sub_frame = cropped # Use cropped image without manual resizing
logger.debug(f"Successfully cropped {crop_class} region for {model_id} - model will handle resizing")
else:
return BranchResult(
model_id=model_id,
success=False,
error=f"Failed to crop {crop_class} region"
)
else:
return BranchResult(
model_id=model_id,
success=False,
error=f"Crop class {crop_class} not found in detected regions"
)
# Execute branch pipeline
result, _ = self.run_pipeline(sub_frame, branch, True, branch_context)
if result:
branch_result = BranchResult(
model_id=model_id,
success=True,
result=result
)
# Collect nested branch results if they exist
if "branch_results" in result:
branch_result.nested_results = result["branch_results"]
logger.info(f"Branch {model_id} completed: {result}")
return branch_result
else:
return BranchResult(
model_id=model_id,
success=False,
error="Branch returned no result"
)
except Exception as e:
logger.error(f"Error in branch {model_id}: {e}")
return BranchResult(
model_id=model_id,
success=False,
error=str(e)
)
def _process_branches_parallel(self,
frame: np.ndarray,
active_branches: List[Dict[str, Any]],
branch_context: Dict[str, Any],
regions_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Process branches in parallel."""
branch_results = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=len(active_branches)) as executor:
futures = {}
for branch in active_branches:
future = executor.submit(
self._execute_single_branch,
frame, branch, branch_context, regions_dict
)
futures[future] = branch
# Collect results
for future in concurrent.futures.as_completed(futures):
branch = futures[future]
try:
result = future.result()
if result.success and result.result:
branch_results[result.model_id] = result.result
# Collect nested branch results
for nested_id, nested_result in result.nested_results.items():
branch_results[nested_id] = nested_result
logger.info(f"Collected nested branch result: {nested_id} = {nested_result}")
else:
logger.error(f"Branch {result.model_id} failed: {result.error}")
except Exception as e:
logger.error(f"Branch {branch['modelId']} failed: {e}")
return branch_results
def _process_branches_sequential(self,
frame: np.ndarray,
active_branches: List[Dict[str, Any]],
branch_context: Dict[str, Any],
regions_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Process branches sequentially."""
branch_results = {}
for branch in active_branches:
result = self._execute_single_branch(frame, branch, branch_context, regions_dict)
if result.success and result.result:
branch_results[result.model_id] = result.result
# Collect nested branch results
for nested_id, nested_result in result.nested_results.items():
branch_results[nested_id] = nested_result
logger.info(f"Collected nested branch result: {nested_id} = {nested_result}")
else:
logger.error(f"Branch {result.model_id} failed: {result.error}")
return branch_results
def _filter_active_branches(self,
node: Dict[str, Any],
regions_dict: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Filter branches that should be triggered based on detected regions."""
active_branches = []
for br in node["branches"]:
trigger_classes = br.get("triggerClasses", [])
min_conf = br.get("minConfidence", 0)
logger.debug(f"Evaluating branch {br['modelId']}: trigger_classes={trigger_classes}, min_conf={min_conf}")
# Check if any detected class matches branch trigger
branch_triggered = False
for det_class in regions_dict:
det_confidence = regions_dict[det_class]["confidence"]
logger.debug(f" Checking detected class '{det_class}' (confidence={det_confidence:.3f}) against triggers {trigger_classes}")
if (det_class in trigger_classes and det_confidence >= min_conf):
active_branches.append(br)
branch_triggered = True
logger.info(f"Branch {br['modelId']} activated by class '{det_class}' (conf={det_confidence:.3f} >= {min_conf})")
break
if not branch_triggered:
logger.debug(f"Branch {br['modelId']} not triggered - no matching classes or insufficient confidence")
return active_branches
def _process_branches(self,
frame: np.ndarray,
node: Dict[str, Any],
detection_result: Dict[str, Any],
regions_dict: Dict[str, Any],
pipeline_context: PipelineContext) -> Dict[str, Any]:
"""Process all branches for a node."""
if not node.get("branches"):
return {}
# Filter branches that should be triggered
active_branches = self._filter_active_branches(node, regions_dict)
if not active_branches:
return {}
# Prepare branch context
branch_context = self._prepare_branch_context(pipeline_context, regions_dict, detection_result)
# Execute branches
if node.get("parallel", False) or any(br.get("parallel", False) for br in active_branches):
# Run branches in parallel
branch_results = self._process_branches_parallel(frame, active_branches, branch_context, regions_dict)
else:
# Run branches sequentially
branch_results = self._process_branches_sequential(frame, active_branches, branch_context, regions_dict)
return branch_results
def run_pipeline(self,
frame: np.ndarray,
node: Dict[str, Any],
return_bbox: bool = False,
context: Optional[Dict[str, Any]] = None,
validated_detection: Optional[Dict[str, Any]] = None) -> Union[Dict[str, Any], Tuple[Dict[str, Any], List[int]], Tuple[None, None]]:
"""
Run enhanced pipeline that supports:
- Multi-class detection (detecting multiple classes simultaneously)
- Parallel branch processing
- Region-based actions and cropping
- Context passing for session/camera information
Args:
frame: Input frame for processing
node: Pipeline node configuration
return_bbox: Whether to return bounding box with result
context: Optional context information
validated_detection: Optional pre-validated detection to use
Returns:
Detection result, optionally with bounding box
"""
try:
# Extract context information
pipeline_context = self._extract_context(context)
model_id = node.get("modelId", "unknown")
if pipeline_context.backend_session_id:
logger.info(f"🔑 PIPELINE USING BACKEND SESSION_ID: {pipeline_context.backend_session_id} for camera {pipeline_context.camera_id}")
task = getattr(node["model"], "task", None)
# ─── Classification stage ───────────────────────────────────
if task == "classify":
return self._handle_classification_task(frame, node, pipeline_context, return_bbox)
# ─── Session management check ───────────────────────────────────────
camera_inactive_result = self._check_camera_active(pipeline_context.camera_id, model_id, return_bbox)
if camera_inactive_result is not None:
return camera_inactive_result
# ─── Detection stage ───
all_detections, regions_dict, track_validation_result = self._run_detection_stage(
frame, node, pipeline_context, validated_detection
)
if not all_detections:
logger.debug("No detections from structured detection function - sending 'none' detection")
none_detection = {
"class": "none",
"confidence": 1.0,
"bbox": [0, 0, 0, 0],
"branch_results": {}
}
return (none_detection, [0, 0, 0, 0]) if return_bbox else none_detection
# ─── Track-Based Validation System ────────────────────────
tracking_validation_result = self._validate_tracking_requirements(
node, track_validation_result, pipeline_context, return_bbox
)
if tracking_validation_result is not None:
return tracking_validation_result
# ─── Pre-validate pipeline execution ────────────────────────
pipeline_valid, missing_branches = validate_pipeline_execution(node, regions_dict)
if not pipeline_valid:
logger.error(f"Pipeline execution validation FAILED - required branches {missing_branches} cannot execute")
logger.error("Aborting pipeline: no Redis actions or database records will be created")
return (None, None) if return_bbox else None
# ─── Execute actions with region information ────────────────
detection_result = {
"detections": all_detections,
"regions": regions_dict,
**pipeline_context.to_dict()
}
# ─── Database operations ────
self._handle_database_operations(node, detection_result, regions_dict, pipeline_context)
# Execute actions for root node only if it doesn't have branches
# Branch nodes with actions will execute them after branch processing
if not node.get("branches") or node.get("modelId") == "yolo11n":
self._execute_node_actions(node, frame, detection_result, regions_dict)
# ─── Branch processing ─────────────────────────────
branch_results = self._process_branches(frame, node, detection_result, regions_dict, pipeline_context)
detection_result["branch_results"] = branch_results
# ─── Execute Parallel Actions ───────────────────────────────
if node.get("parallelActions") and "branch_results" in detection_result:
self._execute_parallel_actions(node, frame, detection_result, regions_dict)
# ─── Auto-enable occupancy mode after successful pipeline completion ─────────────────
occupancy_detector(pipeline_context.camera_id, model_id, enable=True)
logger.info(f"✅ Camera {pipeline_context.camera_id}: Pipeline completed, detection data will be sent to backend")
logger.info(f"🛑 Camera {pipeline_context.camera_id}: Model will stop inference for future frames")
logger.info(f"📡 Backend sessionId will be handled when received via WebSocket")
# ─── Execute actions after successful detection AND branch processing ──────────
# This ensures detection nodes (like frontal_detection_v1) execute their actions
# after completing both detection and branch processing
if node.get("actions") and regions_dict and node.get("modelId") != "yolo11n":
# Execute actions for branch detection nodes, skip root to avoid duplication
logger.debug(f"Executing post-detection actions for branch node {node.get('modelId')}")
self._execute_node_actions(node, frame, detection_result, regions_dict)
# ─── Return detection result ────────────────────────────────
primary_detection = max(all_detections, key=lambda x: x["confidence"])
primary_bbox = primary_detection["bbox"]
# Add branch results and session_id to primary detection for compatibility
if "branch_results" in detection_result:
primary_detection["branch_results"] = detection_result["branch_results"]
if "session_id" in detection_result:
primary_detection["session_id"] = detection_result["session_id"]
return (primary_detection, primary_bbox) if return_bbox else primary_detection
except Exception as e:
pipeline_id = node.get("modelId", "unknown")
raise create_pipeline_error(pipeline_id, "pipeline_execution", e)
# Global pipeline executor instance
pipeline_executor = PipelineExecutor()
# ===== CONVENIENCE FUNCTIONS =====
# These provide the same interface as the original functions in pympta.py
def run_pipeline(frame: np.ndarray,
node: Dict[str, Any],
return_bbox: bool = False,
context: Optional[Dict[str, Any]] = None,
validated_detection: Optional[Dict[str, Any]] = None) -> Union[Dict[str, Any], Tuple[Dict[str, Any], List[int]], Tuple[None, None]]:
"""Run enhanced pipeline using global executor instance."""
return pipeline_executor.run_pipeline(frame, node, return_bbox, context, validated_detection)
def crop_region_by_class(frame: np.ndarray, regions_dict: Dict[str, Any], class_name: str) -> Optional[np.ndarray]:
"""Crop a specific region from frame based on detected class."""
return pipeline_executor._crop_region_by_class(frame, regions_dict, class_name)

View file

@ -0,0 +1,345 @@
"""
Camera connection state monitoring and management.
This module provides centralized tracking of camera connection states,
error handling, and disconnection notifications.
"""
import time
import logging
from typing import Dict, Any, Optional
from dataclasses import dataclass, field
from datetime import datetime
from ..core.exceptions import CameraConnectionError
logger = logging.getLogger(__name__)
@dataclass
class CameraConnectionState:
"""Represents the connection state of a single camera."""
connected: bool = True
last_error: Optional[str] = None
last_error_time: Optional[float] = None
consecutive_failures: int = 0
disconnection_notified: bool = False
last_successful_frame: Optional[float] = None
connection_start_time: Optional[float] = field(default_factory=time.time)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"connected": self.connected,
"last_error": self.last_error,
"last_error_time": self.last_error_time,
"consecutive_failures": self.consecutive_failures,
"disconnection_notified": self.disconnection_notified,
"last_successful_frame": self.last_successful_frame,
"connection_start_time": self.connection_start_time
}
class CameraMonitor:
"""
Monitors camera connection states and handles disconnection notifications.
This class provides a centralized way to track camera connection status,
handle error states, and determine when to notify about disconnections.
"""
def __init__(self, failure_threshold: int = 3):
"""
Initialize camera monitor.
Args:
failure_threshold: Number of consecutive failures before considering disconnected
"""
self.failure_threshold = failure_threshold
self._camera_states: Dict[str, CameraConnectionState] = {}
self._lock = None # Will be initialized when needed for thread safety
def _ensure_thread_safety(self):
"""Initialize thread safety if not already done."""
if self._lock is None:
import threading
self._lock = threading.RLock()
def get_camera_state(self, camera_id: str) -> CameraConnectionState:
"""
Get or create camera connection state.
Args:
camera_id: Unique camera identifier
Returns:
Current connection state for the camera
"""
self._ensure_thread_safety()
with self._lock:
if camera_id not in self._camera_states:
self._camera_states[camera_id] = CameraConnectionState()
logger.debug(f"Initialized connection state for camera {camera_id}")
return self._camera_states[camera_id]
def set_camera_connected(self,
camera_id: str,
connected: bool = True,
error_msg: Optional[str] = None) -> None:
"""
Set camera connection state and track error information.
Args:
camera_id: Unique camera identifier
connected: Whether camera is connected
error_msg: Error message if disconnected
"""
self._ensure_thread_safety()
current_time = time.time()
with self._lock:
state = self.get_camera_state(camera_id)
if connected:
# Camera is now connected
if not state.connected:
logger.info(f"Camera {camera_id} reconnected successfully")
state.connected = True
state.last_successful_frame = current_time
state.consecutive_failures = 0
state.disconnection_notified = False
state.last_error = None
state.last_error_time = None
else:
# Camera is now disconnected
state.connected = False
state.consecutive_failures += 1
state.last_error = error_msg
state.last_error_time = current_time
if state.consecutive_failures == 1:
logger.warning(f"Camera {camera_id} connection lost: {error_msg}")
elif state.consecutive_failures >= self.failure_threshold:
logger.error(f"Camera {camera_id} has {state.consecutive_failures} consecutive failures")
logger.debug(f"Camera {camera_id} state updated - failures: {state.consecutive_failures}")
def is_camera_connected(self, camera_id: str) -> bool:
"""
Check if camera is currently connected.
Args:
camera_id: Unique camera identifier
Returns:
True if camera is connected, False otherwise
"""
self._ensure_thread_safety()
with self._lock:
state = self._camera_states.get(camera_id)
return state.connected if state else True # Default to connected for new cameras
def should_notify_disconnection(self, camera_id: str) -> bool:
"""
Check if we should notify backend about disconnection.
A disconnection notification should be sent when:
1. Camera is disconnected
2. We haven't already notified about this disconnection
3. We have enough consecutive failures
Args:
camera_id: Unique camera identifier
Returns:
True if disconnection notification should be sent
"""
self._ensure_thread_safety()
with self._lock:
state = self._camera_states.get(camera_id)
if not state:
return False
is_disconnected = not state.connected
not_yet_notified = not state.disconnection_notified
has_enough_failures = state.consecutive_failures >= self.failure_threshold
should_notify = is_disconnected and not_yet_notified and has_enough_failures
if should_notify:
logger.info(f"Camera {camera_id} qualifies for disconnection notification - "
f"failures: {state.consecutive_failures}, error: {state.last_error}")
return should_notify
def mark_disconnection_notified(self, camera_id: str) -> None:
"""
Mark that we've notified backend about this disconnection.
Args:
camera_id: Unique camera identifier
"""
self._ensure_thread_safety()
with self._lock:
if camera_id in self._camera_states:
self._camera_states[camera_id].disconnection_notified = True
logger.debug(f"Marked disconnection notification sent for camera {camera_id}")
def get_connection_stats(self, camera_id: str) -> Dict[str, Any]:
"""
Get comprehensive connection statistics for a camera.
Args:
camera_id: Unique camera identifier
Returns:
Dictionary with connection statistics
"""
self._ensure_thread_safety()
with self._lock:
state = self._camera_states.get(camera_id)
if not state:
return {"error": "Camera not found"}
current_time = time.time()
stats = state.to_dict()
# Add computed stats
if state.connection_start_time:
stats["uptime_seconds"] = current_time - state.connection_start_time
if state.last_successful_frame:
stats["seconds_since_last_frame"] = current_time - state.last_successful_frame
if state.last_error_time:
stats["seconds_since_last_error"] = current_time - state.last_error_time
return stats
def reset_camera_state(self, camera_id: str) -> None:
"""
Reset camera state to initial connected state.
Args:
camera_id: Unique camera identifier
"""
self._ensure_thread_safety()
with self._lock:
if camera_id in self._camera_states:
del self._camera_states[camera_id]
logger.info(f"Reset connection state for camera {camera_id}")
def cleanup_inactive_cameras(self, inactive_threshold_seconds: int = 3600) -> int:
"""
Remove states for cameras inactive for too long.
Args:
inactive_threshold_seconds: Seconds of inactivity before cleanup
Returns:
Number of camera states cleaned up
"""
self._ensure_thread_safety()
current_time = time.time()
cleanup_count = 0
with self._lock:
camera_ids_to_remove = []
for camera_id, state in self._camera_states.items():
last_activity = max(
state.connection_start_time or 0,
state.last_successful_frame or 0,
state.last_error_time or 0
)
if current_time - last_activity > inactive_threshold_seconds:
camera_ids_to_remove.append(camera_id)
for camera_id in camera_ids_to_remove:
del self._camera_states[camera_id]
cleanup_count += 1
logger.debug(f"Cleaned up inactive camera state for {camera_id}")
if cleanup_count > 0:
logger.info(f"Cleaned up {cleanup_count} inactive camera states")
return cleanup_count
def get_all_camera_states(self) -> Dict[str, Dict[str, Any]]:
"""
Get connection states for all monitored cameras.
Returns:
Dictionary mapping camera IDs to their connection states
"""
self._ensure_thread_safety()
with self._lock:
return {
camera_id: state.to_dict()
for camera_id, state in self._camera_states.items()
}
def get_disconnected_cameras(self) -> list[str]:
"""
Get list of currently disconnected camera IDs.
Returns:
List of camera IDs that are currently disconnected
"""
self._ensure_thread_safety()
with self._lock:
return [
camera_id for camera_id, state in self._camera_states.items()
if not state.connected
]
# Global camera monitor instance
camera_monitor = CameraMonitor()
# ===== CONVENIENCE FUNCTIONS =====
# These provide the same interface as the original functions in app.py
def set_camera_connected(camera_id: str, connected: bool = True, error_msg: Optional[str] = None) -> None:
"""Set camera connection state and track error information."""
camera_monitor.set_camera_connected(camera_id, connected, error_msg)
def is_camera_connected(camera_id: str) -> bool:
"""Check if camera is currently connected."""
return camera_monitor.is_camera_connected(camera_id)
def should_notify_disconnection(camera_id: str) -> bool:
"""Check if we should notify backend about disconnection."""
return camera_monitor.should_notify_disconnection(camera_id)
def mark_disconnection_notified(camera_id: str) -> None:
"""Mark that we've notified backend about this disconnection."""
camera_monitor.mark_disconnection_notified(camera_id)
def get_connection_stats(camera_id: str) -> Dict[str, Any]:
"""Get comprehensive connection statistics for a camera."""
return camera_monitor.get_connection_stats(camera_id)
def reset_camera_state(camera_id: str) -> None:
"""Reset camera state to initial connected state."""
camera_monitor.reset_camera_state(camera_id)

View file

@ -0,0 +1,476 @@
"""
Frame reading implementations for RTSP and HTTP snapshot streams.
This module provides thread-safe frame readers for different camera stream types.
"""
import cv2
import time
import queue
import logging
import requests
import threading
from typing import Optional, Any
import numpy as np
from ..core.constants import (
DEFAULT_RECONNECT_INTERVAL_SEC,
DEFAULT_MAX_RETRIES,
HTTP_SNAPSHOT_TIMEOUT,
SHARED_STREAM_BUFFER_SIZE
)
from ..core.exceptions import (
CameraConnectionError,
FrameReadError,
create_stream_error
)
logger = logging.getLogger(__name__)
def fetch_snapshot(url: str, timeout: int = HTTP_SNAPSHOT_TIMEOUT) -> Optional[np.ndarray]:
"""
Fetch a single snapshot from HTTP/HTTPS URL.
Args:
url: HTTP/HTTPS URL to fetch snapshot from
timeout: Request timeout in seconds
Returns:
Decoded image frame or None if fetch failed
"""
try:
logger.debug(f"Fetching snapshot from {url}")
response = requests.get(url, timeout=timeout, stream=True)
response.raise_for_status()
# Check if response has content
if not response.content:
logger.warning(f"Empty response from snapshot URL: {url}")
return None
# Decode image from response bytes
img_array = np.frombuffer(response.content, dtype=np.uint8)
frame = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
if frame is None:
logger.warning(f"Failed to decode image from snapshot URL: {url}")
return None
logger.debug(f"Successfully fetched snapshot from {url}, shape: {frame.shape}")
return frame
except requests.exceptions.Timeout:
logger.warning(f"Timeout fetching snapshot from {url}")
return None
except requests.exceptions.ConnectionError:
logger.warning(f"Connection error fetching snapshot from {url}")
return None
except requests.exceptions.HTTPError as e:
logger.warning(f"HTTP error fetching snapshot from {url}: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error fetching snapshot from {url}: {e}")
return None
class RTSPFrameReader:
"""Thread-safe RTSP frame reader."""
def __init__(self,
camera_id: str,
rtsp_url: str,
buffer: queue.Queue,
stop_event: threading.Event,
reconnect_interval: int = DEFAULT_RECONNECT_INTERVAL_SEC,
max_retries: int = DEFAULT_MAX_RETRIES,
connection_callback=None):
"""
Initialize RTSP frame reader.
Args:
camera_id: Unique camera identifier
rtsp_url: RTSP stream URL
buffer: Queue to put frames into
stop_event: Event to signal thread shutdown
reconnect_interval: Seconds between reconnection attempts
max_retries: Maximum retry attempts (-1 for unlimited)
connection_callback: Callback function for connection state changes
"""
self.camera_id = camera_id
self.rtsp_url = rtsp_url
self.buffer = buffer
self.stop_event = stop_event
self.reconnect_interval = reconnect_interval
self.max_retries = max_retries
self.connection_callback = connection_callback
self.cap: Optional[cv2.VideoCapture] = None
self.retries = 0
self.frame_count = 0
self.last_log_time = time.time()
def _set_connection_state(self, connected: bool, error_msg: Optional[str] = None):
"""Update connection state via callback."""
if self.connection_callback:
self.connection_callback(self.camera_id, connected, error_msg)
def _initialize_capture(self) -> bool:
"""Initialize video capture."""
try:
self.cap = cv2.VideoCapture(self.rtsp_url)
if self.cap.isOpened():
# Log camera properties
width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = self.cap.get(cv2.CAP_PROP_FPS)
logger.info(f"Camera {self.camera_id} opened: {width}x{height}, FPS: {fps}")
self._set_connection_state(True)
return True
else:
logger.error(f"Camera {self.camera_id} failed to open")
self._set_connection_state(False, "Failed to open camera")
return False
except Exception as e:
error_msg = f"Failed to initialize capture: {e}"
logger.error(f"Camera {self.camera_id}: {error_msg}")
self._set_connection_state(False, error_msg)
return False
def _reconnect(self) -> bool:
"""Attempt to reconnect to RTSP stream."""
if self.cap:
self.cap.release()
self.cap = None
logger.info(f"Attempting to reconnect RTSP stream for camera: {self.camera_id}")
time.sleep(self.reconnect_interval)
return self._initialize_capture()
def _read_frame(self) -> Optional[np.ndarray]:
"""Read a single frame from the stream."""
if not self.cap or not self.cap.isOpened():
return None
try:
ret, frame = self.cap.read()
if not ret:
return None
# Update statistics
self.frame_count += 1
current_time = time.time()
# Log frame stats every 5 seconds
if current_time - self.last_log_time > 5:
elapsed = current_time - self.last_log_time
logger.info(f"Camera {self.camera_id}: Read {self.frame_count} frames in {elapsed:.1f}s")
self.frame_count = 0
self.last_log_time = current_time
return frame
except cv2.error as e:
raise FrameReadError(f"OpenCV error reading frame: {e}")
def _put_frame(self, frame: np.ndarray):
"""Put frame into buffer, removing old frame if necessary."""
# Remove old frame if buffer is full
if not self.buffer.empty():
try:
self.buffer.get_nowait()
logger.debug(f"Removed old frame from buffer for camera {self.camera_id}")
except queue.Empty:
pass
self.buffer.put(frame)
logger.debug(f"Added frame to buffer for camera {self.camera_id}, buffer size: {self.buffer.qsize()}")
def run(self):
"""Main frame reading loop."""
logger.info(f"Starting RTSP frame reader for camera {self.camera_id}")
try:
# Initialize capture
if not self._initialize_capture():
return
while not self.stop_event.is_set():
try:
frame = self._read_frame()
if frame is None:
# Connection lost
self.retries += 1
error_msg = f"Connection lost, retry {self.retries}/{self.max_retries}"
logger.warning(f"Camera {self.camera_id}: {error_msg}")
self._set_connection_state(False, error_msg)
# Check max retries
if self.retries > self.max_retries and self.max_retries != -1:
logger.error(f"Camera {self.camera_id}: Max retries reached, stopping")
self._set_connection_state(False, "Max retries reached")
break
# Attempt reconnection
if not self._reconnect():
continue
# Reset retry counter on successful reconnection
self.retries = 0
continue
# Successfully read frame
logger.debug(f"Camera {self.camera_id}: Read frame, shape: {frame.shape}")
self.retries = 0
self._set_connection_state(True)
self._put_frame(frame)
# Short sleep to avoid CPU overuse
time.sleep(0.01)
except FrameReadError as e:
self.retries += 1
error_msg = f"Frame read error: {e}"
logger.error(f"Camera {self.camera_id}: {error_msg}")
self._set_connection_state(False, error_msg)
# Check max retries
if self.retries > self.max_retries and self.max_retries != -1:
logger.error(f"Camera {self.camera_id}: Max retries reached after error")
self._set_connection_state(False, "Max retries reached after error")
break
# Attempt reconnection
if not self._reconnect():
continue
except Exception as e:
error_msg = f"Unexpected error: {str(e)}"
logger.error(f"Camera {self.camera_id}: {error_msg}", exc_info=True)
self._set_connection_state(False, error_msg)
break
except Exception as e:
logger.error(f"Error in RTSP frame reader for camera {self.camera_id}: {str(e)}", exc_info=True)
finally:
logger.info(f"RTSP frame reader for camera {self.camera_id} is exiting")
if self.cap and self.cap.isOpened():
self.cap.release()
class SnapshotFrameReader:
"""Thread-safe HTTP snapshot frame reader."""
def __init__(self,
camera_id: str,
snapshot_url: str,
snapshot_interval: int,
buffer: queue.Queue,
stop_event: threading.Event,
max_retries: int = DEFAULT_MAX_RETRIES,
connection_callback=None):
"""
Initialize snapshot frame reader.
Args:
camera_id: Unique camera identifier
snapshot_url: HTTP/HTTPS snapshot URL
snapshot_interval: Interval between snapshots in milliseconds
buffer: Queue to put frames into
stop_event: Event to signal thread shutdown
max_retries: Maximum retry attempts (-1 for unlimited)
connection_callback: Callback function for connection state changes
"""
self.camera_id = camera_id
self.snapshot_url = snapshot_url
self.snapshot_interval = snapshot_interval
self.buffer = buffer
self.stop_event = stop_event
self.max_retries = max_retries
self.connection_callback = connection_callback
self.retries = 0
self.consecutive_failures = 0
self.frame_count = 0
self.last_log_time = time.time()
def _set_connection_state(self, connected: bool, error_msg: Optional[str] = None):
"""Update connection state via callback."""
if self.connection_callback:
self.connection_callback(self.camera_id, connected, error_msg)
def _calculate_backoff_delay(self) -> float:
"""Calculate exponential backoff delay based on consecutive failures."""
interval_seconds = self.snapshot_interval / 1000.0
backoff_delay = min(30, max(1, min(2 ** min(self.consecutive_failures - 1, 6), interval_seconds * 2)))
return backoff_delay
def _test_connectivity(self):
"""Test connectivity to snapshot URL."""
if self.consecutive_failures % 5 == 1: # Every 5th failure
try:
test_response = requests.get(self.snapshot_url, timeout=(2, 5), stream=False)
logger.info(f"Camera {self.camera_id}: Connectivity test result: {test_response.status_code}")
except Exception as test_error:
logger.warning(f"Camera {self.camera_id}: Connectivity test failed: {test_error}")
def _put_frame(self, frame: np.ndarray):
"""Put frame into buffer, removing old frame if necessary."""
# Remove old frame if buffer is full
if not self.buffer.empty():
try:
self.buffer.get_nowait()
logger.debug(f"Removed old snapshot from buffer for camera {self.camera_id}")
except queue.Empty:
pass
self.buffer.put(frame)
logger.debug(f"Added snapshot to buffer for camera {self.camera_id}, buffer size: {self.buffer.qsize()}")
def run(self):
"""Main snapshot reading loop."""
logger.info(f"Starting snapshot reader for camera {self.camera_id} from {self.snapshot_url}")
interval_seconds = self.snapshot_interval / 1000.0
logger.info(f"Snapshot interval for camera {self.camera_id}: {interval_seconds}s")
# Initialize connection state
self._set_connection_state(True)
try:
while not self.stop_event.is_set():
try:
start_time = time.time()
frame = fetch_snapshot(self.snapshot_url)
if frame is None:
# Failed to fetch snapshot
self.consecutive_failures += 1
self.retries += 1
error_msg = f"Failed to fetch snapshot, consecutive failures: {self.consecutive_failures}"
logger.warning(f"Camera {self.camera_id}: {error_msg}")
self._set_connection_state(False, error_msg)
# Test connectivity periodically
self._test_connectivity()
# Check max retries
if self.retries > self.max_retries and self.max_retries != -1:
logger.error(f"Camera {self.camera_id}: Max retries reached for snapshot, stopping")
self._set_connection_state(False, "Max retries reached for snapshot")
break
# Exponential backoff
backoff_delay = self._calculate_backoff_delay()
logger.debug(f"Camera {self.camera_id}: Backing off for {backoff_delay:.1f}s")
if self.stop_event.wait(backoff_delay):
break # Exit if stop event set during backoff
continue
# Successfully fetched snapshot
self.consecutive_failures = 0 # Reset on success
self.retries = 0
self.frame_count += 1
current_time = time.time()
# Log frame stats every 5 seconds
if current_time - self.last_log_time > 5:
elapsed = current_time - self.last_log_time
logger.info(f"Camera {self.camera_id}: Fetched {self.frame_count} snapshots in {elapsed:.1f}s")
self.frame_count = 0
self.last_log_time = current_time
logger.debug(f"Camera {self.camera_id}: Fetched snapshot, shape: {frame.shape}")
self._set_connection_state(True)
self._put_frame(frame)
# Wait for interval
elapsed = time.time() - start_time
sleep_time = max(interval_seconds - elapsed, 0)
if sleep_time > 0:
if self.stop_event.wait(sleep_time):
break # Exit if stop event set during sleep
except Exception as e:
self.consecutive_failures += 1
self.retries += 1
error_msg = f"Unexpected error: {str(e)}"
logger.error(f"Camera {self.camera_id}: {error_msg}", exc_info=True)
self._set_connection_state(False, error_msg)
# Check max retries
if self.retries > self.max_retries and self.max_retries != -1:
logger.error(f"Camera {self.camera_id}: Max retries reached after error")
self._set_connection_state(False, "Max retries reached after error")
break
# Exponential backoff for exceptions too
backoff_delay = self._calculate_backoff_delay()
logger.debug(f"Camera {self.camera_id}: Exception backoff for {backoff_delay:.1f}s")
if self.stop_event.wait(backoff_delay):
break # Exit if stop event set during backoff
except Exception as e:
logger.error(f"Error in snapshot reader for camera {self.camera_id}: {str(e)}", exc_info=True)
finally:
logger.info(f"Snapshot reader for camera {self.camera_id} is exiting")
def create_frame_reader_thread(camera_id: str,
rtsp_url: Optional[str] = None,
snapshot_url: Optional[str] = None,
snapshot_interval: Optional[int] = None,
buffer: Optional[queue.Queue] = None,
stop_event: Optional[threading.Event] = None,
connection_callback=None) -> Optional[threading.Thread]:
"""
Create appropriate frame reader thread based on stream type.
Args:
camera_id: Unique camera identifier
rtsp_url: RTSP stream URL (for RTSP streams)
snapshot_url: HTTP snapshot URL (for snapshot streams)
snapshot_interval: Snapshot interval in milliseconds
buffer: Frame buffer queue
stop_event: Thread stop event
connection_callback: Connection state callback
Returns:
Configured thread ready to start, or None if invalid parameters
"""
if not buffer:
buffer = queue.Queue(maxsize=SHARED_STREAM_BUFFER_SIZE)
if not stop_event:
stop_event = threading.Event()
if snapshot_url and snapshot_interval:
# Create snapshot reader
reader = SnapshotFrameReader(
camera_id=camera_id,
snapshot_url=snapshot_url,
snapshot_interval=snapshot_interval,
buffer=buffer,
stop_event=stop_event,
connection_callback=connection_callback
)
thread = threading.Thread(target=reader.run, name=f"snapshot-{camera_id}")
elif rtsp_url:
# Create RTSP reader
reader = RTSPFrameReader(
camera_id=camera_id,
rtsp_url=rtsp_url,
buffer=buffer,
stop_event=stop_event,
connection_callback=connection_callback
)
thread = threading.Thread(target=reader.run, name=f"rtsp-{camera_id}")
else:
logger.error(f"No valid URL provided for camera {camera_id}")
return None
thread.daemon = True
return thread

View file

@ -0,0 +1,572 @@
"""
Stream lifecycle management and coordination.
This module provides centralized management of camera streams including
lifecycle management, resource allocation, and stream coordination.
"""
import time
import queue
import logging
import threading
from typing import Dict, List, Any, Optional, Tuple, Set
from dataclasses import dataclass, field
from datetime import datetime
from ..core.constants import (
DEFAULT_MAX_STREAMS,
SHARED_STREAM_BUFFER_SIZE,
DEFAULT_RECONNECT_INTERVAL_SEC,
DEFAULT_MAX_RETRIES
)
from ..core.exceptions import StreamError, create_stream_error
from ..streams.frame_reader import create_frame_reader_thread
from ..streams.camera_monitor import set_camera_connected
logger = logging.getLogger(__name__)
@dataclass
class StreamInfo:
"""Information about a single camera stream."""
camera_id: str
stream_url: str
stream_type: str # "rtsp" or "snapshot"
snapshot_interval: Optional[int] = None
buffer: Optional[queue.Queue] = None
stop_event: Optional[threading.Event] = None
thread: Optional[threading.Thread] = None
subscribers: Set[str] = field(default_factory=set)
created_at: float = field(default_factory=time.time)
last_frame_time: Optional[float] = None
frame_count: int = 0
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"camera_id": self.camera_id,
"stream_url": self.stream_url,
"stream_type": self.stream_type,
"snapshot_interval": self.snapshot_interval,
"subscriber_count": len(self.subscribers),
"subscribers": list(self.subscribers),
"created_at": self.created_at,
"last_frame_time": self.last_frame_time,
"frame_count": self.frame_count,
"is_active": self.thread is not None and self.thread.is_alive()
}
@dataclass
class StreamSubscription:
"""Information about a stream subscription."""
subscription_id: str
camera_id: str
subscriber_id: str
created_at: float = field(default_factory=time.time)
last_access: float = field(default_factory=time.time)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"subscription_id": self.subscription_id,
"camera_id": self.camera_id,
"subscriber_id": self.subscriber_id,
"created_at": self.created_at,
"last_access": self.last_access
}
class StreamManager:
"""
Manages camera stream lifecycle and resource allocation.
This class provides centralized management of camera streams including:
- Stream lifecycle management (start/stop/restart)
- Resource allocation and sharing
- Subscriber management
- Connection state monitoring
"""
def __init__(self, max_streams: int = DEFAULT_MAX_STREAMS):
"""
Initialize stream manager.
Args:
max_streams: Maximum number of concurrent streams
"""
self.max_streams = max_streams
self._streams: Dict[str, StreamInfo] = {}
self._subscriptions: Dict[str, StreamSubscription] = {}
self._lock = None
def _ensure_thread_safety(self):
"""Initialize thread safety if not already done."""
if self._lock is None:
import threading
self._lock = threading.RLock()
def _connection_state_callback(self, camera_id: str, connected: bool, error_msg: Optional[str] = None):
"""Callback for connection state changes."""
set_camera_connected(camera_id, connected, error_msg)
def _create_stream_info(self,
camera_id: str,
rtsp_url: Optional[str] = None,
snapshot_url: Optional[str] = None,
snapshot_interval: Optional[int] = None) -> StreamInfo:
"""Create StreamInfo object based on stream type."""
if snapshot_url and snapshot_interval:
return StreamInfo(
camera_id=camera_id,
stream_url=snapshot_url,
stream_type="snapshot",
snapshot_interval=snapshot_interval,
buffer=queue.Queue(maxsize=SHARED_STREAM_BUFFER_SIZE),
stop_event=threading.Event()
)
elif rtsp_url:
return StreamInfo(
camera_id=camera_id,
stream_url=rtsp_url,
stream_type="rtsp",
buffer=queue.Queue(maxsize=SHARED_STREAM_BUFFER_SIZE),
stop_event=threading.Event()
)
else:
raise ValueError("Must provide either RTSP URL or snapshot URL with interval")
def create_subscription(self,
subscription_id: str,
camera_id: str,
subscriber_id: str,
rtsp_url: Optional[str] = None,
snapshot_url: Optional[str] = None,
snapshot_interval: Optional[int] = None) -> bool:
"""
Create a stream subscription.
Args:
subscription_id: Unique subscription identifier
camera_id: Camera identifier
subscriber_id: Subscriber identifier
rtsp_url: RTSP stream URL (for RTSP streams)
snapshot_url: HTTP snapshot URL (for snapshot streams)
snapshot_interval: Snapshot interval in milliseconds
Returns:
True if subscription was created successfully
"""
self._ensure_thread_safety()
with self._lock:
try:
# Check if subscription already exists
if subscription_id in self._subscriptions:
logger.warning(f"Subscription {subscription_id} already exists")
return False
# Check stream limit
if len(self._streams) >= self.max_streams and camera_id not in self._streams:
logger.error(f"Maximum streams ({self.max_streams}) reached, cannot create new stream for camera {camera_id}")
return False
# Create or get existing stream
if camera_id not in self._streams:
stream_info = self._create_stream_info(
camera_id, rtsp_url, snapshot_url, snapshot_interval
)
self._streams[camera_id] = stream_info
# Create and start frame reader thread
thread = create_frame_reader_thread(
camera_id=camera_id,
rtsp_url=rtsp_url,
snapshot_url=snapshot_url,
snapshot_interval=snapshot_interval,
buffer=stream_info.buffer,
stop_event=stream_info.stop_event,
connection_callback=self._connection_state_callback
)
if thread:
stream_info.thread = thread
thread.start()
logger.info(f"Created new {stream_info.stream_type} stream for camera {camera_id}")
else:
# Clean up failed stream
del self._streams[camera_id]
return False
# Add subscriber to stream
stream_info = self._streams[camera_id]
stream_info.subscribers.add(subscription_id)
# Create subscription record
subscription = StreamSubscription(
subscription_id=subscription_id,
camera_id=camera_id,
subscriber_id=subscriber_id
)
self._subscriptions[subscription_id] = subscription
logger.info(f"Created subscription {subscription_id} for camera {camera_id}, subscribers: {len(stream_info.subscribers)}")
return True
except Exception as e:
logger.error(f"Error creating subscription {subscription_id}: {e}")
return False
def remove_subscription(self, subscription_id: str) -> bool:
"""
Remove a stream subscription.
Args:
subscription_id: Unique subscription identifier
Returns:
True if subscription was removed successfully
"""
self._ensure_thread_safety()
with self._lock:
if subscription_id not in self._subscriptions:
logger.warning(f"Subscription {subscription_id} not found")
return False
subscription = self._subscriptions[subscription_id]
camera_id = subscription.camera_id
# Remove subscription
del self._subscriptions[subscription_id]
# Remove subscriber from stream if stream exists
if camera_id in self._streams:
stream_info = self._streams[camera_id]
stream_info.subscribers.discard(subscription_id)
logger.info(f"Removed subscription {subscription_id} for camera {camera_id}, remaining subscribers: {len(stream_info.subscribers)}")
# Stop stream if no more subscribers
if not stream_info.subscribers:
self._stop_stream(camera_id)
return True
def _stop_stream(self, camera_id: str) -> None:
"""Stop a stream and clean up resources."""
if camera_id not in self._streams:
return
stream_info = self._streams[camera_id]
# Signal thread to stop
if stream_info.stop_event:
stream_info.stop_event.set()
# Wait for thread to finish
if stream_info.thread and stream_info.thread.is_alive():
stream_info.thread.join(timeout=5)
if stream_info.thread.is_alive():
logger.warning(f"Stream thread for camera {camera_id} did not stop gracefully")
# Clean up
del self._streams[camera_id]
logger.info(f"Stopped {stream_info.stream_type} stream for camera {camera_id}")
def get_frame(self, subscription_id: str, timeout: float = 0.1) -> Optional[Any]:
"""
Get the latest frame for a subscription.
Args:
subscription_id: Unique subscription identifier
timeout: Timeout for frame retrieval in seconds
Returns:
Latest frame or None if not available
"""
self._ensure_thread_safety()
with self._lock:
if subscription_id not in self._subscriptions:
return None
subscription = self._subscriptions[subscription_id]
camera_id = subscription.camera_id
if camera_id not in self._streams:
return None
stream_info = self._streams[camera_id]
subscription.last_access = time.time()
try:
frame = stream_info.buffer.get(timeout=timeout)
stream_info.last_frame_time = time.time()
stream_info.frame_count += 1
return frame
except queue.Empty:
return None
except Exception as e:
logger.error(f"Error getting frame for subscription {subscription_id}: {e}")
return None
def is_stream_active(self, camera_id: str) -> bool:
"""
Check if a stream is active.
Args:
camera_id: Camera identifier
Returns:
True if stream is active
"""
self._ensure_thread_safety()
with self._lock:
if camera_id not in self._streams:
return False
stream_info = self._streams[camera_id]
return stream_info.thread is not None and stream_info.thread.is_alive()
def get_stream_stats(self, camera_id: str) -> Optional[Dict[str, Any]]:
"""
Get statistics for a stream.
Args:
camera_id: Camera identifier
Returns:
Stream statistics or None if stream not found
"""
self._ensure_thread_safety()
with self._lock:
if camera_id not in self._streams:
return None
stream_info = self._streams[camera_id]
current_time = time.time()
stats = stream_info.to_dict()
stats["uptime_seconds"] = current_time - stream_info.created_at
if stream_info.last_frame_time:
stats["seconds_since_last_frame"] = current_time - stream_info.last_frame_time
return stats
def get_subscription_info(self, subscription_id: str) -> Optional[Dict[str, Any]]:
"""
Get information about a subscription.
Args:
subscription_id: Unique subscription identifier
Returns:
Subscription information or None if not found
"""
self._ensure_thread_safety()
with self._lock:
if subscription_id not in self._subscriptions:
return None
return self._subscriptions[subscription_id].to_dict()
def get_all_streams(self) -> Dict[str, Dict[str, Any]]:
"""
Get information about all active streams.
Returns:
Dictionary mapping camera IDs to stream information
"""
self._ensure_thread_safety()
with self._lock:
return {
camera_id: stream_info.to_dict()
for camera_id, stream_info in self._streams.items()
}
def get_all_subscriptions(self) -> Dict[str, Dict[str, Any]]:
"""
Get information about all active subscriptions.
Returns:
Dictionary mapping subscription IDs to subscription information
"""
self._ensure_thread_safety()
with self._lock:
return {
sub_id: subscription.to_dict()
for sub_id, subscription in self._subscriptions.items()
}
def cleanup_inactive_streams(self, inactive_threshold_seconds: int = 3600) -> int:
"""
Clean up streams that have been inactive for too long.
Args:
inactive_threshold_seconds: Seconds of inactivity before cleanup
Returns:
Number of streams cleaned up
"""
self._ensure_thread_safety()
current_time = time.time()
cleanup_count = 0
with self._lock:
streams_to_remove = []
for camera_id, stream_info in self._streams.items():
# Check if stream has subscribers
if stream_info.subscribers:
continue
# Check if stream has been inactive
last_activity = max(
stream_info.created_at,
stream_info.last_frame_time or 0
)
if current_time - last_activity > inactive_threshold_seconds:
streams_to_remove.append(camera_id)
for camera_id in streams_to_remove:
self._stop_stream(camera_id)
cleanup_count += 1
logger.info(f"Cleaned up inactive stream for camera {camera_id}")
if cleanup_count > 0:
logger.info(f"Cleaned up {cleanup_count} inactive streams")
return cleanup_count
def restart_stream(self, camera_id: str) -> bool:
"""
Restart a stream.
Args:
camera_id: Camera identifier
Returns:
True if stream was restarted successfully
"""
self._ensure_thread_safety()
with self._lock:
if camera_id not in self._streams:
logger.warning(f"Cannot restart stream for camera {camera_id}: stream not found")
return False
stream_info = self._streams[camera_id]
subscribers = stream_info.subscribers.copy()
stream_url = stream_info.stream_url
stream_type = stream_info.stream_type
snapshot_interval = stream_info.snapshot_interval
# Stop current stream
self._stop_stream(camera_id)
# Recreate stream
try:
new_stream_info = self._create_stream_info(
camera_id,
rtsp_url=stream_url if stream_type == "rtsp" else None,
snapshot_url=stream_url if stream_type == "snapshot" else None,
snapshot_interval=snapshot_interval
)
new_stream_info.subscribers = subscribers
self._streams[camera_id] = new_stream_info
# Create and start new frame reader thread
thread = create_frame_reader_thread(
camera_id=camera_id,
rtsp_url=stream_url if stream_type == "rtsp" else None,
snapshot_url=stream_url if stream_type == "snapshot" else None,
snapshot_interval=snapshot_interval,
buffer=new_stream_info.buffer,
stop_event=new_stream_info.stop_event,
connection_callback=self._connection_state_callback
)
if thread:
new_stream_info.thread = thread
thread.start()
logger.info(f"Restarted {stream_type} stream for camera {camera_id}")
return True
else:
# Clean up failed restart
del self._streams[camera_id]
return False
except Exception as e:
logger.error(f"Error restarting stream for camera {camera_id}: {e}")
return False
def shutdown_all(self) -> None:
"""Shutdown all streams and clean up resources."""
self._ensure_thread_safety()
with self._lock:
logger.info("Shutting down all streams...")
# Stop all streams
camera_ids = list(self._streams.keys())
for camera_id in camera_ids:
self._stop_stream(camera_id)
# Clear all subscriptions
self._subscriptions.clear()
logger.info("All streams shut down successfully")
# Global stream manager instance
stream_manager = StreamManager()
# ===== CONVENIENCE FUNCTIONS =====
# These provide a simplified interface for common operations
def create_stream_subscription(subscription_id: str,
camera_id: str,
subscriber_id: str,
rtsp_url: Optional[str] = None,
snapshot_url: Optional[str] = None,
snapshot_interval: Optional[int] = None) -> bool:
"""Create a stream subscription using global stream manager."""
return stream_manager.create_subscription(
subscription_id, camera_id, subscriber_id, rtsp_url, snapshot_url, snapshot_interval
)
def remove_stream_subscription(subscription_id: str) -> bool:
"""Remove a stream subscription using global stream manager."""
return stream_manager.remove_subscription(subscription_id)
def get_stream_frame(subscription_id: str, timeout: float = 0.1) -> Optional[Any]:
"""Get the latest frame for a subscription using global stream manager."""
return stream_manager.get_frame(subscription_id, timeout)
def is_stream_active(camera_id: str) -> bool:
"""Check if a stream is active using global stream manager."""
return stream_manager.is_stream_active(camera_id)
def get_stream_statistics() -> Dict[str, Any]:
"""Get comprehensive stream statistics."""
return {
"streams": stream_manager.get_all_streams(),
"subscriptions": stream_manager.get_all_subscriptions(),
"total_streams": len(stream_manager._streams),
"total_subscriptions": len(stream_manager._subscriptions),
"max_streams": stream_manager.max_streams
}