python-detector-worker/core/tracking/tracker.py
2025-09-23 17:56:40 +07:00

352 lines
No EOL
14 KiB
Python

"""
Vehicle Tracking Module - Continuous tracking with front_rear_detection model
Implements vehicle identification, persistence, and motion analysis.
"""
import logging
import time
import uuid
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, field
import numpy as np
from threading import Lock
logger = logging.getLogger(__name__)
@dataclass
class TrackedVehicle:
"""Represents a tracked vehicle with all its state information."""
track_id: int
first_seen: float
last_seen: float
session_id: Optional[str] = None
display_id: Optional[str] = None
confidence: float = 0.0
bbox: Tuple[int, int, int, int] = (0, 0, 0, 0) # x1, y1, x2, y2
center: Tuple[float, float] = (0.0, 0.0)
stable_frames: int = 0
total_frames: int = 0
is_stable: bool = False
processed_pipeline: bool = False
last_position_history: List[Tuple[float, float]] = field(default_factory=list)
avg_confidence: float = 0.0
def update_position(self, bbox: Tuple[int, int, int, int], confidence: float):
"""Update vehicle position and confidence."""
self.bbox = bbox
self.center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
self.last_seen = time.time()
self.confidence = confidence
self.total_frames += 1
# Update confidence average
self.avg_confidence = ((self.avg_confidence * (self.total_frames - 1)) + confidence) / self.total_frames
# Maintain position history (last 10 positions)
self.last_position_history.append(self.center)
if len(self.last_position_history) > 10:
self.last_position_history.pop(0)
def calculate_stability(self) -> float:
"""Calculate stability score based on position history."""
if len(self.last_position_history) < 2:
return 0.0
# Calculate movement variance
positions = np.array(self.last_position_history)
if len(positions) < 2:
return 0.0
# Calculate standard deviation of positions
std_x = np.std(positions[:, 0])
std_y = np.std(positions[:, 1])
# Lower variance means more stable (inverse relationship)
# Normalize to 0-1 range (assuming max reasonable std is 50 pixels)
stability = max(0, 1 - (std_x + std_y) / 100)
return stability
def is_expired(self, timeout_seconds: float = 2.0) -> bool:
"""Check if vehicle tracking has expired."""
return (time.time() - self.last_seen) > timeout_seconds
class VehicleTracker:
"""
Main vehicle tracking implementation using YOLO tracking capabilities.
Manages continuous tracking, vehicle identification, and state persistence.
"""
def __init__(self, tracking_config: Optional[Dict] = None):
"""
Initialize the vehicle tracker.
Args:
tracking_config: Configuration from pipeline.json tracking section
"""
self.config = tracking_config or {}
self.trigger_classes = self.config.get('triggerClasses', ['front_rear'])
self.min_confidence = self.config.get('minConfidence', 0.6)
# Tracking state
self.tracked_vehicles: Dict[int, TrackedVehicle] = {}
self.next_track_id = 1
self.lock = Lock()
# Tracking parameters
self.stability_threshold = 0.7
self.min_stable_frames = 5
self.position_tolerance = 50 # pixels
self.timeout_seconds = 2.0
logger.info(f"VehicleTracker initialized with trigger_classes={self.trigger_classes}, "
f"min_confidence={self.min_confidence}")
def process_detections(self,
results: Any,
display_id: str,
frame: np.ndarray) -> List[TrackedVehicle]:
"""
Process YOLO detection results and update tracking state.
Args:
results: YOLO detection results with tracking
display_id: Display identifier for this stream
frame: Current frame being processed
Returns:
List of currently tracked vehicles
"""
current_time = time.time()
active_tracks = []
with self.lock:
# Clean up expired tracks
expired_ids = [
track_id for track_id, vehicle in self.tracked_vehicles.items()
if vehicle.is_expired(self.timeout_seconds)
]
for track_id in expired_ids:
logger.debug(f"Removing expired track {track_id}")
del self.tracked_vehicles[track_id]
# Process new detections
if hasattr(results, 'boxes') and results.boxes is not None:
boxes = results.boxes
# Check if tracking is available
if hasattr(boxes, 'id') and boxes.id is not None:
# Process tracked objects
for i, box in enumerate(boxes):
# Get tracking ID
track_id = int(boxes.id[i].item()) if boxes.id[i] is not None else None
if track_id is None:
continue
# Get class and confidence
cls_id = int(box.cls.item())
confidence = float(box.conf.item())
# Check if class is in trigger classes
class_name = results.names[cls_id] if hasattr(results, 'names') else str(cls_id)
if class_name not in self.trigger_classes and confidence < self.min_confidence:
continue
# Get bounding box
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
bbox = (x1, y1, x2, y2)
# Update or create tracked vehicle
if track_id in self.tracked_vehicles:
# Update existing track
vehicle = self.tracked_vehicles[track_id]
vehicle.update_position(bbox, confidence)
vehicle.display_id = display_id
# Check stability
stability = vehicle.calculate_stability()
if stability > self.stability_threshold:
vehicle.stable_frames += 1
if vehicle.stable_frames >= self.min_stable_frames:
vehicle.is_stable = True
else:
vehicle.stable_frames = max(0, vehicle.stable_frames - 1)
if vehicle.stable_frames < self.min_stable_frames:
vehicle.is_stable = False
logger.debug(f"Updated track {track_id}: conf={confidence:.2f}, "
f"stable={vehicle.is_stable}, stability={stability:.2f}")
else:
# Create new track
vehicle = TrackedVehicle(
track_id=track_id,
first_seen=current_time,
last_seen=current_time,
display_id=display_id,
confidence=confidence,
bbox=bbox,
center=((x1 + x2) / 2, (y1 + y2) / 2),
total_frames=1
)
vehicle.last_position_history.append(vehicle.center)
self.tracked_vehicles[track_id] = vehicle
logger.info(f"New vehicle tracked: ID={track_id}, display={display_id}")
active_tracks.append(self.tracked_vehicles[track_id])
else:
# No tracking available, process as detections only
logger.debug("No tracking IDs available, processing as detections only")
for i, box in enumerate(boxes):
cls_id = int(box.cls.item())
confidence = float(box.conf.item())
# Check confidence threshold
if confidence < self.min_confidence:
continue
# Get bounding box
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
bbox = (x1, y1, x2, y2)
center = ((x1 + x2) / 2, (y1 + y2) / 2)
# Try to match with existing tracks by position
matched_track = self._find_closest_track(center)
if matched_track:
matched_track.update_position(bbox, confidence)
matched_track.display_id = display_id
active_tracks.append(matched_track)
else:
# Create new track with generated ID
track_id = self.next_track_id
self.next_track_id += 1
vehicle = TrackedVehicle(
track_id=track_id,
first_seen=current_time,
last_seen=current_time,
display_id=display_id,
confidence=confidence,
bbox=bbox,
center=center,
total_frames=1
)
vehicle.last_position_history.append(center)
self.tracked_vehicles[track_id] = vehicle
active_tracks.append(vehicle)
logger.info(f"New vehicle detected (no tracking): ID={track_id}")
return active_tracks
def _find_closest_track(self, center: Tuple[float, float]) -> Optional[TrackedVehicle]:
"""
Find the closest existing track to a given position.
Args:
center: Center position to match
Returns:
Closest tracked vehicle if within tolerance, None otherwise
"""
min_distance = float('inf')
closest_track = None
for vehicle in self.tracked_vehicles.values():
if vehicle.is_expired(0.5): # Shorter timeout for matching
continue
distance = np.sqrt(
(center[0] - vehicle.center[0]) ** 2 +
(center[1] - vehicle.center[1]) ** 2
)
if distance < min_distance and distance < self.position_tolerance:
min_distance = distance
closest_track = vehicle
return closest_track
def get_stable_vehicles(self, display_id: Optional[str] = None) -> List[TrackedVehicle]:
"""
Get all stable vehicles, optionally filtered by display.
Args:
display_id: Optional display ID to filter by
Returns:
List of stable tracked vehicles
"""
with self.lock:
stable = [
v for v in self.tracked_vehicles.values()
if v.is_stable and not v.is_expired(self.timeout_seconds)
and (display_id is None or v.display_id == display_id)
]
return stable
def get_vehicle_by_session(self, session_id: str) -> Optional[TrackedVehicle]:
"""
Get a tracked vehicle by its session ID.
Args:
session_id: Session ID to look up
Returns:
Tracked vehicle if found, None otherwise
"""
with self.lock:
for vehicle in self.tracked_vehicles.values():
if vehicle.session_id == session_id:
return vehicle
return None
def mark_processed(self, track_id: int, session_id: str):
"""
Mark a vehicle as processed through the pipeline.
Args:
track_id: Track ID of the vehicle
session_id: Session ID assigned to this vehicle
"""
with self.lock:
if track_id in self.tracked_vehicles:
vehicle = self.tracked_vehicles[track_id]
vehicle.processed_pipeline = True
vehicle.session_id = session_id
logger.info(f"Marked vehicle {track_id} as processed with session {session_id}")
def clear_session(self, session_id: str):
"""
Clear session ID from a tracked vehicle (post-fueling).
Args:
session_id: Session ID to clear
"""
with self.lock:
for vehicle in self.tracked_vehicles.values():
if vehicle.session_id == session_id:
logger.info(f"Clearing session {session_id} from vehicle {vehicle.track_id}")
vehicle.session_id = None
# Keep processed_pipeline=True to prevent re-processing
def reset_tracking(self):
"""Reset all tracking state."""
with self.lock:
self.tracked_vehicles.clear()
self.next_track_id = 1
logger.info("Vehicle tracking state reset")
def get_statistics(self) -> Dict:
"""Get tracking statistics."""
with self.lock:
total = len(self.tracked_vehicles)
stable = sum(1 for v in self.tracked_vehicles.values() if v.is_stable)
processed = sum(1 for v in self.tracked_vehicles.values() if v.processed_pipeline)
return {
'total_tracked': total,
'stable_vehicles': stable,
'processed_vehicles': processed,
'avg_confidence': np.mean([v.avg_confidence for v in self.tracked_vehicles.values()])
if self.tracked_vehicles else 0.0
}