Refactor: PHASE 2: Core Module Extraction
This commit is contained in:
parent
96bedae80a
commit
4e9ae6bcc4
7 changed files with 3684 additions and 0 deletions
483
detector_worker/detection/stability_validator.py
Normal file
483
detector_worker/detection/stability_validator.py
Normal file
|
@ -0,0 +1,483 @@
|
|||
"""
|
||||
Detection stability validation and lightweight detection functionality.
|
||||
|
||||
This module provides validation functionality for detection pipelines including
|
||||
pipeline execution pre-validation and lightweight detection for validation phases.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Tuple, Union
|
||||
from dataclasses import dataclass, field
|
||||
import numpy as np
|
||||
|
||||
from ..core.constants import (
|
||||
LIGHTWEIGHT_DETECTION_MIN_CONFIDENCE,
|
||||
LIGHTWEIGHT_DETECTION_MIN_BBOX_AREA_RATIO,
|
||||
DEFAULT_MIN_CONFIDENCE
|
||||
)
|
||||
from ..core.exceptions import ValidationError, create_detection_error
|
||||
from ..detection.detection_result import LightweightDetectionResult, BoundingBox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineValidationResult:
|
||||
"""Result of pipeline execution validation."""
|
||||
is_valid: bool
|
||||
missing_branches: List[str] = field(default_factory=list)
|
||||
required_branches: List[str] = field(default_factory=list)
|
||||
available_classes: List[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"is_valid": self.is_valid,
|
||||
"missing_branches": self.missing_branches.copy(),
|
||||
"required_branches": self.required_branches.copy(),
|
||||
"available_classes": self.available_classes.copy()
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LightweightDetectionResult:
|
||||
"""Result from lightweight detection."""
|
||||
car_detected: bool
|
||||
best_detection: Optional[Dict[str, Any]] = None
|
||||
validation_passed: bool = False
|
||||
confidence: Optional[float] = None
|
||||
bbox_area_ratio: Optional[float] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
result = {
|
||||
"car_detected": self.car_detected,
|
||||
"validation_passed": self.validation_passed
|
||||
}
|
||||
|
||||
if self.best_detection:
|
||||
result["best_detection"] = self.best_detection.copy()
|
||||
if self.confidence is not None:
|
||||
result["confidence"] = self.confidence
|
||||
if self.bbox_area_ratio is not None:
|
||||
result["bbox_area_ratio"] = self.bbox_area_ratio
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class StabilityValidator:
|
||||
"""
|
||||
Validates detection stability and pipeline execution requirements.
|
||||
|
||||
This class provides functionality for:
|
||||
- Pipeline execution pre-validation
|
||||
- Lightweight detection for validation phases
|
||||
- Detection threshold validation
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize stability validator."""
|
||||
pass
|
||||
|
||||
def validate_pipeline_execution(self,
|
||||
node: Dict[str, Any],
|
||||
regions_dict: Dict[str, Any]) -> PipelineValidationResult:
|
||||
"""
|
||||
Pre-validate that all required branches will execute successfully before
|
||||
committing to Redis actions and database records.
|
||||
|
||||
Args:
|
||||
node: Pipeline node configuration
|
||||
regions_dict: Dictionary mapping class names to detection regions
|
||||
|
||||
Returns:
|
||||
PipelineValidationResult with validation status and details
|
||||
"""
|
||||
# Get all branches that parallel actions are waiting for
|
||||
required_branches = set()
|
||||
|
||||
for action in node.get("parallelActions", []):
|
||||
if action.get("type") == "postgresql_update_combined":
|
||||
wait_for_branches = action.get("waitForBranches", [])
|
||||
required_branches.update(wait_for_branches)
|
||||
|
||||
if not required_branches:
|
||||
# No parallel actions requiring specific branches
|
||||
logger.debug("No parallel actions with waitForBranches - validation passes")
|
||||
return PipelineValidationResult(
|
||||
is_valid=True,
|
||||
required_branches=[],
|
||||
available_classes=list(regions_dict.keys())
|
||||
)
|
||||
|
||||
logger.debug(f"Pre-validation: checking if required branches {list(required_branches)} will execute")
|
||||
|
||||
# Check each required branch
|
||||
missing_branches = []
|
||||
|
||||
for branch in node.get("branches", []):
|
||||
branch_id = branch["modelId"]
|
||||
|
||||
if branch_id not in required_branches:
|
||||
continue # This branch is not required by parallel actions
|
||||
|
||||
# Check if this branch would be triggered
|
||||
trigger_classes = branch.get("triggerClasses", [])
|
||||
min_conf = branch.get("minConfidence", 0)
|
||||
|
||||
branch_triggered = False
|
||||
for det_class in regions_dict:
|
||||
det_confidence = regions_dict[det_class]["confidence"]
|
||||
|
||||
if (det_class in trigger_classes and det_confidence >= min_conf):
|
||||
branch_triggered = True
|
||||
logger.debug(f"Pre-validation: branch {branch_id} WILL be triggered by {det_class} (conf={det_confidence:.3f} >= {min_conf})")
|
||||
break
|
||||
|
||||
if not branch_triggered:
|
||||
missing_branches.append(branch_id)
|
||||
logger.warning(f"Pre-validation: branch {branch_id} will NOT be triggered - no matching classes or insufficient confidence")
|
||||
logger.debug(f" Required: {trigger_classes} with min_conf={min_conf}")
|
||||
logger.debug(f" Available: {[(cls, regions_dict[cls]['confidence']) for cls in regions_dict]}")
|
||||
|
||||
is_valid = len(missing_branches) == 0
|
||||
|
||||
if missing_branches:
|
||||
logger.error(f"Pipeline pre-validation FAILED: required branches {missing_branches} will not execute")
|
||||
else:
|
||||
logger.info(f"Pipeline pre-validation PASSED: all required branches {list(required_branches)} will execute")
|
||||
|
||||
return PipelineValidationResult(
|
||||
is_valid=is_valid,
|
||||
missing_branches=missing_branches,
|
||||
required_branches=list(required_branches),
|
||||
available_classes=list(regions_dict.keys())
|
||||
)
|
||||
|
||||
def _apply_detection_filters(self,
|
||||
box: Any,
|
||||
model: Any,
|
||||
min_confidence: float,
|
||||
trigger_classes: Optional[List[str]] = None,
|
||||
trigger_class_indices: Optional[List[int]] = None) -> Optional[Dict[str, Any]]:
|
||||
"""Apply confidence and class filtering to a detection box."""
|
||||
try:
|
||||
# Extract detection info
|
||||
xyxy = box.xyxy[0].cpu().numpy()
|
||||
conf = box.conf[0].cpu().numpy()
|
||||
cls_id = int(box.cls[0].cpu().numpy())
|
||||
class_name = model.names[cls_id]
|
||||
|
||||
# Apply confidence threshold
|
||||
if conf < min_confidence:
|
||||
return None
|
||||
|
||||
# Apply trigger class filtering if specified
|
||||
if trigger_class_indices and cls_id not in trigger_class_indices:
|
||||
return None
|
||||
if trigger_classes and class_name not in trigger_classes:
|
||||
return None
|
||||
|
||||
return {
|
||||
"class": class_name,
|
||||
"confidence": float(conf),
|
||||
"bbox": [int(x) for x in xyxy],
|
||||
"class_id": cls_id,
|
||||
"xyxy": xyxy
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing detection box: {e}")
|
||||
return None
|
||||
|
||||
def run_lightweight_detection_with_validation(self,
|
||||
frame: np.ndarray,
|
||||
node: Dict[str, Any],
|
||||
min_confidence: float = LIGHTWEIGHT_DETECTION_MIN_CONFIDENCE,
|
||||
min_bbox_area_ratio: float = LIGHTWEIGHT_DETECTION_MIN_BBOX_AREA_RATIO) -> Dict[str, Any]:
|
||||
"""
|
||||
Run lightweight detection with validation rules for session ID triggering.
|
||||
Returns detection info only if it passes validation thresholds.
|
||||
|
||||
Args:
|
||||
frame: Input frame for detection
|
||||
node: Pipeline node configuration
|
||||
min_confidence: Minimum confidence threshold for validation
|
||||
min_bbox_area_ratio: Minimum bounding box area ratio for validation
|
||||
|
||||
Returns:
|
||||
Dictionary with validation results and detection info
|
||||
"""
|
||||
logger.debug(f"Running lightweight detection with validation: {node['modelId']} (conf>={min_confidence}, bbox_area>={min_bbox_area_ratio})")
|
||||
|
||||
try:
|
||||
# Run basic detection only - no branches, no actions
|
||||
model = node["model"]
|
||||
trigger_classes = node.get("triggerClasses", [])
|
||||
trigger_class_indices = node.get("triggerClassIndices")
|
||||
|
||||
# Run YOLO inference
|
||||
res = model(frame, verbose=False)
|
||||
|
||||
best_detection = None
|
||||
frame_height, frame_width = frame.shape[:2]
|
||||
frame_area = frame_height * frame_width
|
||||
|
||||
for r in res:
|
||||
boxes = r.boxes
|
||||
if boxes is None or len(boxes) == 0:
|
||||
continue
|
||||
|
||||
for box in boxes:
|
||||
detection = self._apply_detection_filters(
|
||||
box, model, min_confidence, trigger_classes, trigger_class_indices
|
||||
)
|
||||
|
||||
if not detection:
|
||||
continue
|
||||
|
||||
# Calculate bbox area ratio
|
||||
x1, y1, x2, y2 = detection["xyxy"]
|
||||
bbox_area = (x2 - x1) * (y2 - y1)
|
||||
bbox_area_ratio = bbox_area / frame_area if frame_area > 0 else 0
|
||||
|
||||
# Apply bbox area threshold
|
||||
if bbox_area_ratio < min_bbox_area_ratio:
|
||||
logger.debug(f"Detection filtered out: bbox_area_ratio={bbox_area_ratio:.3f} < {min_bbox_area_ratio}")
|
||||
continue
|
||||
|
||||
# Validation passed
|
||||
detection["bbox_area_ratio"] = float(bbox_area_ratio)
|
||||
detection["validation_passed"] = True
|
||||
|
||||
if not best_detection or detection["confidence"] > best_detection["confidence"]:
|
||||
best_detection = detection
|
||||
|
||||
if best_detection:
|
||||
logger.debug(f"Validation PASSED: {best_detection['class']} (conf: {best_detection['confidence']:.3f}, area: {best_detection['bbox_area_ratio']:.3f})")
|
||||
return best_detection
|
||||
else:
|
||||
logger.debug(f"Validation FAILED: No detection meets criteria (conf>={min_confidence}, area>={min_bbox_area_ratio})")
|
||||
return {"validation_passed": False}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in lightweight detection with validation: {str(e)}", exc_info=True)
|
||||
return {"validation_passed": False}
|
||||
|
||||
def run_lightweight_detection(self,
|
||||
frame: np.ndarray,
|
||||
node: Dict[str, Any]) -> LightweightDetectionResult:
|
||||
"""
|
||||
Run lightweight detection for car presence validation only.
|
||||
Returns basic detection info without running branches or external actions.
|
||||
|
||||
Args:
|
||||
frame: Input frame for detection
|
||||
node: Pipeline node configuration
|
||||
|
||||
Returns:
|
||||
LightweightDetectionResult with car detection status
|
||||
"""
|
||||
logger.debug(f"Running lightweight detection: {node['modelId']}")
|
||||
|
||||
try:
|
||||
# Run basic detection only - no branches, no actions
|
||||
model = node["model"]
|
||||
min_confidence = node.get("minConfidence", DEFAULT_MIN_CONFIDENCE)
|
||||
trigger_classes = node.get("triggerClasses", [])
|
||||
trigger_class_indices = node.get("triggerClassIndices")
|
||||
|
||||
# Run YOLO inference
|
||||
res = model(frame, verbose=False)
|
||||
|
||||
car_detected = False
|
||||
best_detection = None
|
||||
|
||||
for r in res:
|
||||
boxes = r.boxes
|
||||
if boxes is None or len(boxes) == 0:
|
||||
continue
|
||||
|
||||
for box in boxes:
|
||||
detection = self._apply_detection_filters(
|
||||
box, model, min_confidence, trigger_classes, trigger_class_indices
|
||||
)
|
||||
|
||||
if not detection:
|
||||
continue
|
||||
|
||||
# Car detected
|
||||
car_detected = True
|
||||
if not best_detection or detection["confidence"] > best_detection["confidence"]:
|
||||
best_detection = detection
|
||||
|
||||
logger.debug(f"Lightweight detection result: car_detected={car_detected}")
|
||||
if best_detection:
|
||||
logger.debug(f"Best detection: {best_detection['class']} (conf: {best_detection['confidence']:.3f})")
|
||||
|
||||
return LightweightDetectionResult(
|
||||
car_detected=car_detected,
|
||||
best_detection=best_detection,
|
||||
confidence=best_detection["confidence"] if best_detection else None
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in lightweight detection: {str(e)}", exc_info=True)
|
||||
return LightweightDetectionResult(car_detected=False, best_detection=None)
|
||||
|
||||
def validate_detection_thresholds(self,
|
||||
detection: Dict[str, Any],
|
||||
frame_shape: Tuple[int, int, int],
|
||||
min_confidence: float,
|
||||
min_bbox_area_ratio: float) -> bool:
|
||||
"""
|
||||
Validate detection against confidence and bbox area thresholds.
|
||||
|
||||
Args:
|
||||
detection: Detection dictionary with confidence and bbox
|
||||
frame_shape: Frame dimensions (height, width, channels)
|
||||
min_confidence: Minimum confidence threshold
|
||||
min_bbox_area_ratio: Minimum bbox area ratio threshold
|
||||
|
||||
Returns:
|
||||
True if detection passes all validation thresholds
|
||||
"""
|
||||
try:
|
||||
# Check confidence
|
||||
confidence = detection.get("confidence", 0.0)
|
||||
if confidence < min_confidence:
|
||||
logger.debug(f"Detection failed confidence threshold: {confidence:.3f} < {min_confidence}")
|
||||
return False
|
||||
|
||||
# Check bbox area ratio
|
||||
bbox = detection.get("bbox", [])
|
||||
if len(bbox) >= 4:
|
||||
x1, y1, x2, y2 = bbox[:4]
|
||||
bbox_area = (x2 - x1) * (y2 - y1)
|
||||
|
||||
frame_height, frame_width = frame_shape[:2]
|
||||
frame_area = frame_height * frame_width
|
||||
bbox_area_ratio = bbox_area / frame_area if frame_area > 0 else 0
|
||||
|
||||
if bbox_area_ratio < min_bbox_area_ratio:
|
||||
logger.debug(f"Detection failed bbox area threshold: {bbox_area_ratio:.3f} < {min_bbox_area_ratio}")
|
||||
return False
|
||||
|
||||
logger.debug(f"Detection passed validation thresholds: conf={confidence:.3f}, area_ratio={bbox_area_ratio:.3f}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating detection thresholds: {e}")
|
||||
return False
|
||||
|
||||
def check_branch_execution_requirements(self,
|
||||
branch: Dict[str, Any],
|
||||
regions_dict: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if a branch will be executed based on trigger classes and confidence.
|
||||
|
||||
Args:
|
||||
branch: Branch configuration
|
||||
regions_dict: Available detection regions
|
||||
|
||||
Returns:
|
||||
True if branch will be executed
|
||||
"""
|
||||
trigger_classes = branch.get("triggerClasses", [])
|
||||
min_conf = branch.get("minConfidence", 0)
|
||||
|
||||
for det_class in regions_dict:
|
||||
det_confidence = regions_dict[det_class]["confidence"]
|
||||
|
||||
if det_class in trigger_classes and det_confidence >= min_conf:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_validation_summary(self,
|
||||
node: Dict[str, Any],
|
||||
regions_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Get comprehensive validation summary for a pipeline node.
|
||||
|
||||
Args:
|
||||
node: Pipeline node configuration
|
||||
regions_dict: Available detection regions
|
||||
|
||||
Returns:
|
||||
Dictionary with validation summary information
|
||||
"""
|
||||
summary = {
|
||||
"model_id": node.get("modelId", "unknown"),
|
||||
"available_classes": list(regions_dict.keys()),
|
||||
"branch_count": len(node.get("branches", [])),
|
||||
"parallel_action_count": len(node.get("parallelActions", [])),
|
||||
"branches_that_will_execute": [],
|
||||
"branches_that_will_not_execute": []
|
||||
}
|
||||
|
||||
# Check each branch
|
||||
for branch in node.get("branches", []):
|
||||
branch_id = branch["modelId"]
|
||||
will_execute = self.check_branch_execution_requirements(branch, regions_dict)
|
||||
|
||||
if will_execute:
|
||||
summary["branches_that_will_execute"].append(branch_id)
|
||||
else:
|
||||
summary["branches_that_will_not_execute"].append(branch_id)
|
||||
|
||||
# Pipeline validation
|
||||
pipeline_validation = self.validate_pipeline_execution(node, regions_dict)
|
||||
summary["pipeline_validation"] = pipeline_validation.to_dict()
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
# Global stability validator instance
|
||||
stability_validator = StabilityValidator()
|
||||
|
||||
|
||||
# ===== CONVENIENCE FUNCTIONS =====
|
||||
# These provide the same interface as the original functions in pympta.py
|
||||
|
||||
def validate_pipeline_execution(node: Dict[str, Any], regions_dict: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
||||
"""
|
||||
Pre-validate that all required branches will execute successfully.
|
||||
|
||||
Returns:
|
||||
- (True, []) if pipeline can execute completely
|
||||
- (False, missing_branches) if some required branches won't execute
|
||||
"""
|
||||
result = stability_validator.validate_pipeline_execution(node, regions_dict)
|
||||
return result.is_valid, result.missing_branches
|
||||
|
||||
|
||||
def run_lightweight_detection_with_validation(frame: np.ndarray,
|
||||
node: Dict[str, Any],
|
||||
min_confidence: float = LIGHTWEIGHT_DETECTION_MIN_CONFIDENCE,
|
||||
min_bbox_area_ratio: float = LIGHTWEIGHT_DETECTION_MIN_BBOX_AREA_RATIO) -> Dict[str, Any]:
|
||||
"""Run lightweight detection with validation rules for session ID triggering."""
|
||||
return stability_validator.run_lightweight_detection_with_validation(
|
||||
frame, node, min_confidence, min_bbox_area_ratio
|
||||
)
|
||||
|
||||
|
||||
def run_lightweight_detection(frame: np.ndarray, node: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Run lightweight detection for car presence validation only."""
|
||||
result = stability_validator.run_lightweight_detection(frame, node)
|
||||
return result.to_dict()
|
||||
|
||||
|
||||
def validate_detection_thresholds(detection: Dict[str, Any],
|
||||
frame_shape: Tuple[int, int, int],
|
||||
min_confidence: float,
|
||||
min_bbox_area_ratio: float) -> bool:
|
||||
"""Validate detection against confidence and bbox area thresholds."""
|
||||
return stability_validator.validate_detection_thresholds(
|
||||
detection, frame_shape, min_confidence, min_bbox_area_ratio
|
||||
)
|
||||
|
||||
|
||||
def get_validation_summary(node: Dict[str, Any], regions_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get comprehensive validation summary for a pipeline node."""
|
||||
return stability_validator.get_validation_summary(node, regions_dict)
|
481
detector_worker/detection/tracking_manager.py
Normal file
481
detector_worker/detection/tracking_manager.py
Normal file
|
@ -0,0 +1,481 @@
|
|||
"""
|
||||
Object tracking management and state handling.
|
||||
|
||||
This module provides tracking state management, session handling,
|
||||
and occupancy detection functionality for the detection pipeline.
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Tuple, Set
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from ..core.constants import SESSION_TIMEOUT_SECONDS
|
||||
from ..core.exceptions import TrackingError, create_detection_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionState:
|
||||
"""Session state for tracking management."""
|
||||
active: bool = True
|
||||
waiting_for_backend_session: bool = False
|
||||
wait_start_time: float = 0.0
|
||||
reset_tracker_on_resume: bool = False
|
||||
occupancy_mode: bool = False
|
||||
occupancy_enabled_at: Optional[float] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"active": self.active,
|
||||
"waiting_for_backend_session": self.waiting_for_backend_session,
|
||||
"wait_start_time": self.wait_start_time,
|
||||
"reset_tracker_on_resume": self.reset_tracker_on_resume,
|
||||
"occupancy_mode": self.occupancy_mode,
|
||||
"occupancy_enabled_at": self.occupancy_enabled_at
|
||||
}
|
||||
|
||||
def from_dict(self, data: Dict[str, Any]) -> None:
|
||||
"""Update from dictionary data."""
|
||||
self.active = data.get("active", True)
|
||||
self.waiting_for_backend_session = data.get("waiting_for_backend_session", False)
|
||||
self.wait_start_time = data.get("wait_start_time", 0.0)
|
||||
self.reset_tracker_on_resume = data.get("reset_tracker_on_resume", False)
|
||||
self.occupancy_mode = data.get("occupancy_mode", False)
|
||||
self.occupancy_enabled_at = data.get("occupancy_enabled_at")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrackingData:
|
||||
"""Complete tracking data for a camera and model."""
|
||||
track_stability_counters: Dict[int, int] = field(default_factory=dict)
|
||||
stable_tracks: Set[int] = field(default_factory=set)
|
||||
session_state: SessionState = field(default_factory=SessionState)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"track_stability_counters": dict(self.track_stability_counters),
|
||||
"stable_tracks": list(self.stable_tracks),
|
||||
"session_state": self.session_state.to_dict()
|
||||
}
|
||||
|
||||
|
||||
class TrackingManager:
|
||||
"""
|
||||
Manages object tracking state and session handling across cameras and models.
|
||||
|
||||
This class provides centralized tracking state management including:
|
||||
- Track stability counters
|
||||
- Session state management
|
||||
- Backend session timeout handling
|
||||
- Occupancy detection mode
|
||||
"""
|
||||
|
||||
def __init__(self, session_timeout_seconds: int = SESSION_TIMEOUT_SECONDS):
|
||||
"""
|
||||
Initialize tracking manager.
|
||||
|
||||
Args:
|
||||
session_timeout_seconds: Timeout for backend session waiting
|
||||
"""
|
||||
self.session_timeout_seconds = session_timeout_seconds
|
||||
self._camera_tracking_data: Dict[str, Dict[str, TrackingData]] = {}
|
||||
self._lock = None
|
||||
|
||||
def _ensure_thread_safety(self):
|
||||
"""Initialize thread safety if not already done."""
|
||||
if self._lock is None:
|
||||
import threading
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def get_camera_tracking_data(self, camera_id: str, model_id: str) -> TrackingData:
|
||||
"""
|
||||
Get or create tracking data for a specific camera and model.
|
||||
|
||||
Args:
|
||||
camera_id: Unique camera identifier
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
Tracking data for the camera and model
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
if camera_id not in self._camera_tracking_data:
|
||||
self._camera_tracking_data[camera_id] = {}
|
||||
|
||||
if model_id not in self._camera_tracking_data[camera_id]:
|
||||
logger.warning(f"🔄 Camera {camera_id}: Creating NEW tracking data for {model_id} - this will reset any cooldown!")
|
||||
self._camera_tracking_data[camera_id][model_id] = TrackingData()
|
||||
|
||||
return self._camera_tracking_data[camera_id][model_id]
|
||||
|
||||
def check_stable_tracks(self,
|
||||
camera_id: str,
|
||||
model_id: str,
|
||||
regions_dict: Dict[str, Any]) -> Tuple[bool, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Check if any stable tracks match the detected classes for a specific camera.
|
||||
|
||||
Args:
|
||||
camera_id: Unique camera identifier
|
||||
model_id: Model identifier
|
||||
regions_dict: Dictionary mapping class names to detection regions
|
||||
|
||||
Returns:
|
||||
Tuple of (has_stable_tracks, stable_detections)
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||
stable_tracks = tracking_data.stable_tracks
|
||||
|
||||
if not stable_tracks:
|
||||
return False, []
|
||||
|
||||
# Check for track-based stability
|
||||
stable_detections = []
|
||||
|
||||
for class_name, region_data in regions_dict.items():
|
||||
detection = region_data.get("detection", {})
|
||||
track_id = detection.get("id")
|
||||
|
||||
if track_id in stable_tracks:
|
||||
stable_detections.append(detection)
|
||||
|
||||
has_stable = len(stable_detections) > 0
|
||||
logger.debug(f"Camera {camera_id}: Stable track check - stable_tracks: {list(stable_tracks)}, found: {has_stable}")
|
||||
|
||||
return has_stable, stable_detections
|
||||
|
||||
def reset_tracking_state(self,
|
||||
camera_id: str,
|
||||
model_id: str,
|
||||
reason: str = "session ended") -> None:
|
||||
"""
|
||||
Reset tracking state after session completion or timeout.
|
||||
|
||||
Args:
|
||||
camera_id: Unique camera identifier
|
||||
model_id: Model identifier
|
||||
reason: Reason for reset (for logging)
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||
session_state = tracking_data.session_state
|
||||
|
||||
# Clear all tracking data for fresh start
|
||||
tracking_data.track_stability_counters.clear()
|
||||
tracking_data.stable_tracks.clear()
|
||||
session_state.active = True
|
||||
session_state.waiting_for_backend_session = False
|
||||
session_state.wait_start_time = 0.0
|
||||
session_state.reset_tracker_on_resume = True
|
||||
|
||||
logger.info(f"Camera {camera_id}: 🔄 Reset tracking state - {reason}")
|
||||
logger.info(f"Camera {camera_id}: 🧹 Cleared stability counters and stable tracks for fresh session")
|
||||
|
||||
def is_camera_active(self, camera_id: str, model_id: str) -> bool:
|
||||
"""
|
||||
Check if camera should be processing detections.
|
||||
|
||||
Args:
|
||||
camera_id: Unique camera identifier
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
True if camera should process detections, False otherwise
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||
session_state = tracking_data.session_state
|
||||
|
||||
# Check if waiting for backend sessionId has timed out
|
||||
if session_state.waiting_for_backend_session:
|
||||
current_time = time.time()
|
||||
wait_start_time = session_state.wait_start_time
|
||||
elapsed_time = current_time - wait_start_time
|
||||
|
||||
if elapsed_time >= self.session_timeout_seconds:
|
||||
logger.warning(f"Camera {camera_id}: Backend sessionId timeout ({self.session_timeout_seconds}s) - resetting tracking")
|
||||
self.reset_tracking_state(camera_id, model_id, "backend sessionId timeout")
|
||||
return True
|
||||
else:
|
||||
remaining_time = self.session_timeout_seconds - elapsed_time
|
||||
logger.debug(f"Camera {camera_id}: Still waiting for backend sessionId ({remaining_time:.1f}s remaining)")
|
||||
return False
|
||||
|
||||
return session_state.active
|
||||
|
||||
def set_waiting_for_backend_session(self,
|
||||
camera_id: str,
|
||||
model_id: str,
|
||||
waiting: bool = True) -> None:
|
||||
"""
|
||||
Set whether camera is waiting for backend session ID.
|
||||
|
||||
Args:
|
||||
camera_id: Unique camera identifier
|
||||
model_id: Model identifier
|
||||
waiting: Whether to wait for backend session
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||
session_state = tracking_data.session_state
|
||||
|
||||
session_state.waiting_for_backend_session = waiting
|
||||
if waiting:
|
||||
session_state.wait_start_time = time.time()
|
||||
logger.debug(f"Camera {camera_id}: Started waiting for backend sessionId")
|
||||
else:
|
||||
session_state.wait_start_time = 0.0
|
||||
logger.debug(f"Camera {camera_id}: Stopped waiting for backend sessionId")
|
||||
|
||||
def cleanup_camera_tracking(self, camera_id: str) -> None:
|
||||
"""
|
||||
Clean up tracking data when a camera is disconnected.
|
||||
|
||||
Args:
|
||||
camera_id: Unique camera identifier
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
if camera_id in self._camera_tracking_data:
|
||||
del self._camera_tracking_data[camera_id]
|
||||
logger.info(f"Cleaned up tracking data for camera {camera_id}")
|
||||
|
||||
def set_occupancy_mode(self,
|
||||
camera_id: str,
|
||||
model_id: str,
|
||||
enable: bool = True) -> bool:
|
||||
"""
|
||||
Enable or disable occupancy detection mode.
|
||||
|
||||
Occupancy mode stops model inference after pipeline completion
|
||||
while backend session handling continues in background.
|
||||
|
||||
Args:
|
||||
camera_id: Unique camera identifier
|
||||
model_id: Model identifier
|
||||
enable: True to enable occupancy mode, False to disable
|
||||
|
||||
Returns:
|
||||
Current occupancy mode state
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||
session_state = tracking_data.session_state
|
||||
|
||||
if enable:
|
||||
session_state.occupancy_mode = True
|
||||
session_state.occupancy_enabled_at = time.time()
|
||||
logger.debug(f"Camera {camera_id}: Occupancy mode ENABLED - model will stop after pipeline completion")
|
||||
else:
|
||||
session_state.occupancy_mode = False
|
||||
session_state.occupancy_enabled_at = None
|
||||
logger.debug(f"Camera {camera_id}: Occupancy mode DISABLED - model will continue running")
|
||||
|
||||
return session_state.occupancy_mode
|
||||
|
||||
def is_occupancy_mode_enabled(self, camera_id: str, model_id: str) -> bool:
|
||||
"""
|
||||
Check if occupancy mode is enabled for a camera.
|
||||
|
||||
Args:
|
||||
camera_id: Unique camera identifier
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
True if occupancy mode is enabled
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||
return tracking_data.session_state.occupancy_mode
|
||||
|
||||
def get_occupancy_duration(self, camera_id: str, model_id: str) -> Optional[float]:
|
||||
"""
|
||||
Get duration since occupancy mode was enabled.
|
||||
|
||||
Args:
|
||||
camera_id: Unique camera identifier
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
Duration in seconds since occupancy mode was enabled, or None if not enabled
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||
session_state = tracking_data.session_state
|
||||
|
||||
if not session_state.occupancy_mode or not session_state.occupancy_enabled_at:
|
||||
return None
|
||||
|
||||
return time.time() - session_state.occupancy_enabled_at
|
||||
|
||||
def get_session_state(self, camera_id: str, model_id: str) -> SessionState:
|
||||
"""
|
||||
Get session state for a camera and model.
|
||||
|
||||
Args:
|
||||
camera_id: Unique camera identifier
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
Session state object
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||
return tracking_data.session_state
|
||||
|
||||
def update_session_state(self,
|
||||
camera_id: str,
|
||||
model_id: str,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Update session state properties.
|
||||
|
||||
Args:
|
||||
camera_id: Unique camera identifier
|
||||
model_id: Model identifier
|
||||
**kwargs: Session state properties to update
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||
session_state = tracking_data.session_state
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(session_state, key):
|
||||
setattr(session_state, key, value)
|
||||
logger.debug(f"Camera {camera_id}: Updated session state {key} = {value}")
|
||||
else:
|
||||
logger.warning(f"Camera {camera_id}: Unknown session state property: {key}")
|
||||
|
||||
def get_tracking_statistics(self, camera_id: str, model_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get comprehensive tracking statistics for a camera and model.
|
||||
|
||||
Args:
|
||||
camera_id: Unique camera identifier
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
Dictionary with tracking statistics
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
tracking_data = self.get_camera_tracking_data(camera_id, model_id)
|
||||
current_time = time.time()
|
||||
|
||||
stats = tracking_data.to_dict()
|
||||
|
||||
# Add computed statistics
|
||||
stats["total_tracked_objects"] = len(tracking_data.track_stability_counters)
|
||||
stats["stable_track_count"] = len(tracking_data.stable_tracks)
|
||||
|
||||
session_state = tracking_data.session_state
|
||||
if session_state.waiting_for_backend_session and session_state.wait_start_time > 0:
|
||||
stats["backend_wait_duration"] = current_time - session_state.wait_start_time
|
||||
stats["backend_wait_remaining"] = max(0, self.session_timeout_seconds - stats["backend_wait_duration"])
|
||||
|
||||
if session_state.occupancy_mode and session_state.occupancy_enabled_at:
|
||||
stats["occupancy_duration"] = current_time - session_state.occupancy_enabled_at
|
||||
|
||||
return stats
|
||||
|
||||
def get_all_camera_stats(self) -> Dict[str, Dict[str, Dict[str, Any]]]:
|
||||
"""
|
||||
Get tracking statistics for all monitored cameras.
|
||||
|
||||
Returns:
|
||||
Nested dictionary: {camera_id: {model_id: stats}}
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
all_stats = {}
|
||||
for camera_id, models in self._camera_tracking_data.items():
|
||||
all_stats[camera_id] = {}
|
||||
for model_id in models.keys():
|
||||
all_stats[camera_id][model_id] = self.get_tracking_statistics(camera_id, model_id)
|
||||
|
||||
return all_stats
|
||||
|
||||
|
||||
# Global tracking manager instance
|
||||
tracking_manager = TrackingManager()
|
||||
|
||||
|
||||
# ===== CONVENIENCE FUNCTIONS =====
|
||||
# These provide the same interface as the original functions in pympta.py
|
||||
|
||||
def check_stable_tracks(camera_id: str, model_id: str, regions_dict: Dict[str, Any]) -> Tuple[bool, List[Dict[str, Any]]]:
|
||||
"""Check if any stable tracks match the detected classes for a specific camera."""
|
||||
return tracking_manager.check_stable_tracks(camera_id, model_id, regions_dict)
|
||||
|
||||
|
||||
def reset_tracking_state(camera_id: str, model_id: str, reason: str = "session ended") -> None:
|
||||
"""Reset tracking state after session completion or timeout."""
|
||||
tracking_manager.reset_tracking_state(camera_id, model_id, reason)
|
||||
|
||||
|
||||
def is_camera_active(camera_id: str, model_id: str) -> bool:
|
||||
"""Check if camera should be processing detections."""
|
||||
return tracking_manager.is_camera_active(camera_id, model_id)
|
||||
|
||||
|
||||
def cleanup_camera_stability(camera_id: str) -> None:
|
||||
"""Clean up tracking data when a camera is disconnected."""
|
||||
tracking_manager.cleanup_camera_tracking(camera_id)
|
||||
|
||||
|
||||
def occupancy_detector(camera_id: str, model_id: str, enable: bool = True) -> bool:
|
||||
"""
|
||||
Enable or disable occupancy detection mode.
|
||||
|
||||
When enabled:
|
||||
- Model stops inference after completing full pipeline
|
||||
- Backend sessionId handling continues in background
|
||||
|
||||
Note: This is a temporary function that will be changed in the future.
|
||||
"""
|
||||
return tracking_manager.set_occupancy_mode(camera_id, model_id, enable)
|
||||
|
||||
|
||||
def get_camera_tracking_data(camera_id: str, model_id: str) -> Dict[str, Any]:
|
||||
"""Get tracking data for a specific camera and model."""
|
||||
tracking_data = tracking_manager.get_camera_tracking_data(camera_id, model_id)
|
||||
return tracking_data.to_dict()
|
||||
|
||||
|
||||
def set_waiting_for_backend_session(camera_id: str, model_id: str, waiting: bool = True) -> None:
|
||||
"""Set whether camera is waiting for backend session ID."""
|
||||
tracking_manager.set_waiting_for_backend_session(camera_id, model_id, waiting)
|
||||
|
||||
|
||||
def get_tracking_statistics(camera_id: str, model_id: str) -> Dict[str, Any]:
|
||||
"""Get comprehensive tracking statistics for a camera and model."""
|
||||
return tracking_manager.get_tracking_statistics(camera_id, model_id)
|
633
detector_worker/detection/yolo_detector.py
Normal file
633
detector_worker/detection/yolo_detector.py
Normal file
|
@ -0,0 +1,633 @@
|
|||
"""
|
||||
YOLO detection with BoT-SORT tracking integration.
|
||||
|
||||
This module provides the main detection functionality combining YOLO object detection
|
||||
with BoT-SORT tracking for stable track validation.
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Tuple, Set
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from ..core.constants import DEFAULT_STABILITY_THRESHOLD, DEFAULT_MIN_CONFIDENCE
|
||||
from ..core.exceptions import DetectionError, create_detection_error
|
||||
from ..detection.detection_result import DetectionResult, BoundingBox, DetectionSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrackingConfig:
|
||||
"""Configuration for object tracking."""
|
||||
enabled: bool = True
|
||||
method: str = "botsort"
|
||||
reid_config_path: str = "botsort.yaml"
|
||||
stability_threshold: int = DEFAULT_STABILITY_THRESHOLD
|
||||
use_tracking_for_model: bool = True # False for frontal detection models
|
||||
|
||||
|
||||
@dataclass
|
||||
class StabilityData:
|
||||
"""Track stability data for a camera and model."""
|
||||
track_stability_counters: Dict[int, int] = field(default_factory=dict)
|
||||
stable_tracks: Set[int] = field(default_factory=set)
|
||||
session_state: Dict[str, Any] = field(default_factory=lambda: {
|
||||
"active": True,
|
||||
"waiting_for_backend_session": False,
|
||||
"wait_start_time": 0.0,
|
||||
"reset_tracker_on_resume": False
|
||||
})
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"track_stability_counters": dict(self.track_stability_counters),
|
||||
"stable_tracks": list(self.stable_tracks),
|
||||
"session_state": self.session_state.copy()
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of track stability validation."""
|
||||
validation_complete: bool
|
||||
stable_tracks: List[int] = field(default_factory=list)
|
||||
current_tracks: List[int] = field(default_factory=list)
|
||||
newly_stable_tracks: List[int] = field(default_factory=list)
|
||||
send_none_detection: bool = False
|
||||
branch_node: bool = False
|
||||
bypass_validation: bool = False
|
||||
awaiting_stable_tracks: bool = False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"validation_complete": self.validation_complete,
|
||||
"stable_tracks": self.stable_tracks.copy(),
|
||||
"current_tracks": self.current_tracks.copy(),
|
||||
"newly_stable_tracks": self.newly_stable_tracks.copy(),
|
||||
"send_none_detection": self.send_none_detection,
|
||||
"branch_node": self.branch_node,
|
||||
"bypass_validation": self.bypass_validation,
|
||||
"awaiting_stable_tracks": self.awaiting_stable_tracks
|
||||
}
|
||||
|
||||
|
||||
class StabilityTracker:
|
||||
"""Manages track stability validation for cameras and models."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize stability tracker."""
|
||||
self._camera_stability_tracking: Dict[str, Dict[str, StabilityData]] = {}
|
||||
self._lock = None
|
||||
|
||||
def _ensure_thread_safety(self):
|
||||
"""Initialize thread safety if not already done."""
|
||||
if self._lock is None:
|
||||
import threading
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def get_camera_stability_data(self, camera_id: str, model_id: str) -> StabilityData:
|
||||
"""
|
||||
Get or create stability tracking data for a specific camera and model.
|
||||
|
||||
Args:
|
||||
camera_id: Unique camera identifier
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
Stability data for the camera and model
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
if camera_id not in self._camera_stability_tracking:
|
||||
self._camera_stability_tracking[camera_id] = {}
|
||||
|
||||
if model_id not in self._camera_stability_tracking[camera_id]:
|
||||
logger.warning(f"🔄 Camera {camera_id}: Creating NEW stability data for {model_id} - this will reset any cooldown!")
|
||||
self._camera_stability_tracking[camera_id][model_id] = StabilityData()
|
||||
|
||||
return self._camera_stability_tracking[camera_id][model_id]
|
||||
|
||||
def reset_camera_stability_tracking(self, camera_id: str, model_id: str) -> None:
|
||||
"""
|
||||
Reset all stability tracking data for a specific camera and model.
|
||||
|
||||
Args:
|
||||
camera_id: Unique camera identifier
|
||||
model_id: Model identifier
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
if camera_id in self._camera_stability_tracking and model_id in self._camera_stability_tracking[camera_id]:
|
||||
stability_data = self._camera_stability_tracking[camera_id][model_id]
|
||||
|
||||
# Clear all tracking data
|
||||
old_counters = dict(stability_data.track_stability_counters)
|
||||
old_stable = list(stability_data.stable_tracks)
|
||||
|
||||
stability_data.track_stability_counters.clear()
|
||||
stability_data.stable_tracks.clear()
|
||||
|
||||
# IMPORTANT: Set flag to reset YOLO tracker on next detection run
|
||||
stability_data.session_state["reset_tracker_on_resume"] = True
|
||||
|
||||
logger.info(f"🧹 Camera {camera_id}: CLEARED stability tracking - old_counters={old_counters}, old_stable={old_stable}")
|
||||
logger.info(f"🔄 Camera {camera_id}: YOLO tracker will be reset on next detection - fresh track IDs will start from 1")
|
||||
|
||||
def update_single_track_stability(self,
|
||||
detection: Optional[Dict[str, Any]],
|
||||
camera_id: str,
|
||||
model_id: str,
|
||||
stability_threshold: int,
|
||||
current_mode: str,
|
||||
is_branch_node: bool = False) -> ValidationResult:
|
||||
"""
|
||||
Update track stability validation for a single highest confidence detection.
|
||||
|
||||
Args:
|
||||
detection: Detection data with track_id, or None if no detection
|
||||
camera_id: Unique camera identifier
|
||||
model_id: Model identifier
|
||||
stability_threshold: Number of consecutive frames needed for stability
|
||||
current_mode: Current pipeline mode
|
||||
is_branch_node: Whether this is a branch node (skip validation)
|
||||
|
||||
Returns:
|
||||
Validation result with track status
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
# Branch nodes should not do validation - only main pipeline should
|
||||
if is_branch_node:
|
||||
logger.debug(f"⏭️ Camera {camera_id}: Skipping validation for branch node {model_id} - validation only done at main pipeline level")
|
||||
return ValidationResult(
|
||||
validation_complete=False,
|
||||
branch_node=True,
|
||||
stable_tracks=[],
|
||||
current_tracks=[]
|
||||
)
|
||||
|
||||
# Check current mode - validation counters should increment in both validation_detecting and full_pipeline modes
|
||||
is_validation_mode = current_mode in ["validation_detecting", "full_pipeline"]
|
||||
|
||||
# Get camera-specific stability data
|
||||
stability_data = self.get_camera_stability_data(camera_id, model_id)
|
||||
track_counters = stability_data.track_stability_counters
|
||||
stable_tracks = stability_data.stable_tracks
|
||||
|
||||
current_track_id = detection.get("id") if detection else None
|
||||
|
||||
# ═══ MODE-AWARE TRACK VALIDATION ═══
|
||||
logger.debug(f"📋 Camera {camera_id}: === TRACK VALIDATION ANALYSIS ===")
|
||||
logger.debug(f"📋 Camera {camera_id}: Current mode: {current_mode} (validation_mode={is_validation_mode})")
|
||||
logger.debug(f"📋 Camera {camera_id}: Current track_id: {current_track_id} (assigned by YOLO tracking - not sequential)")
|
||||
logger.debug(f"📋 Camera {camera_id}: Existing counters: {dict(track_counters)}")
|
||||
logger.debug(f"📋 Camera {camera_id}: Stable tracks: {list(stable_tracks)}")
|
||||
|
||||
# IMPORTANT: Only modify validation counters during validation_detecting mode
|
||||
if not is_validation_mode:
|
||||
logger.debug(f"🚫 Camera {camera_id}: NOT in validation mode - skipping counter modifications")
|
||||
return ValidationResult(
|
||||
validation_complete=False,
|
||||
stable_tracks=list(stable_tracks),
|
||||
current_tracks=[current_track_id] if current_track_id is not None else []
|
||||
)
|
||||
|
||||
if current_track_id is not None:
|
||||
# Check if this is a different track than we were tracking
|
||||
previous_track_ids = list(track_counters.keys())
|
||||
|
||||
# VALIDATION MODE: Reset counter if different track OR if track was previously stable
|
||||
should_reset = (
|
||||
len(previous_track_ids) == 0 or # No previous tracking
|
||||
current_track_id not in previous_track_ids or # Different track ID
|
||||
current_track_id in stable_tracks # Track was stable - start fresh validation
|
||||
)
|
||||
|
||||
logger.debug(f"📋 Camera {camera_id}: Previous track_ids: {previous_track_ids}")
|
||||
logger.debug(f"📋 Camera {camera_id}: Track {current_track_id} was stable: {current_track_id in stable_tracks}")
|
||||
logger.debug(f"📋 Camera {camera_id}: Should reset counters: {should_reset}")
|
||||
|
||||
if should_reset:
|
||||
# Clear all previous tracking - fresh validation needed
|
||||
if previous_track_ids:
|
||||
for old_track_id in previous_track_ids:
|
||||
old_count = track_counters.pop(old_track_id, 0)
|
||||
stable_tracks.discard(old_track_id)
|
||||
logger.info(f"🔄 Camera {camera_id}: VALIDATION RESET - track {old_track_id} counter from {old_count} to 0 (reason: {'stable_track_restart' if current_track_id == old_track_id else 'different_track'})")
|
||||
|
||||
# Start fresh validation for this track
|
||||
old_count = track_counters.get(current_track_id, 0) # Store old count for logging
|
||||
track_counters[current_track_id] = 1
|
||||
current_count = 1
|
||||
logger.info(f"🆕 Camera {camera_id}: FRESH VALIDATION - Track {current_track_id} starting at 1/{stability_threshold}")
|
||||
else:
|
||||
# Continue validation for same track
|
||||
old_count = track_counters.get(current_track_id, 0)
|
||||
track_counters[current_track_id] = old_count + 1
|
||||
current_count = track_counters[current_track_id]
|
||||
|
||||
logger.debug(f"🔢 Camera {camera_id}: Track {current_track_id} counter: {old_count} → {current_count}")
|
||||
logger.info(f"🔍 Camera {camera_id}: Track ID {current_track_id} validation {current_count}/{stability_threshold}")
|
||||
|
||||
# Check if track has reached stability threshold
|
||||
logger.debug(f"📊 Camera {camera_id}: Checking stability: {current_count} >= {stability_threshold}? {current_count >= stability_threshold}")
|
||||
logger.debug(f"📊 Camera {camera_id}: Already stable: {current_track_id in stable_tracks}")
|
||||
|
||||
if current_count >= stability_threshold and current_track_id not in stable_tracks:
|
||||
stable_tracks.add(current_track_id)
|
||||
logger.info(f"✅ Camera {camera_id}: Track ID {current_track_id} STABLE after {current_count} consecutive frames")
|
||||
logger.info(f"🎯 Camera {camera_id}: TRACK VALIDATION COMPLETE")
|
||||
logger.debug(f"🎯 Camera {camera_id}: Stable tracks now: {list(stable_tracks)}")
|
||||
return ValidationResult(
|
||||
validation_complete=True,
|
||||
send_none_detection=True,
|
||||
stable_tracks=[current_track_id],
|
||||
newly_stable_tracks=[current_track_id],
|
||||
current_tracks=[current_track_id]
|
||||
)
|
||||
elif current_count >= stability_threshold:
|
||||
logger.debug(f"📊 Camera {camera_id}: Track {current_track_id} already stable - not re-adding")
|
||||
else:
|
||||
# No car detected - ALWAYS clear all tracking and reset counters
|
||||
logger.debug(f"🚫 Camera {camera_id}: NO CAR DETECTED - clearing all tracking")
|
||||
if track_counters or stable_tracks:
|
||||
logger.debug(f"🚫 Camera {camera_id}: Existing state before reset: counters={dict(track_counters)}, stable={list(stable_tracks)}")
|
||||
for track_id in list(track_counters.keys()):
|
||||
old_count = track_counters.pop(track_id, 0)
|
||||
logger.info(f"🔄 Camera {camera_id}: No car detected - RESET track {track_id} counter from {old_count} to 0")
|
||||
track_counters.clear() # Ensure complete reset
|
||||
stable_tracks.clear() # Clear all stable tracks
|
||||
logger.info(f"✅ Camera {camera_id}: RESET TO VALIDATION PHASE - All counters and stable tracks cleared")
|
||||
else:
|
||||
logger.debug(f"🚫 Camera {camera_id}: No existing counters to clear")
|
||||
logger.debug(f"Camera {camera_id}: VALIDATION - no car detected (all counters reset)")
|
||||
|
||||
# Final return - validation not complete
|
||||
result = ValidationResult(
|
||||
validation_complete=False,
|
||||
stable_tracks=list(stable_tracks),
|
||||
current_tracks=[current_track_id] if current_track_id is not None else []
|
||||
)
|
||||
|
||||
logger.debug(f"📋 Camera {camera_id}: Track stability result: {result.to_dict()}")
|
||||
logger.debug(f"📋 Camera {camera_id}: Final counters: {dict(track_counters)}")
|
||||
logger.debug(f"📋 Camera {camera_id}: Final stable tracks: {list(stable_tracks)}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class YOLODetector:
|
||||
"""
|
||||
YOLO object detector with BoT-SORT tracking integration.
|
||||
|
||||
Provides structured detection functionality with track stability validation
|
||||
and multi-class detection support.
|
||||
"""
|
||||
|
||||
def __init__(self, stability_tracker: Optional[StabilityTracker] = None):
|
||||
"""
|
||||
Initialize YOLO detector.
|
||||
|
||||
Args:
|
||||
stability_tracker: Optional stability tracker instance
|
||||
"""
|
||||
self.stability_tracker = stability_tracker or StabilityTracker()
|
||||
|
||||
def _extract_tracking_config(self, node: Dict[str, Any]) -> TrackingConfig:
|
||||
"""Extract tracking configuration from pipeline node."""
|
||||
tracking_config = node.get("tracking", {})
|
||||
model_id = node.get("modelId", "")
|
||||
|
||||
# Don't use tracking for frontal detection models
|
||||
use_tracking = not ("frontal" in model_id.lower() or "detection" in model_id.lower())
|
||||
|
||||
return TrackingConfig(
|
||||
enabled=tracking_config.get("enabled", True),
|
||||
method=tracking_config.get("method", "botsort"),
|
||||
reid_config_path=tracking_config.get("reidConfig", tracking_config.get("reidConfigPath", "botsort.yaml")),
|
||||
stability_threshold=tracking_config.get("stabilityThreshold", node.get("stabilityThreshold", DEFAULT_STABILITY_THRESHOLD)),
|
||||
use_tracking_for_model=use_tracking
|
||||
)
|
||||
|
||||
def _determine_confidence_threshold(self, node: Dict[str, Any]) -> float:
|
||||
"""Determine confidence threshold based on model type."""
|
||||
model_id = node.get("modelId", "")
|
||||
|
||||
# Use frontalMinConfidence for frontal detection models
|
||||
if "frontal" in model_id.lower() and "frontalMinConfidence" in node:
|
||||
min_confidence = node.get("frontalMinConfidence", DEFAULT_MIN_CONFIDENCE)
|
||||
logger.debug(f"Using frontalMinConfidence={min_confidence} for {model_id}")
|
||||
else:
|
||||
min_confidence = node.get("minConfidence", DEFAULT_MIN_CONFIDENCE)
|
||||
|
||||
return min_confidence
|
||||
|
||||
def _reset_yolo_tracker_if_needed(self, node: Dict[str, Any], camera_id: str, model_id: str) -> None:
|
||||
"""Reset YOLO tracker if flagged for reset."""
|
||||
stability_data = self.stability_tracker.get_camera_stability_data(camera_id, model_id)
|
||||
session_state = stability_data.session_state
|
||||
|
||||
if session_state.get("reset_tracker_on_resume", False):
|
||||
# Reset YOLO tracker to get fresh track IDs
|
||||
if hasattr(node["model"], 'trackers') and node["model"].trackers:
|
||||
node["model"].trackers.clear() # Clear tracker state
|
||||
logger.info(f"Camera {camera_id}: 🔄 Reset YOLO tracker - new cars will get fresh track IDs")
|
||||
session_state["reset_tracker_on_resume"] = False # Clear the flag
|
||||
|
||||
def _run_yolo_inference(self,
|
||||
frame,
|
||||
node: Dict[str, Any],
|
||||
tracking_config: TrackingConfig) -> Any:
|
||||
"""Run YOLO inference with or without tracking."""
|
||||
# Prepare class filtering
|
||||
trigger_class_indices = node.get("triggerClassIndices")
|
||||
class_filter = {"classes": trigger_class_indices} if trigger_class_indices else {}
|
||||
|
||||
model_id = node.get("modelId", "")
|
||||
use_tracking = tracking_config.enabled and tracking_config.use_tracking_for_model
|
||||
|
||||
logger.debug(f"Running detection for {model_id} - tracking: {use_tracking}, stability_threshold: {tracking_config.stability_threshold}, classes: {node.get('triggerClasses', 'all')}")
|
||||
|
||||
if use_tracking:
|
||||
# Use tracking for main detection models (yolo11m, etc.)
|
||||
logger.debug(f"Using tracking for {model_id}")
|
||||
res = node["model"].track(
|
||||
frame,
|
||||
stream=False,
|
||||
persist=True,
|
||||
**class_filter
|
||||
)[0]
|
||||
else:
|
||||
# Use detection only for frontal detection and other detection-only models
|
||||
logger.debug(f"Using prediction only for {model_id}")
|
||||
res = node["model"].predict(
|
||||
frame,
|
||||
stream=False,
|
||||
**class_filter
|
||||
)[0]
|
||||
|
||||
return res
|
||||
|
||||
def _process_detections(self,
|
||||
res: Any,
|
||||
node: Dict[str, Any],
|
||||
camera_id: str,
|
||||
min_confidence: float) -> List[Dict[str, Any]]:
|
||||
"""Process YOLO detection results into candidate detections."""
|
||||
candidate_detections = []
|
||||
|
||||
if res.boxes is None or len(res.boxes) == 0:
|
||||
logger.debug(f"🚫 Camera {camera_id}: YOLO returned no detections")
|
||||
return candidate_detections
|
||||
|
||||
logger.debug(f"🔍 Camera {camera_id}: YOLO detected {len(res.boxes)} raw objects - processing with tracking...")
|
||||
logger.debug(f"🔍 Camera {camera_id}: === DETECTION ANALYSIS ===")
|
||||
|
||||
for i, box in enumerate(res.boxes):
|
||||
# Extract detection data
|
||||
conf = float(box.cpu().conf[0])
|
||||
cls_id = int(box.cpu().cls[0])
|
||||
class_name = node["model"].names[cls_id]
|
||||
|
||||
# Extract bounding box
|
||||
xy = box.cpu().xyxy[0]
|
||||
x1, y1, x2, y2 = map(int, xy)
|
||||
bbox = (x1, y1, x2, y2)
|
||||
|
||||
# Extract tracking ID if available
|
||||
track_id = None
|
||||
if hasattr(box, "id") and box.id is not None:
|
||||
track_id = int(box.id.item())
|
||||
|
||||
logger.debug(f"🔍 Camera {camera_id}: Detection {i+1}: class='{class_name}' conf={conf:.3f} track_id={track_id} bbox={bbox}")
|
||||
|
||||
# Apply confidence filtering
|
||||
if conf < min_confidence:
|
||||
logger.debug(f"❌ Camera {camera_id}: Detection {i+1} REJECTED - confidence {conf:.3f} < {min_confidence}")
|
||||
continue
|
||||
|
||||
# Create detection object
|
||||
detection = {
|
||||
"class": class_name,
|
||||
"confidence": conf,
|
||||
"id": track_id,
|
||||
"bbox": bbox,
|
||||
"class_id": cls_id
|
||||
}
|
||||
|
||||
candidate_detections.append(detection)
|
||||
logger.debug(f"✅ Camera {camera_id}: Detection {i+1} ACCEPTED as candidate: {class_name} (conf={conf:.3f}, track_id={track_id})")
|
||||
|
||||
return candidate_detections
|
||||
|
||||
def _select_best_detection(self,
|
||||
candidate_detections: List[Dict[str, Any]],
|
||||
camera_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Select the highest confidence detection from candidates."""
|
||||
if not candidate_detections:
|
||||
logger.debug(f"🚫 Camera {camera_id}: No valid candidates after filtering - no car will be tracked")
|
||||
return None
|
||||
|
||||
logger.debug(f"🏆 Camera {camera_id}: === SELECTING HIGHEST CONFIDENCE CAR ===")
|
||||
for i, detection in enumerate(candidate_detections):
|
||||
logger.debug(f"🏆 Camera {camera_id}: Candidate {i+1}: {detection['class']} conf={detection['confidence']:.3f} track_id={detection['id']}")
|
||||
|
||||
# Find the single highest confidence detection across all detected classes
|
||||
best_detection = max(candidate_detections, key=lambda x: x["confidence"])
|
||||
|
||||
logger.info(f"🎯 Camera {camera_id}: SELECTED WINNER: {best_detection['class']} (conf={best_detection['confidence']:.3f}, track_id={best_detection['id']}, bbox={best_detection['bbox']})")
|
||||
|
||||
# Show which cars were NOT selected
|
||||
for detection in candidate_detections:
|
||||
if detection != best_detection:
|
||||
logger.debug(f"🚫 Camera {camera_id}: NOT SELECTED: {detection['class']} (conf={detection['confidence']:.3f}, track_id={detection['id']}) - lower confidence")
|
||||
|
||||
return best_detection
|
||||
|
||||
def _apply_class_mapping(self, detection: Dict[str, Any], node: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Apply class mapping if configured."""
|
||||
original_class = detection["class"]
|
||||
class_mapping = node.get("classMapping", {})
|
||||
|
||||
if original_class in class_mapping:
|
||||
mapped_class = class_mapping[original_class]
|
||||
logger.info(f"Class mapping applied: {original_class} → {mapped_class}")
|
||||
# Update the detection object with mapped class
|
||||
detection["class"] = mapped_class
|
||||
detection["original_class"] = original_class # Keep original for reference
|
||||
|
||||
return detection
|
||||
|
||||
def _validate_multi_class(self,
|
||||
regions_dict: Dict[str, Any],
|
||||
node: Dict[str, Any]) -> bool:
|
||||
"""Validate multi-class detection requirements."""
|
||||
if not (node.get("multiClass", False) and node.get("expectedClasses")):
|
||||
return True
|
||||
|
||||
expected_classes = node["expectedClasses"]
|
||||
detected_classes = list(regions_dict.keys())
|
||||
|
||||
logger.debug(f"Multi-class validation: expected={expected_classes}, detected={detected_classes}")
|
||||
|
||||
# Check for required classes (flexible - at least one must match)
|
||||
matching_classes = [cls for cls in expected_classes if cls in detected_classes]
|
||||
if not matching_classes:
|
||||
logger.warning(f"Multi-class validation failed: no expected classes detected")
|
||||
return False
|
||||
|
||||
logger.info(f"Multi-class validation passed: {matching_classes} detected")
|
||||
return True
|
||||
|
||||
def run_detection_with_tracking(self,
|
||||
frame,
|
||||
node: Dict[str, Any],
|
||||
context: Optional[Dict[str, Any]] = None) -> Tuple[List[Dict[str, Any]], Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
Run YOLO detection with BoT-SORT tracking and stability validation.
|
||||
|
||||
Args:
|
||||
frame: Input frame/image
|
||||
node: Pipeline node configuration with model and settings
|
||||
context: Optional context information (camera info, session data, etc.)
|
||||
|
||||
Returns:
|
||||
tuple: (all_detections, regions_dict, track_validation_result) where:
|
||||
- all_detections: List of all detection objects
|
||||
- regions_dict: Dict mapping class names to highest confidence detections
|
||||
- track_validation_result: Dict with validation status and stable tracks
|
||||
"""
|
||||
try:
|
||||
camera_id = context.get("camera_id", "unknown") if context else "unknown"
|
||||
model_id = node.get("modelId", "unknown")
|
||||
current_mode = context.get("current_mode", "unknown") if context else "unknown"
|
||||
|
||||
# Extract configuration
|
||||
tracking_config = self._extract_tracking_config(node)
|
||||
min_confidence = self._determine_confidence_threshold(node)
|
||||
|
||||
# Check if we need to reset tracker after cooldown
|
||||
self._reset_yolo_tracker_if_needed(node, camera_id, model_id)
|
||||
|
||||
# Run YOLO inference
|
||||
res = self._run_yolo_inference(frame, node, tracking_config)
|
||||
|
||||
# Process detection results
|
||||
candidate_detections = self._process_detections(res, node, camera_id, min_confidence)
|
||||
|
||||
# Select best detection
|
||||
best_detection = self._select_best_detection(candidate_detections, camera_id)
|
||||
|
||||
# Update track stability validation
|
||||
is_branch_node = node.get("cropClass") is not None or node.get("parallel") is True
|
||||
track_validation_result = self.stability_tracker.update_single_track_stability(
|
||||
detection=best_detection,
|
||||
camera_id=camera_id,
|
||||
model_id=model_id,
|
||||
stability_threshold=tracking_config.stability_threshold,
|
||||
current_mode=current_mode,
|
||||
is_branch_node=is_branch_node
|
||||
)
|
||||
|
||||
# Handle no detection case
|
||||
if best_detection is None:
|
||||
# Store validation state in context for pipeline decisions
|
||||
if context is not None:
|
||||
context["track_validation_result"] = track_validation_result.to_dict()
|
||||
|
||||
return [], {}, track_validation_result.to_dict()
|
||||
|
||||
# Apply class mapping
|
||||
best_detection = self._apply_class_mapping(best_detection, node)
|
||||
|
||||
# Create regions dictionary
|
||||
mapped_class = best_detection["class"]
|
||||
track_id = best_detection["id"]
|
||||
|
||||
all_detections = [best_detection]
|
||||
regions_dict = {
|
||||
mapped_class: {
|
||||
"bbox": best_detection["bbox"],
|
||||
"confidence": best_detection["confidence"],
|
||||
"detection": best_detection,
|
||||
"track_id": track_id
|
||||
}
|
||||
}
|
||||
|
||||
# Multi-class validation
|
||||
if not self._validate_multi_class(regions_dict, node):
|
||||
return [], {}, track_validation_result.to_dict()
|
||||
|
||||
logger.info(f"✅ Camera {camera_id}: DETECTION COMPLETE - tracking single car: track_id={track_id}, conf={best_detection['confidence']:.3f}")
|
||||
logger.debug(f"📊 Camera {camera_id}: Detection summary: {len(res.boxes)} raw → {len(candidate_detections)} candidates → 1 selected")
|
||||
|
||||
# Store validation state in context for pipeline decisions
|
||||
if context is not None:
|
||||
context["track_validation_result"] = track_validation_result.to_dict()
|
||||
|
||||
return all_detections, regions_dict, track_validation_result.to_dict()
|
||||
|
||||
except Exception as e:
|
||||
camera_id = context.get("camera_id", "unknown") if context else "unknown"
|
||||
model_id = node.get("modelId", "unknown")
|
||||
raise create_detection_error(camera_id, model_id, "detection_with_tracking", e)
|
||||
|
||||
|
||||
# Global instances for backward compatibility
|
||||
_global_stability_tracker = StabilityTracker()
|
||||
_global_yolo_detector = YOLODetector(_global_stability_tracker)
|
||||
|
||||
|
||||
# ===== CONVENIENCE FUNCTIONS =====
|
||||
# These provide the same interface as the original functions in pympta.py
|
||||
|
||||
def run_detection_with_tracking(frame, node: Dict[str, Any], context: Optional[Dict[str, Any]] = None) -> Tuple[List[Dict[str, Any]], Dict[str, Any], Dict[str, Any]]:
|
||||
"""Run YOLO detection with BoT-SORT tracking using global detector instance."""
|
||||
return _global_yolo_detector.run_detection_with_tracking(frame, node, context)
|
||||
|
||||
|
||||
def get_camera_stability_data(camera_id: str, model_id: str) -> Dict[str, Any]:
|
||||
"""Get stability tracking data for a specific camera and model."""
|
||||
stability_data = _global_stability_tracker.get_camera_stability_data(camera_id, model_id)
|
||||
return stability_data.to_dict()
|
||||
|
||||
|
||||
def reset_camera_stability_tracking(camera_id: str, model_id: str) -> None:
|
||||
"""Reset all stability tracking data for a specific camera and model."""
|
||||
_global_stability_tracker.reset_camera_stability_tracking(camera_id, model_id)
|
||||
|
||||
|
||||
def update_single_track_stability(node: Dict[str, Any],
|
||||
detection: Optional[Dict[str, Any]],
|
||||
camera_id: str,
|
||||
frame_shape=None,
|
||||
stability_threshold: int = DEFAULT_STABILITY_THRESHOLD,
|
||||
context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""Update track stability validation for a single highest confidence detection."""
|
||||
model_id = node.get("modelId", "unknown")
|
||||
current_mode = context.get("current_mode", "unknown") if context else "unknown"
|
||||
is_branch_node = node.get("cropClass") is not None or node.get("parallel") is True
|
||||
|
||||
result = _global_stability_tracker.update_single_track_stability(
|
||||
detection=detection,
|
||||
camera_id=camera_id,
|
||||
model_id=model_id,
|
||||
stability_threshold=stability_threshold,
|
||||
current_mode=current_mode,
|
||||
is_branch_node=is_branch_node
|
||||
)
|
||||
|
||||
return result.to_dict()
|
||||
|
||||
|
||||
def reset_tracking_state(camera_id: str, model_id: str, reason: str = "reset_requested") -> None:
|
||||
"""Reset tracking state for a camera and model."""
|
||||
logger.info(f"🔄 Camera {camera_id}: Resetting tracking state for {model_id} - reason: {reason}")
|
||||
reset_camera_stability_tracking(camera_id, model_id)
|
Loading…
Add table
Add a link
Reference in a new issue