Refactor: done phase 4
This commit is contained in:
parent
7e8034c6e5
commit
9e4c23c75c
8 changed files with 1533 additions and 37 deletions
352
core/tracking/tracker.py
Normal file
352
core/tracking/tracker.py
Normal file
|
@ -0,0 +1,352 @@
|
|||
"""
|
||||
Vehicle Tracking Module - Continuous tracking with front_rear_detection model
|
||||
Implements vehicle identification, persistence, and motion analysis.
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass, field
|
||||
import numpy as np
|
||||
from threading import Lock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrackedVehicle:
|
||||
"""Represents a tracked vehicle with all its state information."""
|
||||
track_id: int
|
||||
first_seen: float
|
||||
last_seen: float
|
||||
session_id: Optional[str] = None
|
||||
display_id: Optional[str] = None
|
||||
confidence: float = 0.0
|
||||
bbox: Tuple[int, int, int, int] = (0, 0, 0, 0) # x1, y1, x2, y2
|
||||
center: Tuple[float, float] = (0.0, 0.0)
|
||||
stable_frames: int = 0
|
||||
total_frames: int = 0
|
||||
is_stable: bool = False
|
||||
processed_pipeline: bool = False
|
||||
last_position_history: List[Tuple[float, float]] = field(default_factory=list)
|
||||
avg_confidence: float = 0.0
|
||||
|
||||
def update_position(self, bbox: Tuple[int, int, int, int], confidence: float):
|
||||
"""Update vehicle position and confidence."""
|
||||
self.bbox = bbox
|
||||
self.center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
|
||||
self.last_seen = time.time()
|
||||
self.confidence = confidence
|
||||
self.total_frames += 1
|
||||
|
||||
# Update confidence average
|
||||
self.avg_confidence = ((self.avg_confidence * (self.total_frames - 1)) + confidence) / self.total_frames
|
||||
|
||||
# Maintain position history (last 10 positions)
|
||||
self.last_position_history.append(self.center)
|
||||
if len(self.last_position_history) > 10:
|
||||
self.last_position_history.pop(0)
|
||||
|
||||
def calculate_stability(self) -> float:
|
||||
"""Calculate stability score based on position history."""
|
||||
if len(self.last_position_history) < 2:
|
||||
return 0.0
|
||||
|
||||
# Calculate movement variance
|
||||
positions = np.array(self.last_position_history)
|
||||
if len(positions) < 2:
|
||||
return 0.0
|
||||
|
||||
# Calculate standard deviation of positions
|
||||
std_x = np.std(positions[:, 0])
|
||||
std_y = np.std(positions[:, 1])
|
||||
|
||||
# Lower variance means more stable (inverse relationship)
|
||||
# Normalize to 0-1 range (assuming max reasonable std is 50 pixels)
|
||||
stability = max(0, 1 - (std_x + std_y) / 100)
|
||||
return stability
|
||||
|
||||
def is_expired(self, timeout_seconds: float = 2.0) -> bool:
|
||||
"""Check if vehicle tracking has expired."""
|
||||
return (time.time() - self.last_seen) > timeout_seconds
|
||||
|
||||
|
||||
class VehicleTracker:
|
||||
"""
|
||||
Main vehicle tracking implementation using YOLO tracking capabilities.
|
||||
Manages continuous tracking, vehicle identification, and state persistence.
|
||||
"""
|
||||
|
||||
def __init__(self, tracking_config: Optional[Dict] = None):
|
||||
"""
|
||||
Initialize the vehicle tracker.
|
||||
|
||||
Args:
|
||||
tracking_config: Configuration from pipeline.json tracking section
|
||||
"""
|
||||
self.config = tracking_config or {}
|
||||
self.trigger_classes = self.config.get('triggerClasses', ['front_rear'])
|
||||
self.min_confidence = self.config.get('minConfidence', 0.6)
|
||||
|
||||
# Tracking state
|
||||
self.tracked_vehicles: Dict[int, TrackedVehicle] = {}
|
||||
self.next_track_id = 1
|
||||
self.lock = Lock()
|
||||
|
||||
# Tracking parameters
|
||||
self.stability_threshold = 0.7
|
||||
self.min_stable_frames = 5
|
||||
self.position_tolerance = 50 # pixels
|
||||
self.timeout_seconds = 2.0
|
||||
|
||||
logger.info(f"VehicleTracker initialized with trigger_classes={self.trigger_classes}, "
|
||||
f"min_confidence={self.min_confidence}")
|
||||
|
||||
def process_detections(self,
|
||||
results: Any,
|
||||
display_id: str,
|
||||
frame: np.ndarray) -> List[TrackedVehicle]:
|
||||
"""
|
||||
Process YOLO detection results and update tracking state.
|
||||
|
||||
Args:
|
||||
results: YOLO detection results with tracking
|
||||
display_id: Display identifier for this stream
|
||||
frame: Current frame being processed
|
||||
|
||||
Returns:
|
||||
List of currently tracked vehicles
|
||||
"""
|
||||
current_time = time.time()
|
||||
active_tracks = []
|
||||
|
||||
with self.lock:
|
||||
# Clean up expired tracks
|
||||
expired_ids = [
|
||||
track_id for track_id, vehicle in self.tracked_vehicles.items()
|
||||
if vehicle.is_expired(self.timeout_seconds)
|
||||
]
|
||||
for track_id in expired_ids:
|
||||
logger.debug(f"Removing expired track {track_id}")
|
||||
del self.tracked_vehicles[track_id]
|
||||
|
||||
# Process new detections
|
||||
if hasattr(results, 'boxes') and results.boxes is not None:
|
||||
boxes = results.boxes
|
||||
|
||||
# Check if tracking is available
|
||||
if hasattr(boxes, 'id') and boxes.id is not None:
|
||||
# Process tracked objects
|
||||
for i, box in enumerate(boxes):
|
||||
# Get tracking ID
|
||||
track_id = int(boxes.id[i].item()) if boxes.id[i] is not None else None
|
||||
if track_id is None:
|
||||
continue
|
||||
|
||||
# Get class and confidence
|
||||
cls_id = int(box.cls.item())
|
||||
confidence = float(box.conf.item())
|
||||
|
||||
# Check if class is in trigger classes
|
||||
class_name = results.names[cls_id] if hasattr(results, 'names') else str(cls_id)
|
||||
if class_name not in self.trigger_classes and confidence < self.min_confidence:
|
||||
continue
|
||||
|
||||
# Get bounding box
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
|
||||
bbox = (x1, y1, x2, y2)
|
||||
|
||||
# Update or create tracked vehicle
|
||||
if track_id in self.tracked_vehicles:
|
||||
# Update existing track
|
||||
vehicle = self.tracked_vehicles[track_id]
|
||||
vehicle.update_position(bbox, confidence)
|
||||
vehicle.display_id = display_id
|
||||
|
||||
# Check stability
|
||||
stability = vehicle.calculate_stability()
|
||||
if stability > self.stability_threshold:
|
||||
vehicle.stable_frames += 1
|
||||
if vehicle.stable_frames >= self.min_stable_frames:
|
||||
vehicle.is_stable = True
|
||||
else:
|
||||
vehicle.stable_frames = max(0, vehicle.stable_frames - 1)
|
||||
if vehicle.stable_frames < self.min_stable_frames:
|
||||
vehicle.is_stable = False
|
||||
|
||||
logger.debug(f"Updated track {track_id}: conf={confidence:.2f}, "
|
||||
f"stable={vehicle.is_stable}, stability={stability:.2f}")
|
||||
else:
|
||||
# Create new track
|
||||
vehicle = TrackedVehicle(
|
||||
track_id=track_id,
|
||||
first_seen=current_time,
|
||||
last_seen=current_time,
|
||||
display_id=display_id,
|
||||
confidence=confidence,
|
||||
bbox=bbox,
|
||||
center=((x1 + x2) / 2, (y1 + y2) / 2),
|
||||
total_frames=1
|
||||
)
|
||||
vehicle.last_position_history.append(vehicle.center)
|
||||
self.tracked_vehicles[track_id] = vehicle
|
||||
logger.info(f"New vehicle tracked: ID={track_id}, display={display_id}")
|
||||
|
||||
active_tracks.append(self.tracked_vehicles[track_id])
|
||||
else:
|
||||
# No tracking available, process as detections only
|
||||
logger.debug("No tracking IDs available, processing as detections only")
|
||||
for i, box in enumerate(boxes):
|
||||
cls_id = int(box.cls.item())
|
||||
confidence = float(box.conf.item())
|
||||
|
||||
# Check confidence threshold
|
||||
if confidence < self.min_confidence:
|
||||
continue
|
||||
|
||||
# Get bounding box
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
|
||||
bbox = (x1, y1, x2, y2)
|
||||
center = ((x1 + x2) / 2, (y1 + y2) / 2)
|
||||
|
||||
# Try to match with existing tracks by position
|
||||
matched_track = self._find_closest_track(center)
|
||||
|
||||
if matched_track:
|
||||
matched_track.update_position(bbox, confidence)
|
||||
matched_track.display_id = display_id
|
||||
active_tracks.append(matched_track)
|
||||
else:
|
||||
# Create new track with generated ID
|
||||
track_id = self.next_track_id
|
||||
self.next_track_id += 1
|
||||
|
||||
vehicle = TrackedVehicle(
|
||||
track_id=track_id,
|
||||
first_seen=current_time,
|
||||
last_seen=current_time,
|
||||
display_id=display_id,
|
||||
confidence=confidence,
|
||||
bbox=bbox,
|
||||
center=center,
|
||||
total_frames=1
|
||||
)
|
||||
vehicle.last_position_history.append(center)
|
||||
self.tracked_vehicles[track_id] = vehicle
|
||||
active_tracks.append(vehicle)
|
||||
logger.info(f"New vehicle detected (no tracking): ID={track_id}")
|
||||
|
||||
return active_tracks
|
||||
|
||||
def _find_closest_track(self, center: Tuple[float, float]) -> Optional[TrackedVehicle]:
|
||||
"""
|
||||
Find the closest existing track to a given position.
|
||||
|
||||
Args:
|
||||
center: Center position to match
|
||||
|
||||
Returns:
|
||||
Closest tracked vehicle if within tolerance, None otherwise
|
||||
"""
|
||||
min_distance = float('inf')
|
||||
closest_track = None
|
||||
|
||||
for vehicle in self.tracked_vehicles.values():
|
||||
if vehicle.is_expired(0.5): # Shorter timeout for matching
|
||||
continue
|
||||
|
||||
distance = np.sqrt(
|
||||
(center[0] - vehicle.center[0]) ** 2 +
|
||||
(center[1] - vehicle.center[1]) ** 2
|
||||
)
|
||||
|
||||
if distance < min_distance and distance < self.position_tolerance:
|
||||
min_distance = distance
|
||||
closest_track = vehicle
|
||||
|
||||
return closest_track
|
||||
|
||||
def get_stable_vehicles(self, display_id: Optional[str] = None) -> List[TrackedVehicle]:
|
||||
"""
|
||||
Get all stable vehicles, optionally filtered by display.
|
||||
|
||||
Args:
|
||||
display_id: Optional display ID to filter by
|
||||
|
||||
Returns:
|
||||
List of stable tracked vehicles
|
||||
"""
|
||||
with self.lock:
|
||||
stable = [
|
||||
v for v in self.tracked_vehicles.values()
|
||||
if v.is_stable and not v.is_expired(self.timeout_seconds)
|
||||
and (display_id is None or v.display_id == display_id)
|
||||
]
|
||||
return stable
|
||||
|
||||
def get_vehicle_by_session(self, session_id: str) -> Optional[TrackedVehicle]:
|
||||
"""
|
||||
Get a tracked vehicle by its session ID.
|
||||
|
||||
Args:
|
||||
session_id: Session ID to look up
|
||||
|
||||
Returns:
|
||||
Tracked vehicle if found, None otherwise
|
||||
"""
|
||||
with self.lock:
|
||||
for vehicle in self.tracked_vehicles.values():
|
||||
if vehicle.session_id == session_id:
|
||||
return vehicle
|
||||
return None
|
||||
|
||||
def mark_processed(self, track_id: int, session_id: str):
|
||||
"""
|
||||
Mark a vehicle as processed through the pipeline.
|
||||
|
||||
Args:
|
||||
track_id: Track ID of the vehicle
|
||||
session_id: Session ID assigned to this vehicle
|
||||
"""
|
||||
with self.lock:
|
||||
if track_id in self.tracked_vehicles:
|
||||
vehicle = self.tracked_vehicles[track_id]
|
||||
vehicle.processed_pipeline = True
|
||||
vehicle.session_id = session_id
|
||||
logger.info(f"Marked vehicle {track_id} as processed with session {session_id}")
|
||||
|
||||
def clear_session(self, session_id: str):
|
||||
"""
|
||||
Clear session ID from a tracked vehicle (post-fueling).
|
||||
|
||||
Args:
|
||||
session_id: Session ID to clear
|
||||
"""
|
||||
with self.lock:
|
||||
for vehicle in self.tracked_vehicles.values():
|
||||
if vehicle.session_id == session_id:
|
||||
logger.info(f"Clearing session {session_id} from vehicle {vehicle.track_id}")
|
||||
vehicle.session_id = None
|
||||
# Keep processed_pipeline=True to prevent re-processing
|
||||
|
||||
def reset_tracking(self):
|
||||
"""Reset all tracking state."""
|
||||
with self.lock:
|
||||
self.tracked_vehicles.clear()
|
||||
self.next_track_id = 1
|
||||
logger.info("Vehicle tracking state reset")
|
||||
|
||||
def get_statistics(self) -> Dict:
|
||||
"""Get tracking statistics."""
|
||||
with self.lock:
|
||||
total = len(self.tracked_vehicles)
|
||||
stable = sum(1 for v in self.tracked_vehicles.values() if v.is_stable)
|
||||
processed = sum(1 for v in self.tracked_vehicles.values() if v.processed_pipeline)
|
||||
|
||||
return {
|
||||
'total_tracked': total,
|
||||
'stable_vehicles': stable,
|
||||
'processed_vehicles': processed,
|
||||
'avg_confidence': np.mean([v.avg_confidence for v in self.tracked_vehicles.values()])
|
||||
if self.tracked_vehicles else 0.0
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue