Refactor: done phase 4
This commit is contained in:
parent
7e8034c6e5
commit
9e4c23c75c
8 changed files with 1533 additions and 37 deletions
|
@ -1 +1,14 @@
|
|||
# Tracking module for vehicle tracking and validation
|
||||
# Tracking module for vehicle tracking and validation
|
||||
|
||||
from .tracker import VehicleTracker, TrackedVehicle
|
||||
from .validator import StableCarValidator, ValidationResult, VehicleState
|
||||
from .integration import TrackingPipelineIntegration
|
||||
|
||||
__all__ = [
|
||||
'VehicleTracker',
|
||||
'TrackedVehicle',
|
||||
'StableCarValidator',
|
||||
'ValidationResult',
|
||||
'VehicleState',
|
||||
'TrackingPipelineIntegration'
|
||||
]
|
369
core/tracking/integration.py
Normal file
369
core/tracking/integration.py
Normal file
|
@ -0,0 +1,369 @@
|
|||
"""
|
||||
Tracking-Pipeline Integration Module.
|
||||
Connects the tracking system with the main detection pipeline and manages the flow.
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Optional, Any, List, Tuple
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import numpy as np
|
||||
|
||||
from .tracker import VehicleTracker, TrackedVehicle
|
||||
from .validator import StableCarValidator, ValidationResult, VehicleState
|
||||
from ..models.inference import YOLOWrapper
|
||||
from ..models.pipeline import PipelineParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrackingPipelineIntegration:
|
||||
"""
|
||||
Integrates vehicle tracking with the detection pipeline.
|
||||
Manages tracking state transitions and pipeline execution triggers.
|
||||
"""
|
||||
|
||||
def __init__(self, pipeline_parser: PipelineParser, model_manager: Any):
|
||||
"""
|
||||
Initialize tracking-pipeline integration.
|
||||
|
||||
Args:
|
||||
pipeline_parser: Pipeline parser with loaded configuration
|
||||
model_manager: Model manager for loading models
|
||||
"""
|
||||
self.pipeline_parser = pipeline_parser
|
||||
self.model_manager = model_manager
|
||||
|
||||
# Initialize tracking components
|
||||
tracking_config = pipeline_parser.tracking_config.__dict__ if pipeline_parser.tracking_config else {}
|
||||
self.tracker = VehicleTracker(tracking_config)
|
||||
self.validator = StableCarValidator()
|
||||
|
||||
# Tracking model
|
||||
self.tracking_model: Optional[YOLOWrapper] = None
|
||||
self.tracking_model_id = None
|
||||
|
||||
# Session management
|
||||
self.active_sessions: Dict[str, str] = {} # display_id -> session_id
|
||||
self.session_vehicles: Dict[str, int] = {} # session_id -> track_id
|
||||
self.cleared_sessions: Dict[str, float] = {} # session_id -> clear_time
|
||||
|
||||
# Thread pool for pipeline execution
|
||||
self.executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
# Statistics
|
||||
self.stats = {
|
||||
'frames_processed': 0,
|
||||
'vehicles_detected': 0,
|
||||
'vehicles_validated': 0,
|
||||
'pipelines_executed': 0
|
||||
}
|
||||
|
||||
logger.info("TrackingPipelineIntegration initialized")
|
||||
|
||||
async def initialize_tracking_model(self) -> bool:
|
||||
"""
|
||||
Load and initialize the tracking model.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not self.pipeline_parser.tracking_config:
|
||||
logger.warning("No tracking configuration found in pipeline")
|
||||
return False
|
||||
|
||||
model_file = self.pipeline_parser.tracking_config.model_file
|
||||
model_id = self.pipeline_parser.tracking_config.model_id
|
||||
|
||||
if not model_file:
|
||||
logger.warning("No tracking model file specified")
|
||||
return False
|
||||
|
||||
# Load tracking model
|
||||
logger.info(f"Loading tracking model: {model_id} ({model_file})")
|
||||
# Get the model ID from the ModelManager context
|
||||
# We need the actual model ID, not the model string identifier
|
||||
# For now, let's extract it from the model manager
|
||||
pipeline_models = list(self.model_manager.get_all_downloaded_models())
|
||||
if pipeline_models:
|
||||
actual_model_id = pipeline_models[0] # Use the first available model
|
||||
self.tracking_model = self.model_manager.get_yolo_model(actual_model_id, model_file)
|
||||
else:
|
||||
logger.error("No models available in ModelManager")
|
||||
return False
|
||||
self.tracking_model_id = model_id
|
||||
|
||||
if self.tracking_model:
|
||||
logger.info(f"Tracking model {model_id} loaded successfully")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to load tracking model {model_id}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing tracking model: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def process_frame(self,
|
||||
frame: np.ndarray,
|
||||
display_id: str,
|
||||
subscription_id: str,
|
||||
session_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a frame through tracking and potentially the detection pipeline.
|
||||
|
||||
Args:
|
||||
frame: Input frame to process
|
||||
display_id: Display identifier
|
||||
subscription_id: Full subscription identifier
|
||||
session_id: Optional session ID from backend
|
||||
|
||||
Returns:
|
||||
Dictionary with processing results
|
||||
"""
|
||||
start_time = time.time()
|
||||
result = {
|
||||
'tracked_vehicles': [],
|
||||
'validated_vehicle': None,
|
||||
'pipeline_result': None,
|
||||
'session_id': session_id,
|
||||
'processing_time': 0.0
|
||||
}
|
||||
|
||||
try:
|
||||
# Update stats
|
||||
self.stats['frames_processed'] += 1
|
||||
|
||||
# Run tracking model
|
||||
if self.tracking_model:
|
||||
# Run inference with tracking
|
||||
tracking_results = self.tracking_model.track(
|
||||
frame,
|
||||
confidence_threshold=self.tracker.min_confidence,
|
||||
trigger_classes=self.tracker.trigger_classes,
|
||||
persist=True
|
||||
)
|
||||
|
||||
# Process tracking results
|
||||
tracked_vehicles = self.tracker.process_detections(
|
||||
tracking_results,
|
||||
display_id,
|
||||
frame
|
||||
)
|
||||
|
||||
result['tracked_vehicles'] = [
|
||||
{
|
||||
'track_id': v.track_id,
|
||||
'bbox': v.bbox,
|
||||
'confidence': v.confidence,
|
||||
'is_stable': v.is_stable,
|
||||
'session_id': v.session_id
|
||||
}
|
||||
for v in tracked_vehicles
|
||||
]
|
||||
|
||||
# Log tracking info periodically
|
||||
if self.stats['frames_processed'] % 30 == 0: # Every 30 frames
|
||||
logger.debug(f"Tracking: {len(tracked_vehicles)} vehicles, "
|
||||
f"display={display_id}")
|
||||
|
||||
# Get stable vehicles for validation
|
||||
stable_vehicles = self.tracker.get_stable_vehicles(display_id)
|
||||
|
||||
# Validate and potentially process stable vehicles
|
||||
for vehicle in stable_vehicles:
|
||||
# Check if vehicle is already processed or has session
|
||||
if vehicle.processed_pipeline:
|
||||
continue
|
||||
|
||||
# Check for session cleared (post-fueling)
|
||||
if session_id and vehicle.session_id == session_id:
|
||||
# Same vehicle with same session, skip
|
||||
continue
|
||||
|
||||
# Check if this was a recently cleared session
|
||||
session_cleared = False
|
||||
if vehicle.session_id in self.cleared_sessions:
|
||||
clear_time = self.cleared_sessions[vehicle.session_id]
|
||||
if (time.time() - clear_time) < 30: # 30 second cooldown
|
||||
session_cleared = True
|
||||
|
||||
# Skip same car after session clear
|
||||
if self.validator.should_skip_same_car(vehicle, session_cleared):
|
||||
continue
|
||||
|
||||
# Validate vehicle
|
||||
validation_result = self.validator.validate_vehicle(vehicle, frame.shape)
|
||||
|
||||
if validation_result.is_valid and validation_result.should_process:
|
||||
logger.info(f"Vehicle {vehicle.track_id} validated for processing: "
|
||||
f"{validation_result.reason}")
|
||||
|
||||
result['validated_vehicle'] = {
|
||||
'track_id': vehicle.track_id,
|
||||
'state': validation_result.state.value,
|
||||
'confidence': validation_result.confidence
|
||||
}
|
||||
|
||||
# Generate session ID if not provided
|
||||
if not session_id:
|
||||
session_id = str(uuid.uuid4())
|
||||
logger.info(f"Generated session ID: {session_id}")
|
||||
|
||||
# Mark vehicle as processed
|
||||
self.tracker.mark_processed(vehicle.track_id, session_id)
|
||||
self.session_vehicles[session_id] = vehicle.track_id
|
||||
self.active_sessions[display_id] = session_id
|
||||
|
||||
# Execute detection pipeline (placeholder for Phase 5)
|
||||
pipeline_result = await self._execute_pipeline(
|
||||
frame,
|
||||
vehicle,
|
||||
display_id,
|
||||
session_id,
|
||||
subscription_id
|
||||
)
|
||||
|
||||
result['pipeline_result'] = pipeline_result
|
||||
result['session_id'] = session_id
|
||||
self.stats['pipelines_executed'] += 1
|
||||
|
||||
# Only process one vehicle per frame
|
||||
break
|
||||
|
||||
self.stats['vehicles_detected'] = len(tracked_vehicles)
|
||||
self.stats['vehicles_validated'] = len(stable_vehicles)
|
||||
|
||||
else:
|
||||
logger.warning("No tracking model available")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in tracking pipeline: {e}", exc_info=True)
|
||||
|
||||
result['processing_time'] = time.time() - start_time
|
||||
return result
|
||||
|
||||
async def _execute_pipeline(self,
|
||||
frame: np.ndarray,
|
||||
vehicle: TrackedVehicle,
|
||||
display_id: str,
|
||||
session_id: str,
|
||||
subscription_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute the main detection pipeline for a validated vehicle.
|
||||
This is a placeholder for Phase 5 implementation.
|
||||
|
||||
Args:
|
||||
frame: Input frame
|
||||
vehicle: Validated tracked vehicle
|
||||
display_id: Display identifier
|
||||
session_id: Session identifier
|
||||
subscription_id: Full subscription identifier
|
||||
|
||||
Returns:
|
||||
Pipeline execution results
|
||||
"""
|
||||
logger.info(f"Executing pipeline for vehicle {vehicle.track_id}, "
|
||||
f"session={session_id}, display={display_id}")
|
||||
|
||||
# Placeholder for Phase 5 pipeline execution
|
||||
# This will be implemented when we create the detection module
|
||||
pipeline_result = {
|
||||
'status': 'pending',
|
||||
'message': 'Pipeline execution will be implemented in Phase 5',
|
||||
'vehicle_id': vehicle.track_id,
|
||||
'session_id': session_id,
|
||||
'bbox': vehicle.bbox,
|
||||
'confidence': vehicle.confidence
|
||||
}
|
||||
|
||||
# Simulate pipeline execution
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return pipeline_result
|
||||
|
||||
def set_session_id(self, display_id: str, session_id: str):
|
||||
"""
|
||||
Set session ID for a display (from backend).
|
||||
|
||||
Args:
|
||||
display_id: Display identifier
|
||||
session_id: Session identifier
|
||||
"""
|
||||
self.active_sessions[display_id] = session_id
|
||||
logger.info(f"Set session {session_id} for display {display_id}")
|
||||
|
||||
# Find vehicle with this session
|
||||
vehicle = self.tracker.get_vehicle_by_session(session_id)
|
||||
if vehicle:
|
||||
self.session_vehicles[session_id] = vehicle.track_id
|
||||
|
||||
def clear_session_id(self, session_id: str):
|
||||
"""
|
||||
Clear session ID (post-fueling).
|
||||
|
||||
Args:
|
||||
session_id: Session identifier to clear
|
||||
"""
|
||||
# Mark session as cleared
|
||||
self.cleared_sessions[session_id] = time.time()
|
||||
|
||||
# Clear from tracker
|
||||
self.tracker.clear_session(session_id)
|
||||
|
||||
# Remove from active sessions
|
||||
display_to_remove = None
|
||||
for display_id, sess_id in self.active_sessions.items():
|
||||
if sess_id == session_id:
|
||||
display_to_remove = display_id
|
||||
break
|
||||
|
||||
if display_to_remove:
|
||||
del self.active_sessions[display_to_remove]
|
||||
|
||||
if session_id in self.session_vehicles:
|
||||
del self.session_vehicles[session_id]
|
||||
|
||||
logger.info(f"Cleared session {session_id}")
|
||||
|
||||
# Clean old cleared sessions (older than 5 minutes)
|
||||
current_time = time.time()
|
||||
old_sessions = [
|
||||
sid for sid, clear_time in self.cleared_sessions.items()
|
||||
if (current_time - clear_time) > 300
|
||||
]
|
||||
for sid in old_sessions:
|
||||
del self.cleared_sessions[sid]
|
||||
|
||||
def get_session_for_display(self, display_id: str) -> Optional[str]:
|
||||
"""Get active session for a display."""
|
||||
return self.active_sessions.get(display_id)
|
||||
|
||||
def reset_tracking(self):
|
||||
"""Reset all tracking state."""
|
||||
self.tracker.reset_tracking()
|
||||
self.active_sessions.clear()
|
||||
self.session_vehicles.clear()
|
||||
self.cleared_sessions.clear()
|
||||
logger.info("Tracking pipeline integration reset")
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive statistics."""
|
||||
tracker_stats = self.tracker.get_statistics()
|
||||
validator_stats = self.validator.get_statistics()
|
||||
|
||||
return {
|
||||
'integration': self.stats,
|
||||
'tracker': tracker_stats,
|
||||
'validator': validator_stats,
|
||||
'active_sessions': len(self.active_sessions),
|
||||
'cleared_sessions': len(self.cleared_sessions)
|
||||
}
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources."""
|
||||
self.executor.shutdown(wait=False)
|
||||
self.reset_tracking()
|
||||
logger.info("Tracking pipeline integration cleaned up")
|
352
core/tracking/tracker.py
Normal file
352
core/tracking/tracker.py
Normal file
|
@ -0,0 +1,352 @@
|
|||
"""
|
||||
Vehicle Tracking Module - Continuous tracking with front_rear_detection model
|
||||
Implements vehicle identification, persistence, and motion analysis.
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass, field
|
||||
import numpy as np
|
||||
from threading import Lock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrackedVehicle:
|
||||
"""Represents a tracked vehicle with all its state information."""
|
||||
track_id: int
|
||||
first_seen: float
|
||||
last_seen: float
|
||||
session_id: Optional[str] = None
|
||||
display_id: Optional[str] = None
|
||||
confidence: float = 0.0
|
||||
bbox: Tuple[int, int, int, int] = (0, 0, 0, 0) # x1, y1, x2, y2
|
||||
center: Tuple[float, float] = (0.0, 0.0)
|
||||
stable_frames: int = 0
|
||||
total_frames: int = 0
|
||||
is_stable: bool = False
|
||||
processed_pipeline: bool = False
|
||||
last_position_history: List[Tuple[float, float]] = field(default_factory=list)
|
||||
avg_confidence: float = 0.0
|
||||
|
||||
def update_position(self, bbox: Tuple[int, int, int, int], confidence: float):
|
||||
"""Update vehicle position and confidence."""
|
||||
self.bbox = bbox
|
||||
self.center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
|
||||
self.last_seen = time.time()
|
||||
self.confidence = confidence
|
||||
self.total_frames += 1
|
||||
|
||||
# Update confidence average
|
||||
self.avg_confidence = ((self.avg_confidence * (self.total_frames - 1)) + confidence) / self.total_frames
|
||||
|
||||
# Maintain position history (last 10 positions)
|
||||
self.last_position_history.append(self.center)
|
||||
if len(self.last_position_history) > 10:
|
||||
self.last_position_history.pop(0)
|
||||
|
||||
def calculate_stability(self) -> float:
|
||||
"""Calculate stability score based on position history."""
|
||||
if len(self.last_position_history) < 2:
|
||||
return 0.0
|
||||
|
||||
# Calculate movement variance
|
||||
positions = np.array(self.last_position_history)
|
||||
if len(positions) < 2:
|
||||
return 0.0
|
||||
|
||||
# Calculate standard deviation of positions
|
||||
std_x = np.std(positions[:, 0])
|
||||
std_y = np.std(positions[:, 1])
|
||||
|
||||
# Lower variance means more stable (inverse relationship)
|
||||
# Normalize to 0-1 range (assuming max reasonable std is 50 pixels)
|
||||
stability = max(0, 1 - (std_x + std_y) / 100)
|
||||
return stability
|
||||
|
||||
def is_expired(self, timeout_seconds: float = 2.0) -> bool:
|
||||
"""Check if vehicle tracking has expired."""
|
||||
return (time.time() - self.last_seen) > timeout_seconds
|
||||
|
||||
|
||||
class VehicleTracker:
|
||||
"""
|
||||
Main vehicle tracking implementation using YOLO tracking capabilities.
|
||||
Manages continuous tracking, vehicle identification, and state persistence.
|
||||
"""
|
||||
|
||||
def __init__(self, tracking_config: Optional[Dict] = None):
|
||||
"""
|
||||
Initialize the vehicle tracker.
|
||||
|
||||
Args:
|
||||
tracking_config: Configuration from pipeline.json tracking section
|
||||
"""
|
||||
self.config = tracking_config or {}
|
||||
self.trigger_classes = self.config.get('triggerClasses', ['front_rear'])
|
||||
self.min_confidence = self.config.get('minConfidence', 0.6)
|
||||
|
||||
# Tracking state
|
||||
self.tracked_vehicles: Dict[int, TrackedVehicle] = {}
|
||||
self.next_track_id = 1
|
||||
self.lock = Lock()
|
||||
|
||||
# Tracking parameters
|
||||
self.stability_threshold = 0.7
|
||||
self.min_stable_frames = 5
|
||||
self.position_tolerance = 50 # pixels
|
||||
self.timeout_seconds = 2.0
|
||||
|
||||
logger.info(f"VehicleTracker initialized with trigger_classes={self.trigger_classes}, "
|
||||
f"min_confidence={self.min_confidence}")
|
||||
|
||||
def process_detections(self,
|
||||
results: Any,
|
||||
display_id: str,
|
||||
frame: np.ndarray) -> List[TrackedVehicle]:
|
||||
"""
|
||||
Process YOLO detection results and update tracking state.
|
||||
|
||||
Args:
|
||||
results: YOLO detection results with tracking
|
||||
display_id: Display identifier for this stream
|
||||
frame: Current frame being processed
|
||||
|
||||
Returns:
|
||||
List of currently tracked vehicles
|
||||
"""
|
||||
current_time = time.time()
|
||||
active_tracks = []
|
||||
|
||||
with self.lock:
|
||||
# Clean up expired tracks
|
||||
expired_ids = [
|
||||
track_id for track_id, vehicle in self.tracked_vehicles.items()
|
||||
if vehicle.is_expired(self.timeout_seconds)
|
||||
]
|
||||
for track_id in expired_ids:
|
||||
logger.debug(f"Removing expired track {track_id}")
|
||||
del self.tracked_vehicles[track_id]
|
||||
|
||||
# Process new detections
|
||||
if hasattr(results, 'boxes') and results.boxes is not None:
|
||||
boxes = results.boxes
|
||||
|
||||
# Check if tracking is available
|
||||
if hasattr(boxes, 'id') and boxes.id is not None:
|
||||
# Process tracked objects
|
||||
for i, box in enumerate(boxes):
|
||||
# Get tracking ID
|
||||
track_id = int(boxes.id[i].item()) if boxes.id[i] is not None else None
|
||||
if track_id is None:
|
||||
continue
|
||||
|
||||
# Get class and confidence
|
||||
cls_id = int(box.cls.item())
|
||||
confidence = float(box.conf.item())
|
||||
|
||||
# Check if class is in trigger classes
|
||||
class_name = results.names[cls_id] if hasattr(results, 'names') else str(cls_id)
|
||||
if class_name not in self.trigger_classes and confidence < self.min_confidence:
|
||||
continue
|
||||
|
||||
# Get bounding box
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
|
||||
bbox = (x1, y1, x2, y2)
|
||||
|
||||
# Update or create tracked vehicle
|
||||
if track_id in self.tracked_vehicles:
|
||||
# Update existing track
|
||||
vehicle = self.tracked_vehicles[track_id]
|
||||
vehicle.update_position(bbox, confidence)
|
||||
vehicle.display_id = display_id
|
||||
|
||||
# Check stability
|
||||
stability = vehicle.calculate_stability()
|
||||
if stability > self.stability_threshold:
|
||||
vehicle.stable_frames += 1
|
||||
if vehicle.stable_frames >= self.min_stable_frames:
|
||||
vehicle.is_stable = True
|
||||
else:
|
||||
vehicle.stable_frames = max(0, vehicle.stable_frames - 1)
|
||||
if vehicle.stable_frames < self.min_stable_frames:
|
||||
vehicle.is_stable = False
|
||||
|
||||
logger.debug(f"Updated track {track_id}: conf={confidence:.2f}, "
|
||||
f"stable={vehicle.is_stable}, stability={stability:.2f}")
|
||||
else:
|
||||
# Create new track
|
||||
vehicle = TrackedVehicle(
|
||||
track_id=track_id,
|
||||
first_seen=current_time,
|
||||
last_seen=current_time,
|
||||
display_id=display_id,
|
||||
confidence=confidence,
|
||||
bbox=bbox,
|
||||
center=((x1 + x2) / 2, (y1 + y2) / 2),
|
||||
total_frames=1
|
||||
)
|
||||
vehicle.last_position_history.append(vehicle.center)
|
||||
self.tracked_vehicles[track_id] = vehicle
|
||||
logger.info(f"New vehicle tracked: ID={track_id}, display={display_id}")
|
||||
|
||||
active_tracks.append(self.tracked_vehicles[track_id])
|
||||
else:
|
||||
# No tracking available, process as detections only
|
||||
logger.debug("No tracking IDs available, processing as detections only")
|
||||
for i, box in enumerate(boxes):
|
||||
cls_id = int(box.cls.item())
|
||||
confidence = float(box.conf.item())
|
||||
|
||||
# Check confidence threshold
|
||||
if confidence < self.min_confidence:
|
||||
continue
|
||||
|
||||
# Get bounding box
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
|
||||
bbox = (x1, y1, x2, y2)
|
||||
center = ((x1 + x2) / 2, (y1 + y2) / 2)
|
||||
|
||||
# Try to match with existing tracks by position
|
||||
matched_track = self._find_closest_track(center)
|
||||
|
||||
if matched_track:
|
||||
matched_track.update_position(bbox, confidence)
|
||||
matched_track.display_id = display_id
|
||||
active_tracks.append(matched_track)
|
||||
else:
|
||||
# Create new track with generated ID
|
||||
track_id = self.next_track_id
|
||||
self.next_track_id += 1
|
||||
|
||||
vehicle = TrackedVehicle(
|
||||
track_id=track_id,
|
||||
first_seen=current_time,
|
||||
last_seen=current_time,
|
||||
display_id=display_id,
|
||||
confidence=confidence,
|
||||
bbox=bbox,
|
||||
center=center,
|
||||
total_frames=1
|
||||
)
|
||||
vehicle.last_position_history.append(center)
|
||||
self.tracked_vehicles[track_id] = vehicle
|
||||
active_tracks.append(vehicle)
|
||||
logger.info(f"New vehicle detected (no tracking): ID={track_id}")
|
||||
|
||||
return active_tracks
|
||||
|
||||
def _find_closest_track(self, center: Tuple[float, float]) -> Optional[TrackedVehicle]:
|
||||
"""
|
||||
Find the closest existing track to a given position.
|
||||
|
||||
Args:
|
||||
center: Center position to match
|
||||
|
||||
Returns:
|
||||
Closest tracked vehicle if within tolerance, None otherwise
|
||||
"""
|
||||
min_distance = float('inf')
|
||||
closest_track = None
|
||||
|
||||
for vehicle in self.tracked_vehicles.values():
|
||||
if vehicle.is_expired(0.5): # Shorter timeout for matching
|
||||
continue
|
||||
|
||||
distance = np.sqrt(
|
||||
(center[0] - vehicle.center[0]) ** 2 +
|
||||
(center[1] - vehicle.center[1]) ** 2
|
||||
)
|
||||
|
||||
if distance < min_distance and distance < self.position_tolerance:
|
||||
min_distance = distance
|
||||
closest_track = vehicle
|
||||
|
||||
return closest_track
|
||||
|
||||
def get_stable_vehicles(self, display_id: Optional[str] = None) -> List[TrackedVehicle]:
|
||||
"""
|
||||
Get all stable vehicles, optionally filtered by display.
|
||||
|
||||
Args:
|
||||
display_id: Optional display ID to filter by
|
||||
|
||||
Returns:
|
||||
List of stable tracked vehicles
|
||||
"""
|
||||
with self.lock:
|
||||
stable = [
|
||||
v for v in self.tracked_vehicles.values()
|
||||
if v.is_stable and not v.is_expired(self.timeout_seconds)
|
||||
and (display_id is None or v.display_id == display_id)
|
||||
]
|
||||
return stable
|
||||
|
||||
def get_vehicle_by_session(self, session_id: str) -> Optional[TrackedVehicle]:
|
||||
"""
|
||||
Get a tracked vehicle by its session ID.
|
||||
|
||||
Args:
|
||||
session_id: Session ID to look up
|
||||
|
||||
Returns:
|
||||
Tracked vehicle if found, None otherwise
|
||||
"""
|
||||
with self.lock:
|
||||
for vehicle in self.tracked_vehicles.values():
|
||||
if vehicle.session_id == session_id:
|
||||
return vehicle
|
||||
return None
|
||||
|
||||
def mark_processed(self, track_id: int, session_id: str):
|
||||
"""
|
||||
Mark a vehicle as processed through the pipeline.
|
||||
|
||||
Args:
|
||||
track_id: Track ID of the vehicle
|
||||
session_id: Session ID assigned to this vehicle
|
||||
"""
|
||||
with self.lock:
|
||||
if track_id in self.tracked_vehicles:
|
||||
vehicle = self.tracked_vehicles[track_id]
|
||||
vehicle.processed_pipeline = True
|
||||
vehicle.session_id = session_id
|
||||
logger.info(f"Marked vehicle {track_id} as processed with session {session_id}")
|
||||
|
||||
def clear_session(self, session_id: str):
|
||||
"""
|
||||
Clear session ID from a tracked vehicle (post-fueling).
|
||||
|
||||
Args:
|
||||
session_id: Session ID to clear
|
||||
"""
|
||||
with self.lock:
|
||||
for vehicle in self.tracked_vehicles.values():
|
||||
if vehicle.session_id == session_id:
|
||||
logger.info(f"Clearing session {session_id} from vehicle {vehicle.track_id}")
|
||||
vehicle.session_id = None
|
||||
# Keep processed_pipeline=True to prevent re-processing
|
||||
|
||||
def reset_tracking(self):
|
||||
"""Reset all tracking state."""
|
||||
with self.lock:
|
||||
self.tracked_vehicles.clear()
|
||||
self.next_track_id = 1
|
||||
logger.info("Vehicle tracking state reset")
|
||||
|
||||
def get_statistics(self) -> Dict:
|
||||
"""Get tracking statistics."""
|
||||
with self.lock:
|
||||
total = len(self.tracked_vehicles)
|
||||
stable = sum(1 for v in self.tracked_vehicles.values() if v.is_stable)
|
||||
processed = sum(1 for v in self.tracked_vehicles.values() if v.processed_pipeline)
|
||||
|
||||
return {
|
||||
'total_tracked': total,
|
||||
'stable_vehicles': stable,
|
||||
'processed_vehicles': processed,
|
||||
'avg_confidence': np.mean([v.avg_confidence for v in self.tracked_vehicles.values()])
|
||||
if self.tracked_vehicles else 0.0
|
||||
}
|
408
core/tracking/validator.py
Normal file
408
core/tracking/validator.py
Normal file
|
@ -0,0 +1,408 @@
|
|||
"""
|
||||
Vehicle Validation Module - Stable car detection and validation logic.
|
||||
Differentiates between stable (fueling) cars and passing-by vehicles.
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from .tracker import TrackedVehicle
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VehicleState(Enum):
|
||||
"""Vehicle state classification."""
|
||||
UNKNOWN = "unknown"
|
||||
ENTERING = "entering"
|
||||
STABLE = "stable"
|
||||
LEAVING = "leaving"
|
||||
PASSING_BY = "passing_by"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of vehicle validation."""
|
||||
is_valid: bool
|
||||
state: VehicleState
|
||||
confidence: float
|
||||
reason: str
|
||||
should_process: bool = False
|
||||
track_id: Optional[int] = None
|
||||
|
||||
|
||||
class StableCarValidator:
|
||||
"""
|
||||
Validates whether a tracked vehicle is stable (fueling) or just passing by.
|
||||
Uses multiple criteria including position stability, duration, and movement patterns.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None):
|
||||
"""
|
||||
Initialize the validator with configuration.
|
||||
|
||||
Args:
|
||||
config: Optional configuration dictionary
|
||||
"""
|
||||
self.config = config or {}
|
||||
|
||||
# Validation thresholds
|
||||
self.min_stable_duration = self.config.get('min_stable_duration', 3.0) # seconds
|
||||
self.min_stable_frames = self.config.get('min_stable_frames', 10)
|
||||
self.position_variance_threshold = self.config.get('position_variance_threshold', 25.0) # pixels
|
||||
self.min_confidence = self.config.get('min_confidence', 0.7)
|
||||
self.velocity_threshold = self.config.get('velocity_threshold', 5.0) # pixels/frame
|
||||
self.entering_zone_ratio = self.config.get('entering_zone_ratio', 0.3) # 30% of frame
|
||||
self.leaving_zone_ratio = self.config.get('leaving_zone_ratio', 0.3)
|
||||
|
||||
# Frame dimensions (will be updated on first frame)
|
||||
self.frame_width = 1920
|
||||
self.frame_height = 1080
|
||||
|
||||
# History for validation
|
||||
self.validation_history: Dict[int, List[VehicleState]] = {}
|
||||
self.last_processed_vehicles: Dict[int, float] = {} # track_id -> last_process_time
|
||||
|
||||
logger.info(f"StableCarValidator initialized with min_duration={self.min_stable_duration}s, "
|
||||
f"min_frames={self.min_stable_frames}, position_variance={self.position_variance_threshold}")
|
||||
|
||||
def update_frame_dimensions(self, width: int, height: int):
|
||||
"""Update frame dimensions for zone calculations."""
|
||||
self.frame_width = width
|
||||
self.frame_height = height
|
||||
logger.debug(f"Updated frame dimensions: {width}x{height}")
|
||||
|
||||
def validate_vehicle(self, vehicle: TrackedVehicle, frame_shape: Optional[Tuple] = None) -> ValidationResult:
|
||||
"""
|
||||
Validate whether a tracked vehicle is stable and should be processed.
|
||||
|
||||
Args:
|
||||
vehicle: The tracked vehicle to validate
|
||||
frame_shape: Optional frame shape (height, width, channels)
|
||||
|
||||
Returns:
|
||||
ValidationResult with validation status and reasoning
|
||||
"""
|
||||
# Update frame dimensions if provided
|
||||
if frame_shape:
|
||||
self.update_frame_dimensions(frame_shape[1], frame_shape[0])
|
||||
|
||||
# Initialize validation history for new vehicles
|
||||
if vehicle.track_id not in self.validation_history:
|
||||
self.validation_history[vehicle.track_id] = []
|
||||
|
||||
# Check if already processed
|
||||
if vehicle.processed_pipeline:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=VehicleState.STABLE,
|
||||
confidence=1.0,
|
||||
reason="Already processed through pipeline",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
|
||||
# Check if recently processed (cooldown period)
|
||||
if vehicle.track_id in self.last_processed_vehicles:
|
||||
time_since_process = time.time() - self.last_processed_vehicles[vehicle.track_id]
|
||||
if time_since_process < 10.0: # 10 second cooldown
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=VehicleState.STABLE,
|
||||
confidence=1.0,
|
||||
reason=f"Recently processed ({time_since_process:.1f}s ago)",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
|
||||
# Determine vehicle state
|
||||
state = self._determine_vehicle_state(vehicle)
|
||||
|
||||
# Update history
|
||||
self.validation_history[vehicle.track_id].append(state)
|
||||
if len(self.validation_history[vehicle.track_id]) > 20:
|
||||
self.validation_history[vehicle.track_id].pop(0)
|
||||
|
||||
# Validate based on state
|
||||
if state == VehicleState.STABLE:
|
||||
return self._validate_stable_vehicle(vehicle)
|
||||
elif state == VehicleState.PASSING_BY:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=state,
|
||||
confidence=0.8,
|
||||
reason="Vehicle is passing by",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
elif state == VehicleState.ENTERING:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=state,
|
||||
confidence=0.5,
|
||||
reason="Vehicle is entering, waiting for stability",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
elif state == VehicleState.LEAVING:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=state,
|
||||
confidence=0.5,
|
||||
reason="Vehicle is leaving",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
else:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=state,
|
||||
confidence=0.0,
|
||||
reason="Unknown vehicle state",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
|
||||
def _determine_vehicle_state(self, vehicle: TrackedVehicle) -> VehicleState:
|
||||
"""
|
||||
Determine the current state of the vehicle based on movement patterns.
|
||||
|
||||
Args:
|
||||
vehicle: The tracked vehicle
|
||||
|
||||
Returns:
|
||||
Current vehicle state
|
||||
"""
|
||||
# Not enough data
|
||||
if len(vehicle.last_position_history) < 3:
|
||||
return VehicleState.UNKNOWN
|
||||
|
||||
# Calculate velocity
|
||||
velocity = self._calculate_velocity(vehicle)
|
||||
|
||||
# Get position zones
|
||||
x_position = vehicle.center[0] / self.frame_width
|
||||
y_position = vehicle.center[1] / self.frame_height
|
||||
|
||||
# Check if vehicle is stable
|
||||
stability = vehicle.calculate_stability()
|
||||
if stability > 0.7 and velocity < self.velocity_threshold:
|
||||
# Check if it's been stable long enough
|
||||
duration = time.time() - vehicle.first_seen
|
||||
if duration > self.min_stable_duration and vehicle.stable_frames >= self.min_stable_frames:
|
||||
return VehicleState.STABLE
|
||||
else:
|
||||
return VehicleState.ENTERING
|
||||
|
||||
# Check if vehicle is entering or leaving
|
||||
if velocity > self.velocity_threshold:
|
||||
# Determine direction based on position history
|
||||
positions = np.array(vehicle.last_position_history)
|
||||
if len(positions) >= 2:
|
||||
direction = positions[-1] - positions[0]
|
||||
|
||||
# Entering: moving towards center
|
||||
if x_position < self.entering_zone_ratio or x_position > (1 - self.entering_zone_ratio):
|
||||
if abs(direction[0]) > abs(direction[1]): # Horizontal movement
|
||||
if (x_position < 0.5 and direction[0] > 0) or (x_position > 0.5 and direction[0] < 0):
|
||||
return VehicleState.ENTERING
|
||||
|
||||
# Leaving: moving away from center
|
||||
if 0.3 < x_position < 0.7: # In center zone
|
||||
if abs(direction[0]) > abs(direction[1]): # Horizontal movement
|
||||
if abs(direction[0]) > 10: # Significant movement
|
||||
return VehicleState.LEAVING
|
||||
|
||||
return VehicleState.PASSING_BY
|
||||
|
||||
return VehicleState.UNKNOWN
|
||||
|
||||
def _validate_stable_vehicle(self, vehicle: TrackedVehicle) -> ValidationResult:
|
||||
"""
|
||||
Perform detailed validation of a stable vehicle.
|
||||
|
||||
Args:
|
||||
vehicle: The stable vehicle to validate
|
||||
|
||||
Returns:
|
||||
Detailed validation result
|
||||
"""
|
||||
# Check duration
|
||||
duration = time.time() - vehicle.first_seen
|
||||
if duration < self.min_stable_duration:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=VehicleState.STABLE,
|
||||
confidence=0.6,
|
||||
reason=f"Not stable long enough ({duration:.1f}s < {self.min_stable_duration}s)",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
|
||||
# Check frame count
|
||||
if vehicle.stable_frames < self.min_stable_frames:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=VehicleState.STABLE,
|
||||
confidence=0.6,
|
||||
reason=f"Not enough stable frames ({vehicle.stable_frames} < {self.min_stable_frames})",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
|
||||
# Check confidence
|
||||
if vehicle.avg_confidence < self.min_confidence:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=VehicleState.STABLE,
|
||||
confidence=vehicle.avg_confidence,
|
||||
reason=f"Confidence too low ({vehicle.avg_confidence:.2f} < {self.min_confidence})",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
|
||||
# Check position variance
|
||||
variance = self._calculate_position_variance(vehicle)
|
||||
if variance > self.position_variance_threshold:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=VehicleState.STABLE,
|
||||
confidence=0.7,
|
||||
reason=f"Position variance too high ({variance:.1f} > {self.position_variance_threshold})",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
|
||||
# Check state history consistency
|
||||
if vehicle.track_id in self.validation_history:
|
||||
history = self.validation_history[vehicle.track_id][-5:] # Last 5 states
|
||||
stable_count = sum(1 for s in history if s == VehicleState.STABLE)
|
||||
if stable_count < 3:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=VehicleState.STABLE,
|
||||
confidence=0.7,
|
||||
reason="Inconsistent state history",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
|
||||
# All checks passed - vehicle is valid for processing
|
||||
self.last_processed_vehicles[vehicle.track_id] = time.time()
|
||||
|
||||
return ValidationResult(
|
||||
is_valid=True,
|
||||
state=VehicleState.STABLE,
|
||||
confidence=vehicle.avg_confidence,
|
||||
reason="Vehicle is stable and ready for processing",
|
||||
should_process=True,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
|
||||
def _calculate_velocity(self, vehicle: TrackedVehicle) -> float:
|
||||
"""
|
||||
Calculate the velocity of the vehicle based on position history.
|
||||
|
||||
Args:
|
||||
vehicle: The tracked vehicle
|
||||
|
||||
Returns:
|
||||
Velocity in pixels per frame
|
||||
"""
|
||||
if len(vehicle.last_position_history) < 2:
|
||||
return 0.0
|
||||
|
||||
positions = np.array(vehicle.last_position_history)
|
||||
if len(positions) < 2:
|
||||
return 0.0
|
||||
|
||||
# Calculate velocity over last 3 frames
|
||||
recent_positions = positions[-min(3, len(positions)):]
|
||||
velocities = []
|
||||
|
||||
for i in range(1, len(recent_positions)):
|
||||
dx = recent_positions[i][0] - recent_positions[i-1][0]
|
||||
dy = recent_positions[i][1] - recent_positions[i-1][1]
|
||||
velocity = np.sqrt(dx**2 + dy**2)
|
||||
velocities.append(velocity)
|
||||
|
||||
return np.mean(velocities) if velocities else 0.0
|
||||
|
||||
def _calculate_position_variance(self, vehicle: TrackedVehicle) -> float:
|
||||
"""
|
||||
Calculate the position variance of the vehicle.
|
||||
|
||||
Args:
|
||||
vehicle: The tracked vehicle
|
||||
|
||||
Returns:
|
||||
Position variance in pixels
|
||||
"""
|
||||
if len(vehicle.last_position_history) < 2:
|
||||
return 0.0
|
||||
|
||||
positions = np.array(vehicle.last_position_history)
|
||||
variance_x = np.var(positions[:, 0])
|
||||
variance_y = np.var(positions[:, 1])
|
||||
|
||||
return np.sqrt(variance_x + variance_y)
|
||||
|
||||
def should_skip_same_car(self,
|
||||
vehicle: TrackedVehicle,
|
||||
session_cleared: bool = False) -> bool:
|
||||
"""
|
||||
Determine if we should skip processing for the same car after session clear.
|
||||
|
||||
Args:
|
||||
vehicle: The tracked vehicle
|
||||
session_cleared: Whether the session was recently cleared
|
||||
|
||||
Returns:
|
||||
True if we should skip this vehicle
|
||||
"""
|
||||
# If vehicle has a session_id but it was cleared, skip for a period
|
||||
if vehicle.session_id is None and vehicle.processed_pipeline and session_cleared:
|
||||
# Check if enough time has passed since processing
|
||||
if vehicle.track_id in self.last_processed_vehicles:
|
||||
time_since = time.time() - self.last_processed_vehicles[vehicle.track_id]
|
||||
if time_since < 30.0: # 30 second cooldown after session clear
|
||||
logger.debug(f"Skipping same car {vehicle.track_id} after session clear "
|
||||
f"({time_since:.1f}s since processing)")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def reset_vehicle(self, track_id: int):
|
||||
"""
|
||||
Reset validation state for a specific vehicle.
|
||||
|
||||
Args:
|
||||
track_id: Track ID of the vehicle to reset
|
||||
"""
|
||||
if track_id in self.validation_history:
|
||||
del self.validation_history[track_id]
|
||||
if track_id in self.last_processed_vehicles:
|
||||
del self.last_processed_vehicles[track_id]
|
||||
logger.debug(f"Reset validation state for vehicle {track_id}")
|
||||
|
||||
def get_statistics(self) -> Dict:
|
||||
"""Get validation statistics."""
|
||||
return {
|
||||
'vehicles_in_history': len(self.validation_history),
|
||||
'recently_processed': len(self.last_processed_vehicles),
|
||||
'state_distribution': self._get_state_distribution()
|
||||
}
|
||||
|
||||
def _get_state_distribution(self) -> Dict[str, int]:
|
||||
"""Get distribution of current vehicle states."""
|
||||
distribution = {state.value: 0 for state in VehicleState}
|
||||
|
||||
for history in self.validation_history.values():
|
||||
if history:
|
||||
current_state = history[-1]
|
||||
distribution[current_state.value] += 1
|
||||
|
||||
return distribution
|
Loading…
Add table
Add a link
Reference in a new issue