271 lines
No EOL
8.8 KiB
Python
271 lines
No EOL
8.8 KiB
Python
"""
|
|
Detection result data structures and models.
|
|
|
|
This module defines the data structures used to represent detection
|
|
results throughout the system.
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import List, Optional, Dict, Any, Tuple
|
|
from datetime import datetime
|
|
|
|
|
|
@dataclass
|
|
class BoundingBox:
|
|
"""Represents a bounding box for detected objects."""
|
|
x1: int
|
|
y1: int
|
|
x2: int
|
|
y2: int
|
|
|
|
@property
|
|
def width(self) -> int:
|
|
return self.x2 - self.x1
|
|
|
|
@property
|
|
def height(self) -> int:
|
|
return self.y2 - self.y1
|
|
|
|
@property
|
|
def area(self) -> int:
|
|
return self.width * self.height
|
|
|
|
@property
|
|
def center(self) -> Tuple[int, int]:
|
|
return (self.x1 + self.width // 2, self.y1 + self.height // 2)
|
|
|
|
def to_list(self) -> List[int]:
|
|
"""Convert to [x1, y1, x2, y2] format."""
|
|
return [self.x1, self.y1, self.x2, self.y2]
|
|
|
|
def to_xyxy(self) -> Tuple[int, int, int, int]:
|
|
"""Convert to (x1, y1, x2, y2) tuple."""
|
|
return (self.x1, self.y1, self.x2, self.y2)
|
|
|
|
|
|
@dataclass
|
|
class DetectionResult:
|
|
"""Represents a single detection result."""
|
|
class_name: str
|
|
confidence: float
|
|
bbox: BoundingBox
|
|
track_id: Optional[int] = None
|
|
class_id: Optional[int] = None
|
|
original_class: Optional[str] = None # For class mapping
|
|
|
|
# Additional detection metadata
|
|
model_id: Optional[str] = None
|
|
timestamp: Optional[datetime] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary format."""
|
|
result = {
|
|
"class": self.class_name,
|
|
"confidence": self.confidence,
|
|
"bbox": self.bbox.to_list()
|
|
}
|
|
|
|
if self.track_id is not None:
|
|
result["id"] = self.track_id
|
|
if self.class_id is not None:
|
|
result["class_id"] = self.class_id
|
|
if self.original_class is not None:
|
|
result["original_class"] = self.original_class
|
|
if self.model_id is not None:
|
|
result["model_id"] = self.model_id
|
|
if self.timestamp is not None:
|
|
result["timestamp"] = self.timestamp.isoformat()
|
|
|
|
return result
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'DetectionResult':
|
|
"""Create DetectionResult from dictionary."""
|
|
bbox_data = data["bbox"]
|
|
if isinstance(bbox_data, list) and len(bbox_data) == 4:
|
|
bbox = BoundingBox(bbox_data[0], bbox_data[1], bbox_data[2], bbox_data[3])
|
|
else:
|
|
raise ValueError(f"Invalid bbox format: {bbox_data}")
|
|
|
|
timestamp = None
|
|
if "timestamp" in data:
|
|
if isinstance(data["timestamp"], str):
|
|
timestamp = datetime.fromisoformat(data["timestamp"])
|
|
elif isinstance(data["timestamp"], datetime):
|
|
timestamp = data["timestamp"]
|
|
|
|
return cls(
|
|
class_name=data["class"],
|
|
confidence=data["confidence"],
|
|
bbox=bbox,
|
|
track_id=data.get("id"),
|
|
class_id=data.get("class_id"),
|
|
original_class=data.get("original_class"),
|
|
model_id=data.get("model_id"),
|
|
timestamp=timestamp
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class RegionData:
|
|
"""Represents detection data for a specific region/class."""
|
|
bbox: BoundingBox
|
|
confidence: float
|
|
detection: DetectionResult
|
|
track_id: Optional[int] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary format."""
|
|
return {
|
|
"bbox": self.bbox.to_list(),
|
|
"confidence": self.confidence,
|
|
"detection": self.detection.to_dict(),
|
|
"track_id": self.track_id
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class TrackValidationResult:
|
|
"""Represents track stability validation results."""
|
|
validation_complete: bool
|
|
stable_tracks: List[int] = field(default_factory=list)
|
|
current_tracks: List[int] = field(default_factory=list)
|
|
newly_stable_tracks: List[int] = field(default_factory=list)
|
|
|
|
# Additional metadata
|
|
send_none_detection: bool = False
|
|
branch_node: bool = False
|
|
bypass_validation: bool = False
|
|
awaiting_stable_tracks: bool = False
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary format."""
|
|
return {
|
|
"validation_complete": self.validation_complete,
|
|
"stable_tracks": self.stable_tracks.copy(),
|
|
"current_tracks": self.current_tracks.copy(),
|
|
"newly_stable_tracks": self.newly_stable_tracks.copy(),
|
|
"send_none_detection": self.send_none_detection,
|
|
"branch_node": self.branch_node,
|
|
"bypass_validation": self.bypass_validation,
|
|
"awaiting_stable_tracks": self.awaiting_stable_tracks
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class DetectionSession:
|
|
"""Represents a detection session with multiple results."""
|
|
detections: List[DetectionResult] = field(default_factory=list)
|
|
regions: Dict[str, RegionData] = field(default_factory=dict)
|
|
validation_result: Optional[TrackValidationResult] = None
|
|
|
|
# Session metadata
|
|
camera_id: Optional[str] = None
|
|
session_id: Optional[str] = None
|
|
backend_session_id: Optional[str] = None
|
|
timestamp: Optional[datetime] = None
|
|
|
|
# Branch results for pipeline processing
|
|
branch_results: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
def get_primary_detection(self) -> Optional[DetectionResult]:
|
|
"""Get the highest confidence detection."""
|
|
if not self.detections:
|
|
return None
|
|
return max(self.detections, key=lambda x: x.confidence)
|
|
|
|
def get_detection_by_class(self, class_name: str) -> Optional[DetectionResult]:
|
|
"""Get detection for specific class."""
|
|
for detection in self.detections:
|
|
if detection.class_name == class_name:
|
|
return detection
|
|
return None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary format."""
|
|
result = {
|
|
"detections": [d.to_dict() for d in self.detections],
|
|
"regions": {k: v.to_dict() for k, v in self.regions.items()},
|
|
"branch_results": self.branch_results.copy()
|
|
}
|
|
|
|
if self.validation_result:
|
|
result["validation_result"] = self.validation_result.to_dict()
|
|
if self.camera_id:
|
|
result["camera_id"] = self.camera_id
|
|
if self.session_id:
|
|
result["session_id"] = self.session_id
|
|
if self.backend_session_id:
|
|
result["backend_session_id"] = self.backend_session_id
|
|
if self.timestamp:
|
|
result["timestamp"] = self.timestamp.isoformat()
|
|
|
|
return result
|
|
|
|
|
|
@dataclass
|
|
class LightweightDetectionResult:
|
|
"""Result from lightweight detection for validation."""
|
|
validation_passed: bool
|
|
class_name: Optional[str] = None
|
|
confidence: Optional[float] = None
|
|
bbox: Optional[BoundingBox] = None
|
|
bbox_area_ratio: Optional[float] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary format."""
|
|
result = {"validation_passed": self.validation_passed}
|
|
|
|
if self.class_name:
|
|
result["class"] = self.class_name
|
|
if self.confidence is not None:
|
|
result["confidence"] = self.confidence
|
|
if self.bbox:
|
|
result["bbox"] = self.bbox.to_list()
|
|
if self.bbox_area_ratio is not None:
|
|
result["bbox_area_ratio"] = self.bbox_area_ratio
|
|
|
|
return result
|
|
|
|
|
|
# ===== HELPER FUNCTIONS =====
|
|
|
|
def create_none_detection() -> DetectionResult:
|
|
"""Create a 'none' detection result."""
|
|
return DetectionResult(
|
|
class_name="none",
|
|
confidence=1.0,
|
|
bbox=BoundingBox(0, 0, 0, 0),
|
|
timestamp=datetime.now()
|
|
)
|
|
|
|
|
|
def create_regions_dict_from_detections(detections: List[DetectionResult]) -> Dict[str, RegionData]:
|
|
"""Create regions dictionary from detection list."""
|
|
regions = {}
|
|
|
|
for detection in detections:
|
|
region_data = RegionData(
|
|
bbox=detection.bbox,
|
|
confidence=detection.confidence,
|
|
detection=detection,
|
|
track_id=detection.track_id
|
|
)
|
|
regions[detection.class_name] = region_data
|
|
|
|
return regions
|
|
|
|
|
|
def merge_detection_results(results: List[DetectionResult]) -> List[DetectionResult]:
|
|
"""Merge multiple detection results, keeping highest confidence per class."""
|
|
class_detections = {}
|
|
|
|
for result in results:
|
|
if result.class_name not in class_detections:
|
|
class_detections[result.class_name] = result
|
|
else:
|
|
# Keep higher confidence detection
|
|
if result.confidence > class_detections[result.class_name].confidence:
|
|
class_detections[result.class_name] = result
|
|
|
|
return list(class_detections.values()) |