380 lines
No EOL
14 KiB
Python
380 lines
No EOL
14 KiB
Python
"""
|
|
Tracking-Pipeline Integration Module.
|
|
Connects the tracking system with the main detection pipeline and manages the flow.
|
|
"""
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from typing import Dict, Optional, Any, List, Tuple
|
|
import asyncio
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import numpy as np
|
|
|
|
from .tracker import VehicleTracker, TrackedVehicle
|
|
from .validator import StableCarValidator, ValidationResult, VehicleState
|
|
from ..models.inference import YOLOWrapper
|
|
from ..models.pipeline import PipelineParser
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TrackingPipelineIntegration:
|
|
"""
|
|
Integrates vehicle tracking with the detection pipeline.
|
|
Manages tracking state transitions and pipeline execution triggers.
|
|
"""
|
|
|
|
def __init__(self, pipeline_parser: PipelineParser, model_manager: Any):
|
|
"""
|
|
Initialize tracking-pipeline integration.
|
|
|
|
Args:
|
|
pipeline_parser: Pipeline parser with loaded configuration
|
|
model_manager: Model manager for loading models
|
|
"""
|
|
self.pipeline_parser = pipeline_parser
|
|
self.model_manager = model_manager
|
|
|
|
# Initialize tracking components
|
|
tracking_config = pipeline_parser.tracking_config.__dict__ if pipeline_parser.tracking_config else {}
|
|
self.tracker = VehicleTracker(tracking_config)
|
|
self.validator = StableCarValidator()
|
|
|
|
# Tracking model
|
|
self.tracking_model: Optional[YOLOWrapper] = None
|
|
self.tracking_model_id = None
|
|
|
|
# Session management
|
|
self.active_sessions: Dict[str, str] = {} # display_id -> session_id
|
|
self.session_vehicles: Dict[str, int] = {} # session_id -> track_id
|
|
self.cleared_sessions: Dict[str, float] = {} # session_id -> clear_time
|
|
|
|
# Thread pool for pipeline execution
|
|
self.executor = ThreadPoolExecutor(max_workers=2)
|
|
|
|
# Statistics
|
|
self.stats = {
|
|
'frames_processed': 0,
|
|
'vehicles_detected': 0,
|
|
'vehicles_validated': 0,
|
|
'pipelines_executed': 0
|
|
}
|
|
|
|
logger.info("TrackingPipelineIntegration initialized")
|
|
|
|
async def initialize_tracking_model(self) -> bool:
|
|
"""
|
|
Load and initialize the tracking model.
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
if not self.pipeline_parser.tracking_config:
|
|
logger.warning("No tracking configuration found in pipeline")
|
|
return False
|
|
|
|
model_file = self.pipeline_parser.tracking_config.model_file
|
|
model_id = self.pipeline_parser.tracking_config.model_id
|
|
|
|
if not model_file:
|
|
logger.warning("No tracking model file specified")
|
|
return False
|
|
|
|
# Load tracking model
|
|
logger.info(f"Loading tracking model: {model_id} ({model_file})")
|
|
# Get the model ID from the ModelManager context
|
|
# We need the actual model ID, not the model string identifier
|
|
# For now, let's extract it from the model manager
|
|
pipeline_models = list(self.model_manager.get_all_downloaded_models())
|
|
if pipeline_models:
|
|
actual_model_id = pipeline_models[0] # Use the first available model
|
|
self.tracking_model = self.model_manager.get_yolo_model(actual_model_id, model_file)
|
|
else:
|
|
logger.error("No models available in ModelManager")
|
|
return False
|
|
self.tracking_model_id = model_id
|
|
|
|
if self.tracking_model:
|
|
logger.info(f"Tracking model {model_id} loaded successfully")
|
|
return True
|
|
else:
|
|
logger.error(f"Failed to load tracking model {model_id}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error initializing tracking model: {e}", exc_info=True)
|
|
return False
|
|
|
|
async def process_frame(self,
|
|
frame: np.ndarray,
|
|
display_id: str,
|
|
subscription_id: str,
|
|
session_id: Optional[str] = None) -> Dict[str, Any]:
|
|
"""
|
|
Process a frame through tracking and potentially the detection pipeline.
|
|
|
|
Args:
|
|
frame: Input frame to process
|
|
display_id: Display identifier
|
|
subscription_id: Full subscription identifier
|
|
session_id: Optional session ID from backend
|
|
|
|
Returns:
|
|
Dictionary with processing results
|
|
"""
|
|
start_time = time.time()
|
|
result = {
|
|
'tracked_vehicles': [],
|
|
'validated_vehicle': None,
|
|
'pipeline_result': None,
|
|
'session_id': session_id,
|
|
'processing_time': 0.0
|
|
}
|
|
|
|
try:
|
|
# Update stats
|
|
self.stats['frames_processed'] += 1
|
|
|
|
# Run tracking model
|
|
if self.tracking_model:
|
|
# Run inference with tracking
|
|
tracking_results = self.tracking_model.track(
|
|
frame,
|
|
confidence_threshold=self.tracker.min_confidence,
|
|
trigger_classes=self.tracker.trigger_classes,
|
|
persist=True
|
|
)
|
|
|
|
# Debug: Log raw detection results
|
|
if tracking_results and hasattr(tracking_results, 'detections'):
|
|
raw_detections = len(tracking_results.detections)
|
|
if raw_detections > 0:
|
|
class_names = [detection.class_name for detection in tracking_results.detections]
|
|
logger.info(f"[DEBUG] Raw detections: {raw_detections}, classes: {class_names}")
|
|
else:
|
|
logger.debug(f"[DEBUG] No raw detections found")
|
|
else:
|
|
logger.debug(f"[DEBUG] No tracking results or detections attribute")
|
|
|
|
# Process tracking results
|
|
tracked_vehicles = self.tracker.process_detections(
|
|
tracking_results,
|
|
display_id,
|
|
frame
|
|
)
|
|
|
|
result['tracked_vehicles'] = [
|
|
{
|
|
'track_id': v.track_id,
|
|
'bbox': v.bbox,
|
|
'confidence': v.confidence,
|
|
'is_stable': v.is_stable,
|
|
'session_id': v.session_id
|
|
}
|
|
for v in tracked_vehicles
|
|
]
|
|
|
|
# Log tracking info periodically
|
|
if self.stats['frames_processed'] % 30 == 0: # Every 30 frames
|
|
logger.debug(f"Tracking: {len(tracked_vehicles)} vehicles, "
|
|
f"display={display_id}")
|
|
|
|
# Get stable vehicles for validation
|
|
stable_vehicles = self.tracker.get_stable_vehicles(display_id)
|
|
|
|
# Validate and potentially process stable vehicles
|
|
for vehicle in stable_vehicles:
|
|
# Check if vehicle is already processed or has session
|
|
if vehicle.processed_pipeline:
|
|
continue
|
|
|
|
# Check for session cleared (post-fueling)
|
|
if session_id and vehicle.session_id == session_id:
|
|
# Same vehicle with same session, skip
|
|
continue
|
|
|
|
# Check if this was a recently cleared session
|
|
session_cleared = False
|
|
if vehicle.session_id in self.cleared_sessions:
|
|
clear_time = self.cleared_sessions[vehicle.session_id]
|
|
if (time.time() - clear_time) < 30: # 30 second cooldown
|
|
session_cleared = True
|
|
|
|
# Skip same car after session clear
|
|
if self.validator.should_skip_same_car(vehicle, session_cleared):
|
|
continue
|
|
|
|
# Validate vehicle
|
|
validation_result = self.validator.validate_vehicle(vehicle, frame.shape)
|
|
|
|
if validation_result.is_valid and validation_result.should_process:
|
|
logger.info(f"Vehicle {vehicle.track_id} validated for processing: "
|
|
f"{validation_result.reason}")
|
|
|
|
result['validated_vehicle'] = {
|
|
'track_id': vehicle.track_id,
|
|
'state': validation_result.state.value,
|
|
'confidence': validation_result.confidence
|
|
}
|
|
|
|
# Generate session ID if not provided
|
|
if not session_id:
|
|
session_id = str(uuid.uuid4())
|
|
logger.info(f"Generated session ID: {session_id}")
|
|
|
|
# Mark vehicle as processed
|
|
self.tracker.mark_processed(vehicle.track_id, session_id)
|
|
self.session_vehicles[session_id] = vehicle.track_id
|
|
self.active_sessions[display_id] = session_id
|
|
|
|
# Execute detection pipeline (placeholder for Phase 5)
|
|
pipeline_result = await self._execute_pipeline(
|
|
frame,
|
|
vehicle,
|
|
display_id,
|
|
session_id,
|
|
subscription_id
|
|
)
|
|
|
|
result['pipeline_result'] = pipeline_result
|
|
result['session_id'] = session_id
|
|
self.stats['pipelines_executed'] += 1
|
|
|
|
# Only process one vehicle per frame
|
|
break
|
|
|
|
self.stats['vehicles_detected'] = len(tracked_vehicles)
|
|
self.stats['vehicles_validated'] = len(stable_vehicles)
|
|
|
|
else:
|
|
logger.warning("No tracking model available")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in tracking pipeline: {e}", exc_info=True)
|
|
|
|
result['processing_time'] = time.time() - start_time
|
|
return result
|
|
|
|
async def _execute_pipeline(self,
|
|
frame: np.ndarray,
|
|
vehicle: TrackedVehicle,
|
|
display_id: str,
|
|
session_id: str,
|
|
subscription_id: str) -> Dict[str, Any]:
|
|
"""
|
|
Execute the main detection pipeline for a validated vehicle.
|
|
This is a placeholder for Phase 5 implementation.
|
|
|
|
Args:
|
|
frame: Input frame
|
|
vehicle: Validated tracked vehicle
|
|
display_id: Display identifier
|
|
session_id: Session identifier
|
|
subscription_id: Full subscription identifier
|
|
|
|
Returns:
|
|
Pipeline execution results
|
|
"""
|
|
logger.info(f"Executing pipeline for vehicle {vehicle.track_id}, "
|
|
f"session={session_id}, display={display_id}")
|
|
|
|
# Placeholder for Phase 5 pipeline execution
|
|
# This will be implemented when we create the detection module
|
|
pipeline_result = {
|
|
'status': 'pending',
|
|
'message': 'Pipeline execution will be implemented in Phase 5',
|
|
'vehicle_id': vehicle.track_id,
|
|
'session_id': session_id,
|
|
'bbox': vehicle.bbox,
|
|
'confidence': vehicle.confidence
|
|
}
|
|
|
|
# Simulate pipeline execution
|
|
await asyncio.sleep(0.1)
|
|
|
|
return pipeline_result
|
|
|
|
def set_session_id(self, display_id: str, session_id: str):
|
|
"""
|
|
Set session ID for a display (from backend).
|
|
|
|
Args:
|
|
display_id: Display identifier
|
|
session_id: Session identifier
|
|
"""
|
|
self.active_sessions[display_id] = session_id
|
|
logger.info(f"Set session {session_id} for display {display_id}")
|
|
|
|
# Find vehicle with this session
|
|
vehicle = self.tracker.get_vehicle_by_session(session_id)
|
|
if vehicle:
|
|
self.session_vehicles[session_id] = vehicle.track_id
|
|
|
|
def clear_session_id(self, session_id: str):
|
|
"""
|
|
Clear session ID (post-fueling).
|
|
|
|
Args:
|
|
session_id: Session identifier to clear
|
|
"""
|
|
# Mark session as cleared
|
|
self.cleared_sessions[session_id] = time.time()
|
|
|
|
# Clear from tracker
|
|
self.tracker.clear_session(session_id)
|
|
|
|
# Remove from active sessions
|
|
display_to_remove = None
|
|
for display_id, sess_id in self.active_sessions.items():
|
|
if sess_id == session_id:
|
|
display_to_remove = display_id
|
|
break
|
|
|
|
if display_to_remove:
|
|
del self.active_sessions[display_to_remove]
|
|
|
|
if session_id in self.session_vehicles:
|
|
del self.session_vehicles[session_id]
|
|
|
|
logger.info(f"Cleared session {session_id}")
|
|
|
|
# Clean old cleared sessions (older than 5 minutes)
|
|
current_time = time.time()
|
|
old_sessions = [
|
|
sid for sid, clear_time in self.cleared_sessions.items()
|
|
if (current_time - clear_time) > 300
|
|
]
|
|
for sid in old_sessions:
|
|
del self.cleared_sessions[sid]
|
|
|
|
def get_session_for_display(self, display_id: str) -> Optional[str]:
|
|
"""Get active session for a display."""
|
|
return self.active_sessions.get(display_id)
|
|
|
|
def reset_tracking(self):
|
|
"""Reset all tracking state."""
|
|
self.tracker.reset_tracking()
|
|
self.active_sessions.clear()
|
|
self.session_vehicles.clear()
|
|
self.cleared_sessions.clear()
|
|
logger.info("Tracking pipeline integration reset")
|
|
|
|
def get_statistics(self) -> Dict[str, Any]:
|
|
"""Get comprehensive statistics."""
|
|
tracker_stats = self.tracker.get_statistics()
|
|
validator_stats = self.validator.get_statistics()
|
|
|
|
return {
|
|
'integration': self.stats,
|
|
'tracker': tracker_stats,
|
|
'validator': validator_stats,
|
|
'active_sessions': len(self.active_sessions),
|
|
'cleared_sessions': len(self.cleared_sessions)
|
|
}
|
|
|
|
def cleanup(self):
|
|
"""Cleanup resources."""
|
|
self.executor.shutdown(wait=False)
|
|
self.reset_tracking()
|
|
logger.info("Tracking pipeline integration cleaned up") |