Refactor: Phase 1: High-Level Restructuring
This commit is contained in:
parent
21700dce52
commit
96bedae80a
13 changed files with 787 additions and 0 deletions
17
detector_worker/__init__.py
Normal file
17
detector_worker/__init__.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
"""
|
||||
Detector Worker - Refactored FastAPI Computer Vision Detection System
|
||||
|
||||
This package contains the refactored detector worker system with modular architecture:
|
||||
|
||||
- core: Configuration, constants, and exceptions
|
||||
- models: Model loading and pipeline management
|
||||
- detection: YOLO detection, tracking, and validation
|
||||
- pipeline: Pipeline execution and action processing
|
||||
- streams: Video stream management and frame processing
|
||||
- communication: WebSocket and message handling
|
||||
- storage: Database, Redis, and session management
|
||||
- utils: Utility functions and helpers
|
||||
"""
|
||||
|
||||
__version__ = "2.0.0"
|
||||
__author__ = "Detector Worker Team"
|
9
detector_worker/communication/__init__.py
Normal file
9
detector_worker/communication/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
"""
|
||||
Communication module - WebSocket and message handling
|
||||
|
||||
This module handles:
|
||||
- WebSocket connection management
|
||||
- Message parsing and routing
|
||||
- Response formatting
|
||||
- Real-time communication protocols
|
||||
"""
|
11
detector_worker/core/__init__.py
Normal file
11
detector_worker/core/__init__.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
"""
|
||||
Core module - Configuration, constants, and exceptions
|
||||
|
||||
This module contains the core infrastructure components:
|
||||
- Configuration management
|
||||
- Application constants
|
||||
- Custom exceptions
|
||||
- Base classes and interfaces
|
||||
"""
|
||||
|
||||
# Core exports will be added as modules are implemented
|
115
detector_worker/core/config.py
Normal file
115
detector_worker/core/config.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
"""
|
||||
Configuration management for detector worker.
|
||||
|
||||
This module handles application configuration loading, validation,
|
||||
and provides centralized access to configuration parameters.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectorConfig:
|
||||
"""Configuration class for detector worker parameters."""
|
||||
|
||||
# Frame processing settings
|
||||
poll_interval_ms: int = 100
|
||||
target_fps: int = 10
|
||||
|
||||
# Stream management settings
|
||||
max_streams: int = 5
|
||||
reconnect_interval_sec: int = 5
|
||||
max_retries: int = 3
|
||||
|
||||
# Logging settings
|
||||
log_level: str = "INFO"
|
||||
log_file: str = "detector_worker.log"
|
||||
websocket_log_file: str = "websocket_comm.log"
|
||||
|
||||
@property
|
||||
def poll_interval(self) -> float:
|
||||
"""Calculate poll interval based on target FPS."""
|
||||
return 1000 / self.target_fps if self.target_fps > 0 else self.poll_interval_ms
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""Centralized configuration manager."""
|
||||
|
||||
def __init__(self, config_file: str = "config.json"):
|
||||
self.config_file = config_file
|
||||
self._config: Optional[DetectorConfig] = None
|
||||
|
||||
def load_config(self) -> DetectorConfig:
|
||||
"""Load configuration from file with defaults fallback."""
|
||||
if self._config is not None:
|
||||
return self._config
|
||||
|
||||
config_data = {}
|
||||
|
||||
# Try to load from config file
|
||||
if os.path.exists(self.config_file):
|
||||
try:
|
||||
with open(self.config_file, "r") as f:
|
||||
config_data = json.load(f)
|
||||
logger.info(f"Loaded configuration from {self.config_file}")
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.warning(f"Failed to load config from {self.config_file}: {e}")
|
||||
logger.info("Using default configuration")
|
||||
else:
|
||||
logger.info(f"Config file {self.config_file} not found, using defaults")
|
||||
|
||||
# Create config with defaults + loaded values
|
||||
self._config = DetectorConfig(
|
||||
poll_interval_ms=config_data.get("poll_interval_ms", 100),
|
||||
target_fps=config_data.get("target_fps", 10),
|
||||
max_streams=config_data.get("max_streams", 5),
|
||||
reconnect_interval_sec=config_data.get("reconnect_interval_sec", 5),
|
||||
max_retries=config_data.get("max_retries", 3),
|
||||
log_level=config_data.get("log_level", "INFO"),
|
||||
log_file=config_data.get("log_file", "detector_worker.log"),
|
||||
websocket_log_file=config_data.get("websocket_log_file", "websocket_comm.log")
|
||||
)
|
||||
|
||||
# Log configuration summary
|
||||
self._log_config_summary()
|
||||
|
||||
return self._config
|
||||
|
||||
def _log_config_summary(self):
|
||||
"""Log configuration summary for debugging."""
|
||||
if self._config:
|
||||
logger.info(f"Configuration loaded:")
|
||||
logger.info(f" Target FPS: {self._config.target_fps}")
|
||||
logger.info(f" Poll interval: {self._config.poll_interval}ms")
|
||||
logger.info(f" Max streams: {self._config.max_streams}")
|
||||
logger.info(f" Max retries: {self._config.max_retries}")
|
||||
logger.info(f" Log level: {self._config.log_level}")
|
||||
|
||||
def get_config(self) -> DetectorConfig:
|
||||
"""Get current configuration, loading if necessary."""
|
||||
if self._config is None:
|
||||
return self.load_config()
|
||||
return self._config
|
||||
|
||||
def reload_config(self) -> DetectorConfig:
|
||||
"""Force reload configuration from file."""
|
||||
self._config = None
|
||||
return self.load_config()
|
||||
|
||||
|
||||
# Global config manager instance
|
||||
_config_manager = ConfigManager()
|
||||
|
||||
def get_config() -> DetectorConfig:
|
||||
"""Get the global configuration instance."""
|
||||
return _config_manager.get_config()
|
||||
|
||||
def reload_config() -> DetectorConfig:
|
||||
"""Reload configuration from file."""
|
||||
return _config_manager.reload_config()
|
120
detector_worker/core/constants.py
Normal file
120
detector_worker/core/constants.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
"""
|
||||
Application constants for detector worker.
|
||||
|
||||
This module contains all application-wide constants used throughout
|
||||
the detector worker system.
|
||||
"""
|
||||
|
||||
# ===== TIMING CONSTANTS =====
|
||||
HEARTBEAT_INTERVAL = 2 # seconds between heartbeat messages
|
||||
DEFAULT_POLL_INTERVAL_MS = 100 # default frame polling interval
|
||||
DEFAULT_TARGET_FPS = 10 # default target frames per second
|
||||
|
||||
# ===== SESSION MANAGEMENT =====
|
||||
SESSION_TIMEOUT_SECONDS = 15 # timeout for backend sessionId waiting
|
||||
SESSION_CACHE_TTL_MINUTES = 10 # TTL for cached session data
|
||||
DETECTION_CACHE_CLEANUP_INTERVAL = 300 # seconds between cache cleanup
|
||||
|
||||
# ===== STREAM SETTINGS =====
|
||||
DEFAULT_MAX_STREAMS = 5 # maximum concurrent camera streams
|
||||
DEFAULT_RECONNECT_INTERVAL_SEC = 5 # seconds between reconnection attempts
|
||||
DEFAULT_MAX_RETRIES = 3 # maximum retry attempts (-1 for unlimited)
|
||||
SHARED_STREAM_BUFFER_SIZE = 1 # frames in shared stream buffer
|
||||
|
||||
# ===== DETECTION & TRACKING =====
|
||||
DEFAULT_STABILITY_THRESHOLD = 4 # frames needed for track stability
|
||||
DEFAULT_MIN_CONFIDENCE = 0.5 # minimum detection confidence
|
||||
DEFAULT_MIN_BBOX_AREA_RATIO = 0.0 # minimum bbox area ratio
|
||||
MAX_ABSENCE_FRAMES = 3 # frames without detection before reset
|
||||
|
||||
# ===== PIPELINE PROCESSING =====
|
||||
DEFAULT_THREAD_POOL_SIZE = 4 # max workers for parallel processing
|
||||
CLASSIFICATION_TIMEOUT_SECONDS = 30 # timeout for classification branches
|
||||
PIPELINE_EXECUTION_TIMEOUT = 60 # timeout for full pipeline execution
|
||||
|
||||
# ===== REDIS SETTINGS =====
|
||||
REDIS_CONNECTION_TIMEOUT = 5 # connection timeout in seconds
|
||||
REDIS_SOCKET_TIMEOUT = 5 # socket timeout in seconds
|
||||
REDIS_IMAGE_DEFAULT_QUALITY = 90 # JPEG quality for Redis image storage
|
||||
REDIS_IMAGE_DEFAULT_FORMAT = "jpeg" # default image format
|
||||
|
||||
# ===== DATABASE SETTINGS =====
|
||||
DB_CONNECTION_TIMEOUT = 30 # database connection timeout
|
||||
DB_OPERATION_TIMEOUT = 60 # database operation timeout
|
||||
DB_RETRY_ATTEMPTS = 3 # retry attempts for failed operations
|
||||
|
||||
# ===== LOGGING SETTINGS =====
|
||||
DEFAULT_LOG_LEVEL = "INFO"
|
||||
DEFAULT_LOG_FILE = "detector_worker.log"
|
||||
DEFAULT_WEBSOCKET_LOG_FILE = "websocket_comm.log"
|
||||
LOG_FORMAT = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
||||
|
||||
# ===== HTTP SETTINGS =====
|
||||
HTTP_SNAPSHOT_TIMEOUT = 10 # seconds for HTTP snapshot requests
|
||||
HTTP_REQUEST_RETRIES = 3 # retry attempts for HTTP requests
|
||||
|
||||
# ===== MODEL SETTINGS =====
|
||||
MODEL_LOAD_TIMEOUT = 120 # seconds to load model
|
||||
MODEL_INFERENCE_TIMEOUT = 30 # seconds for single inference
|
||||
MPTA_EXTRACTION_TIMEOUT = 60 # seconds to extract MPTA files
|
||||
|
||||
# ===== VALIDATION THRESHOLDS =====
|
||||
LIGHTWEIGHT_DETECTION_MIN_CONFIDENCE = 0.7
|
||||
LIGHTWEIGHT_DETECTION_MIN_BBOX_AREA_RATIO = 0.3
|
||||
|
||||
# ===== MESSAGE TYPES =====
|
||||
class MessageTypes:
|
||||
"""WebSocket message types."""
|
||||
SUBSCRIBE = "subscribe"
|
||||
UNSUBSCRIBE = "unsubscribe"
|
||||
REQUEST_STATE = "requestState"
|
||||
SET_SESSION_ID = "setSessionId"
|
||||
PATCH_SESSION = "patchSession"
|
||||
SET_PROGRESSION_STAGE = "setProgressionStage"
|
||||
STATE_REPORT = "stateReport"
|
||||
IMAGE_DETECTION = "imageDetection"
|
||||
|
||||
# ===== PROGRESSION STAGES =====
|
||||
class ProgressionStages:
|
||||
"""Pipeline progression stages."""
|
||||
WELCOME = "welcome"
|
||||
CAR_FUELING = "car_fueling"
|
||||
CAR_WAITPAYMENT = "car_waitpayment"
|
||||
CAR_WAIT_STAFF = "car_wait_staff"
|
||||
|
||||
# ===== DETECTION MODES =====
|
||||
class DetectionModes:
|
||||
"""Detection pipeline modes."""
|
||||
VALIDATION_DETECTING = "validation_detecting"
|
||||
LIGHTWEIGHT = "lightweight"
|
||||
FULL_PIPELINE = "full_pipeline"
|
||||
|
||||
# ===== MODEL TASKS =====
|
||||
class ModelTasks:
|
||||
"""YOLO model task types."""
|
||||
DETECT = "detect"
|
||||
CLASSIFY = "classify"
|
||||
SEGMENT = "segment"
|
||||
|
||||
# ===== IMAGE FORMATS =====
|
||||
SUPPORTED_IMAGE_FORMATS = ["jpeg", "jpg", "png", "webp"]
|
||||
DEFAULT_IMAGE_ENCODING_PARAMS = {
|
||||
"jpeg": {"quality": REDIS_IMAGE_DEFAULT_QUALITY},
|
||||
"png": {"compression": 9},
|
||||
"webp": {"quality": REDIS_IMAGE_DEFAULT_QUALITY}
|
||||
}
|
||||
|
||||
# ===== DIRECTORY PATHS =====
|
||||
MODELS_DIR = "models"
|
||||
TEMP_DEBUG_DIR = "temp_debug"
|
||||
LOG_DIR = "logs"
|
||||
|
||||
# ===== ERROR MESSAGES =====
|
||||
class ErrorMessages:
|
||||
"""Standard error messages."""
|
||||
MODEL_LOAD_FAILED = "Failed to load model"
|
||||
STREAM_CONNECTION_FAILED = "Failed to connect to camera stream"
|
||||
DATABASE_CONNECTION_FAILED = "Failed to connect to database"
|
||||
REDIS_CONNECTION_FAILED = "Failed to connect to Redis"
|
||||
PIPELINE_EXECUTION_FAILED = "Pipeline execution failed"
|
||||
INVALID_CONFIGURATION = "Invalid configuration provided"
|
190
detector_worker/core/exceptions.py
Normal file
190
detector_worker/core/exceptions.py
Normal file
|
@ -0,0 +1,190 @@
|
|||
"""
|
||||
Custom exceptions for detector worker.
|
||||
|
||||
This module defines all custom exceptions used throughout the detector
|
||||
worker system to provide better error handling and debugging.
|
||||
"""
|
||||
|
||||
from typing import Optional, Any
|
||||
|
||||
|
||||
class DetectorWorkerError(Exception):
|
||||
"""Base exception for all detector worker errors."""
|
||||
|
||||
def __init__(self, message: str, details: Optional[dict] = None):
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
super().__init__(self.message)
|
||||
|
||||
def __str__(self):
|
||||
if self.details:
|
||||
return f"{self.message} (Details: {self.details})"
|
||||
return self.message
|
||||
|
||||
|
||||
class ConfigurationError(DetectorWorkerError):
|
||||
"""Raised when configuration is invalid or missing."""
|
||||
pass
|
||||
|
||||
|
||||
class ModelLoadError(DetectorWorkerError):
|
||||
"""Raised when model loading fails."""
|
||||
pass
|
||||
|
||||
|
||||
class PipelineError(DetectorWorkerError):
|
||||
"""Raised when pipeline execution fails."""
|
||||
pass
|
||||
|
||||
|
||||
class StreamError(DetectorWorkerError):
|
||||
"""Raised when stream operations fail."""
|
||||
pass
|
||||
|
||||
|
||||
class CameraConnectionError(StreamError):
|
||||
"""Raised when camera connection fails."""
|
||||
pass
|
||||
|
||||
|
||||
class FrameReadError(StreamError):
|
||||
"""Raised when frame reading fails."""
|
||||
pass
|
||||
|
||||
|
||||
class DetectionError(DetectorWorkerError):
|
||||
"""Raised when detection operations fail."""
|
||||
pass
|
||||
|
||||
|
||||
class TrackingError(DetectorWorkerError):
|
||||
"""Raised when tracking operations fail."""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(DetectorWorkerError):
|
||||
"""Raised when validation fails."""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseError(DetectorWorkerError):
|
||||
"""Raised when database operations fail."""
|
||||
pass
|
||||
|
||||
|
||||
class RedisError(DetectorWorkerError):
|
||||
"""Raised when Redis operations fail."""
|
||||
pass
|
||||
|
||||
|
||||
class SessionError(DetectorWorkerError):
|
||||
"""Raised when session management fails."""
|
||||
pass
|
||||
|
||||
|
||||
class WebSocketError(DetectorWorkerError):
|
||||
"""Raised when WebSocket operations fail."""
|
||||
pass
|
||||
|
||||
|
||||
class MessageProcessingError(DetectorWorkerError):
|
||||
"""Raised when message processing fails."""
|
||||
pass
|
||||
|
||||
|
||||
class ActionExecutionError(DetectorWorkerError):
|
||||
"""Raised when action execution fails."""
|
||||
pass
|
||||
|
||||
|
||||
class FieldMappingError(DetectorWorkerError):
|
||||
"""Raised when field mapping fails."""
|
||||
pass
|
||||
|
||||
|
||||
class ImageProcessingError(DetectorWorkerError):
|
||||
"""Raised when image processing fails."""
|
||||
pass
|
||||
|
||||
|
||||
class TimeoutError(DetectorWorkerError):
|
||||
"""Raised when operations timeout."""
|
||||
pass
|
||||
|
||||
|
||||
class ResourceExhaustionError(DetectorWorkerError):
|
||||
"""Raised when system resources are exhausted."""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidStateError(DetectorWorkerError):
|
||||
"""Raised when system is in invalid state for operation."""
|
||||
pass
|
||||
|
||||
|
||||
# ===== ERROR CONTEXT HELPERS =====
|
||||
|
||||
def add_error_context(exception: Exception, **context) -> DetectorWorkerError:
|
||||
"""Add context information to an exception."""
|
||||
if isinstance(exception, DetectorWorkerError):
|
||||
exception.details.update(context)
|
||||
return exception
|
||||
else:
|
||||
return DetectorWorkerError(
|
||||
message=str(exception),
|
||||
details=context
|
||||
)
|
||||
|
||||
|
||||
def create_model_error(model_id: str, operation: str, original_error: Exception) -> ModelLoadError:
|
||||
"""Create a model-specific error with context."""
|
||||
return ModelLoadError(
|
||||
message=f"Model {operation} failed for {model_id}: {str(original_error)}",
|
||||
details={
|
||||
"model_id": model_id,
|
||||
"operation": operation,
|
||||
"original_error": type(original_error).__name__,
|
||||
"original_message": str(original_error)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def create_stream_error(camera_id: str, stream_url: str, operation: str, original_error: Exception) -> StreamError:
|
||||
"""Create a stream-specific error with context."""
|
||||
return StreamError(
|
||||
message=f"Stream {operation} failed for camera {camera_id}: {str(original_error)}",
|
||||
details={
|
||||
"camera_id": camera_id,
|
||||
"stream_url": stream_url,
|
||||
"operation": operation,
|
||||
"original_error": type(original_error).__name__,
|
||||
"original_message": str(original_error)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def create_detection_error(camera_id: str, model_id: str, operation: str, original_error: Exception) -> DetectionError:
|
||||
"""Create a detection-specific error with context."""
|
||||
return DetectionError(
|
||||
message=f"Detection {operation} failed for camera {camera_id}, model {model_id}: {str(original_error)}",
|
||||
details={
|
||||
"camera_id": camera_id,
|
||||
"model_id": model_id,
|
||||
"operation": operation,
|
||||
"original_error": type(original_error).__name__,
|
||||
"original_message": str(original_error)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def create_pipeline_error(pipeline_id: str, stage: str, original_error: Exception) -> PipelineError:
|
||||
"""Create a pipeline-specific error with context."""
|
||||
return PipelineError(
|
||||
message=f"Pipeline execution failed at {stage}: {str(original_error)}",
|
||||
details={
|
||||
"pipeline_id": pipeline_id,
|
||||
"stage": stage,
|
||||
"original_error": type(original_error).__name__,
|
||||
"original_message": str(original_error)
|
||||
}
|
||||
)
|
9
detector_worker/detection/__init__.py
Normal file
9
detector_worker/detection/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
"""
|
||||
Detection module - YOLO detection, tracking, and validation
|
||||
|
||||
This module handles:
|
||||
- YOLO model inference and object detection
|
||||
- Object tracking with BoT-SORT integration
|
||||
- Track stability validation
|
||||
- Detection result data structures
|
||||
"""
|
271
detector_worker/detection/detection_result.py
Normal file
271
detector_worker/detection/detection_result.py
Normal file
|
@ -0,0 +1,271 @@
|
|||
"""
|
||||
Detection result data structures and models.
|
||||
|
||||
This module defines the data structures used to represent detection
|
||||
results throughout the system.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class BoundingBox:
|
||||
"""Represents a bounding box for detected objects."""
|
||||
x1: int
|
||||
y1: int
|
||||
x2: int
|
||||
y2: int
|
||||
|
||||
@property
|
||||
def width(self) -> int:
|
||||
return self.x2 - self.x1
|
||||
|
||||
@property
|
||||
def height(self) -> int:
|
||||
return self.y2 - self.y1
|
||||
|
||||
@property
|
||||
def area(self) -> int:
|
||||
return self.width * self.height
|
||||
|
||||
@property
|
||||
def center(self) -> Tuple[int, int]:
|
||||
return (self.x1 + self.width // 2, self.y1 + self.height // 2)
|
||||
|
||||
def to_list(self) -> List[int]:
|
||||
"""Convert to [x1, y1, x2, y2] format."""
|
||||
return [self.x1, self.y1, self.x2, self.y2]
|
||||
|
||||
def to_xyxy(self) -> Tuple[int, int, int, int]:
|
||||
"""Convert to (x1, y1, x2, y2) tuple."""
|
||||
return (self.x1, self.y1, self.x2, self.y2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
"""Represents a single detection result."""
|
||||
class_name: str
|
||||
confidence: float
|
||||
bbox: BoundingBox
|
||||
track_id: Optional[int] = None
|
||||
class_id: Optional[int] = None
|
||||
original_class: Optional[str] = None # For class mapping
|
||||
|
||||
# Additional detection metadata
|
||||
model_id: Optional[str] = None
|
||||
timestamp: Optional[datetime] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
result = {
|
||||
"class": self.class_name,
|
||||
"confidence": self.confidence,
|
||||
"bbox": self.bbox.to_list()
|
||||
}
|
||||
|
||||
if self.track_id is not None:
|
||||
result["id"] = self.track_id
|
||||
if self.class_id is not None:
|
||||
result["class_id"] = self.class_id
|
||||
if self.original_class is not None:
|
||||
result["original_class"] = self.original_class
|
||||
if self.model_id is not None:
|
||||
result["model_id"] = self.model_id
|
||||
if self.timestamp is not None:
|
||||
result["timestamp"] = self.timestamp.isoformat()
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'DetectionResult':
|
||||
"""Create DetectionResult from dictionary."""
|
||||
bbox_data = data["bbox"]
|
||||
if isinstance(bbox_data, list) and len(bbox_data) == 4:
|
||||
bbox = BoundingBox(bbox_data[0], bbox_data[1], bbox_data[2], bbox_data[3])
|
||||
else:
|
||||
raise ValueError(f"Invalid bbox format: {bbox_data}")
|
||||
|
||||
timestamp = None
|
||||
if "timestamp" in data:
|
||||
if isinstance(data["timestamp"], str):
|
||||
timestamp = datetime.fromisoformat(data["timestamp"])
|
||||
elif isinstance(data["timestamp"], datetime):
|
||||
timestamp = data["timestamp"]
|
||||
|
||||
return cls(
|
||||
class_name=data["class"],
|
||||
confidence=data["confidence"],
|
||||
bbox=bbox,
|
||||
track_id=data.get("id"),
|
||||
class_id=data.get("class_id"),
|
||||
original_class=data.get("original_class"),
|
||||
model_id=data.get("model_id"),
|
||||
timestamp=timestamp
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegionData:
|
||||
"""Represents detection data for a specific region/class."""
|
||||
bbox: BoundingBox
|
||||
confidence: float
|
||||
detection: DetectionResult
|
||||
track_id: Optional[int] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"bbox": self.bbox.to_list(),
|
||||
"confidence": self.confidence,
|
||||
"detection": self.detection.to_dict(),
|
||||
"track_id": self.track_id
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrackValidationResult:
|
||||
"""Represents track stability validation results."""
|
||||
validation_complete: bool
|
||||
stable_tracks: List[int] = field(default_factory=list)
|
||||
current_tracks: List[int] = field(default_factory=list)
|
||||
newly_stable_tracks: List[int] = field(default_factory=list)
|
||||
|
||||
# Additional metadata
|
||||
send_none_detection: bool = False
|
||||
branch_node: bool = False
|
||||
bypass_validation: bool = False
|
||||
awaiting_stable_tracks: bool = False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"validation_complete": self.validation_complete,
|
||||
"stable_tracks": self.stable_tracks.copy(),
|
||||
"current_tracks": self.current_tracks.copy(),
|
||||
"newly_stable_tracks": self.newly_stable_tracks.copy(),
|
||||
"send_none_detection": self.send_none_detection,
|
||||
"branch_node": self.branch_node,
|
||||
"bypass_validation": self.bypass_validation,
|
||||
"awaiting_stable_tracks": self.awaiting_stable_tracks
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionSession:
|
||||
"""Represents a detection session with multiple results."""
|
||||
detections: List[DetectionResult] = field(default_factory=list)
|
||||
regions: Dict[str, RegionData] = field(default_factory=dict)
|
||||
validation_result: Optional[TrackValidationResult] = None
|
||||
|
||||
# Session metadata
|
||||
camera_id: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
backend_session_id: Optional[str] = None
|
||||
timestamp: Optional[datetime] = None
|
||||
|
||||
# Branch results for pipeline processing
|
||||
branch_results: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def get_primary_detection(self) -> Optional[DetectionResult]:
|
||||
"""Get the highest confidence detection."""
|
||||
if not self.detections:
|
||||
return None
|
||||
return max(self.detections, key=lambda x: x.confidence)
|
||||
|
||||
def get_detection_by_class(self, class_name: str) -> Optional[DetectionResult]:
|
||||
"""Get detection for specific class."""
|
||||
for detection in self.detections:
|
||||
if detection.class_name == class_name:
|
||||
return detection
|
||||
return None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
result = {
|
||||
"detections": [d.to_dict() for d in self.detections],
|
||||
"regions": {k: v.to_dict() for k, v in self.regions.items()},
|
||||
"branch_results": self.branch_results.copy()
|
||||
}
|
||||
|
||||
if self.validation_result:
|
||||
result["validation_result"] = self.validation_result.to_dict()
|
||||
if self.camera_id:
|
||||
result["camera_id"] = self.camera_id
|
||||
if self.session_id:
|
||||
result["session_id"] = self.session_id
|
||||
if self.backend_session_id:
|
||||
result["backend_session_id"] = self.backend_session_id
|
||||
if self.timestamp:
|
||||
result["timestamp"] = self.timestamp.isoformat()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class LightweightDetectionResult:
|
||||
"""Result from lightweight detection for validation."""
|
||||
validation_passed: bool
|
||||
class_name: Optional[str] = None
|
||||
confidence: Optional[float] = None
|
||||
bbox: Optional[BoundingBox] = None
|
||||
bbox_area_ratio: Optional[float] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
result = {"validation_passed": self.validation_passed}
|
||||
|
||||
if self.class_name:
|
||||
result["class"] = self.class_name
|
||||
if self.confidence is not None:
|
||||
result["confidence"] = self.confidence
|
||||
if self.bbox:
|
||||
result["bbox"] = self.bbox.to_list()
|
||||
if self.bbox_area_ratio is not None:
|
||||
result["bbox_area_ratio"] = self.bbox_area_ratio
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ===== HELPER FUNCTIONS =====
|
||||
|
||||
def create_none_detection() -> DetectionResult:
|
||||
"""Create a 'none' detection result."""
|
||||
return DetectionResult(
|
||||
class_name="none",
|
||||
confidence=1.0,
|
||||
bbox=BoundingBox(0, 0, 0, 0),
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
|
||||
def create_regions_dict_from_detections(detections: List[DetectionResult]) -> Dict[str, RegionData]:
|
||||
"""Create regions dictionary from detection list."""
|
||||
regions = {}
|
||||
|
||||
for detection in detections:
|
||||
region_data = RegionData(
|
||||
bbox=detection.bbox,
|
||||
confidence=detection.confidence,
|
||||
detection=detection,
|
||||
track_id=detection.track_id
|
||||
)
|
||||
regions[detection.class_name] = region_data
|
||||
|
||||
return regions
|
||||
|
||||
|
||||
def merge_detection_results(results: List[DetectionResult]) -> List[DetectionResult]:
|
||||
"""Merge multiple detection results, keeping highest confidence per class."""
|
||||
class_detections = {}
|
||||
|
||||
for result in results:
|
||||
if result.class_name not in class_detections:
|
||||
class_detections[result.class_name] = result
|
||||
else:
|
||||
# Keep higher confidence detection
|
||||
if result.confidence > class_detections[result.class_name].confidence:
|
||||
class_detections[result.class_name] = result
|
||||
|
||||
return list(class_detections.values())
|
9
detector_worker/models/__init__.py
Normal file
9
detector_worker/models/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
"""
|
||||
Models module - Model loading and pipeline management
|
||||
|
||||
This module handles:
|
||||
- MPTA file loading and parsing
|
||||
- Model lifecycle management
|
||||
- Pipeline node creation
|
||||
- Model registry integration
|
||||
"""
|
9
detector_worker/pipeline/__init__.py
Normal file
9
detector_worker/pipeline/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
"""
|
||||
Pipeline module - Pipeline execution and action processing
|
||||
|
||||
This module handles:
|
||||
- Main pipeline execution coordination
|
||||
- Branch processing (parallel classification)
|
||||
- Action execution (Redis, Database)
|
||||
- Field mapping and template resolution
|
||||
"""
|
9
detector_worker/storage/__init__.py
Normal file
9
detector_worker/storage/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
"""
|
||||
Storage module - Database, Redis, and session management
|
||||
|
||||
This module handles:
|
||||
- PostgreSQL database operations
|
||||
- Redis client and operations
|
||||
- Session caching and management
|
||||
- Data persistence layer
|
||||
"""
|
9
detector_worker/streams/__init__.py
Normal file
9
detector_worker/streams/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
"""
|
||||
Streams module - Video stream management and frame processing
|
||||
|
||||
This module handles:
|
||||
- RTSP and HTTP video stream management
|
||||
- Frame reading and buffering
|
||||
- Camera connection monitoring
|
||||
- Shared stream optimization
|
||||
"""
|
9
detector_worker/utils/__init__.py
Normal file
9
detector_worker/utils/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
"""
|
||||
Utils module - Utility functions and helpers
|
||||
|
||||
This module contains:
|
||||
- Image processing utilities (cropping, encoding)
|
||||
- Input validation functions
|
||||
- Logging configuration and utilities
|
||||
- Common helper functions
|
||||
"""
|
Loading…
Add table
Add a link
Reference in a new issue