Refactor: Phase 1: High-Level Restructuring
This commit is contained in:
parent
21700dce52
commit
96bedae80a
13 changed files with 787 additions and 0 deletions
271
detector_worker/detection/detection_result.py
Normal file
271
detector_worker/detection/detection_result.py
Normal file
|
@ -0,0 +1,271 @@
|
|||
"""
|
||||
Detection result data structures and models.
|
||||
|
||||
This module defines the data structures used to represent detection
|
||||
results throughout the system.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class BoundingBox:
|
||||
"""Represents a bounding box for detected objects."""
|
||||
x1: int
|
||||
y1: int
|
||||
x2: int
|
||||
y2: int
|
||||
|
||||
@property
|
||||
def width(self) -> int:
|
||||
return self.x2 - self.x1
|
||||
|
||||
@property
|
||||
def height(self) -> int:
|
||||
return self.y2 - self.y1
|
||||
|
||||
@property
|
||||
def area(self) -> int:
|
||||
return self.width * self.height
|
||||
|
||||
@property
|
||||
def center(self) -> Tuple[int, int]:
|
||||
return (self.x1 + self.width // 2, self.y1 + self.height // 2)
|
||||
|
||||
def to_list(self) -> List[int]:
|
||||
"""Convert to [x1, y1, x2, y2] format."""
|
||||
return [self.x1, self.y1, self.x2, self.y2]
|
||||
|
||||
def to_xyxy(self) -> Tuple[int, int, int, int]:
|
||||
"""Convert to (x1, y1, x2, y2) tuple."""
|
||||
return (self.x1, self.y1, self.x2, self.y2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
"""Represents a single detection result."""
|
||||
class_name: str
|
||||
confidence: float
|
||||
bbox: BoundingBox
|
||||
track_id: Optional[int] = None
|
||||
class_id: Optional[int] = None
|
||||
original_class: Optional[str] = None # For class mapping
|
||||
|
||||
# Additional detection metadata
|
||||
model_id: Optional[str] = None
|
||||
timestamp: Optional[datetime] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
result = {
|
||||
"class": self.class_name,
|
||||
"confidence": self.confidence,
|
||||
"bbox": self.bbox.to_list()
|
||||
}
|
||||
|
||||
if self.track_id is not None:
|
||||
result["id"] = self.track_id
|
||||
if self.class_id is not None:
|
||||
result["class_id"] = self.class_id
|
||||
if self.original_class is not None:
|
||||
result["original_class"] = self.original_class
|
||||
if self.model_id is not None:
|
||||
result["model_id"] = self.model_id
|
||||
if self.timestamp is not None:
|
||||
result["timestamp"] = self.timestamp.isoformat()
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'DetectionResult':
|
||||
"""Create DetectionResult from dictionary."""
|
||||
bbox_data = data["bbox"]
|
||||
if isinstance(bbox_data, list) and len(bbox_data) == 4:
|
||||
bbox = BoundingBox(bbox_data[0], bbox_data[1], bbox_data[2], bbox_data[3])
|
||||
else:
|
||||
raise ValueError(f"Invalid bbox format: {bbox_data}")
|
||||
|
||||
timestamp = None
|
||||
if "timestamp" in data:
|
||||
if isinstance(data["timestamp"], str):
|
||||
timestamp = datetime.fromisoformat(data["timestamp"])
|
||||
elif isinstance(data["timestamp"], datetime):
|
||||
timestamp = data["timestamp"]
|
||||
|
||||
return cls(
|
||||
class_name=data["class"],
|
||||
confidence=data["confidence"],
|
||||
bbox=bbox,
|
||||
track_id=data.get("id"),
|
||||
class_id=data.get("class_id"),
|
||||
original_class=data.get("original_class"),
|
||||
model_id=data.get("model_id"),
|
||||
timestamp=timestamp
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegionData:
|
||||
"""Represents detection data for a specific region/class."""
|
||||
bbox: BoundingBox
|
||||
confidence: float
|
||||
detection: DetectionResult
|
||||
track_id: Optional[int] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"bbox": self.bbox.to_list(),
|
||||
"confidence": self.confidence,
|
||||
"detection": self.detection.to_dict(),
|
||||
"track_id": self.track_id
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrackValidationResult:
|
||||
"""Represents track stability validation results."""
|
||||
validation_complete: bool
|
||||
stable_tracks: List[int] = field(default_factory=list)
|
||||
current_tracks: List[int] = field(default_factory=list)
|
||||
newly_stable_tracks: List[int] = field(default_factory=list)
|
||||
|
||||
# Additional metadata
|
||||
send_none_detection: bool = False
|
||||
branch_node: bool = False
|
||||
bypass_validation: bool = False
|
||||
awaiting_stable_tracks: bool = False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"validation_complete": self.validation_complete,
|
||||
"stable_tracks": self.stable_tracks.copy(),
|
||||
"current_tracks": self.current_tracks.copy(),
|
||||
"newly_stable_tracks": self.newly_stable_tracks.copy(),
|
||||
"send_none_detection": self.send_none_detection,
|
||||
"branch_node": self.branch_node,
|
||||
"bypass_validation": self.bypass_validation,
|
||||
"awaiting_stable_tracks": self.awaiting_stable_tracks
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionSession:
|
||||
"""Represents a detection session with multiple results."""
|
||||
detections: List[DetectionResult] = field(default_factory=list)
|
||||
regions: Dict[str, RegionData] = field(default_factory=dict)
|
||||
validation_result: Optional[TrackValidationResult] = None
|
||||
|
||||
# Session metadata
|
||||
camera_id: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
backend_session_id: Optional[str] = None
|
||||
timestamp: Optional[datetime] = None
|
||||
|
||||
# Branch results for pipeline processing
|
||||
branch_results: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def get_primary_detection(self) -> Optional[DetectionResult]:
|
||||
"""Get the highest confidence detection."""
|
||||
if not self.detections:
|
||||
return None
|
||||
return max(self.detections, key=lambda x: x.confidence)
|
||||
|
||||
def get_detection_by_class(self, class_name: str) -> Optional[DetectionResult]:
|
||||
"""Get detection for specific class."""
|
||||
for detection in self.detections:
|
||||
if detection.class_name == class_name:
|
||||
return detection
|
||||
return None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
result = {
|
||||
"detections": [d.to_dict() for d in self.detections],
|
||||
"regions": {k: v.to_dict() for k, v in self.regions.items()},
|
||||
"branch_results": self.branch_results.copy()
|
||||
}
|
||||
|
||||
if self.validation_result:
|
||||
result["validation_result"] = self.validation_result.to_dict()
|
||||
if self.camera_id:
|
||||
result["camera_id"] = self.camera_id
|
||||
if self.session_id:
|
||||
result["session_id"] = self.session_id
|
||||
if self.backend_session_id:
|
||||
result["backend_session_id"] = self.backend_session_id
|
||||
if self.timestamp:
|
||||
result["timestamp"] = self.timestamp.isoformat()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class LightweightDetectionResult:
|
||||
"""Result from lightweight detection for validation."""
|
||||
validation_passed: bool
|
||||
class_name: Optional[str] = None
|
||||
confidence: Optional[float] = None
|
||||
bbox: Optional[BoundingBox] = None
|
||||
bbox_area_ratio: Optional[float] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
result = {"validation_passed": self.validation_passed}
|
||||
|
||||
if self.class_name:
|
||||
result["class"] = self.class_name
|
||||
if self.confidence is not None:
|
||||
result["confidence"] = self.confidence
|
||||
if self.bbox:
|
||||
result["bbox"] = self.bbox.to_list()
|
||||
if self.bbox_area_ratio is not None:
|
||||
result["bbox_area_ratio"] = self.bbox_area_ratio
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ===== HELPER FUNCTIONS =====
|
||||
|
||||
def create_none_detection() -> DetectionResult:
|
||||
"""Create a 'none' detection result."""
|
||||
return DetectionResult(
|
||||
class_name="none",
|
||||
confidence=1.0,
|
||||
bbox=BoundingBox(0, 0, 0, 0),
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
|
||||
def create_regions_dict_from_detections(detections: List[DetectionResult]) -> Dict[str, RegionData]:
|
||||
"""Create regions dictionary from detection list."""
|
||||
regions = {}
|
||||
|
||||
for detection in detections:
|
||||
region_data = RegionData(
|
||||
bbox=detection.bbox,
|
||||
confidence=detection.confidence,
|
||||
detection=detection,
|
||||
track_id=detection.track_id
|
||||
)
|
||||
regions[detection.class_name] = region_data
|
||||
|
||||
return regions
|
||||
|
||||
|
||||
def merge_detection_results(results: List[DetectionResult]) -> List[DetectionResult]:
|
||||
"""Merge multiple detection results, keeping highest confidence per class."""
|
||||
class_detections = {}
|
||||
|
||||
for result in results:
|
||||
if result.class_name not in class_detections:
|
||||
class_detections[result.class_name] = result
|
||||
else:
|
||||
# Keep higher confidence detection
|
||||
if result.confidence > class_detections[result.class_name].confidence:
|
||||
class_detections[result.class_name] = result
|
||||
|
||||
return list(class_detections.values())
|
Loading…
Add table
Add a link
Reference in a new issue