Refactor: done phase 4
This commit is contained in:
parent
7e8034c6e5
commit
9e4c23c75c
8 changed files with 1533 additions and 37 deletions
369
core/tracking/integration.py
Normal file
369
core/tracking/integration.py
Normal file
|
@ -0,0 +1,369 @@
|
|||
"""
|
||||
Tracking-Pipeline Integration Module.
|
||||
Connects the tracking system with the main detection pipeline and manages the flow.
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Optional, Any, List, Tuple
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import numpy as np
|
||||
|
||||
from .tracker import VehicleTracker, TrackedVehicle
|
||||
from .validator import StableCarValidator, ValidationResult, VehicleState
|
||||
from ..models.inference import YOLOWrapper
|
||||
from ..models.pipeline import PipelineParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrackingPipelineIntegration:
|
||||
"""
|
||||
Integrates vehicle tracking with the detection pipeline.
|
||||
Manages tracking state transitions and pipeline execution triggers.
|
||||
"""
|
||||
|
||||
def __init__(self, pipeline_parser: PipelineParser, model_manager: Any):
|
||||
"""
|
||||
Initialize tracking-pipeline integration.
|
||||
|
||||
Args:
|
||||
pipeline_parser: Pipeline parser with loaded configuration
|
||||
model_manager: Model manager for loading models
|
||||
"""
|
||||
self.pipeline_parser = pipeline_parser
|
||||
self.model_manager = model_manager
|
||||
|
||||
# Initialize tracking components
|
||||
tracking_config = pipeline_parser.tracking_config.__dict__ if pipeline_parser.tracking_config else {}
|
||||
self.tracker = VehicleTracker(tracking_config)
|
||||
self.validator = StableCarValidator()
|
||||
|
||||
# Tracking model
|
||||
self.tracking_model: Optional[YOLOWrapper] = None
|
||||
self.tracking_model_id = None
|
||||
|
||||
# Session management
|
||||
self.active_sessions: Dict[str, str] = {} # display_id -> session_id
|
||||
self.session_vehicles: Dict[str, int] = {} # session_id -> track_id
|
||||
self.cleared_sessions: Dict[str, float] = {} # session_id -> clear_time
|
||||
|
||||
# Thread pool for pipeline execution
|
||||
self.executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
# Statistics
|
||||
self.stats = {
|
||||
'frames_processed': 0,
|
||||
'vehicles_detected': 0,
|
||||
'vehicles_validated': 0,
|
||||
'pipelines_executed': 0
|
||||
}
|
||||
|
||||
logger.info("TrackingPipelineIntegration initialized")
|
||||
|
||||
async def initialize_tracking_model(self) -> bool:
|
||||
"""
|
||||
Load and initialize the tracking model.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not self.pipeline_parser.tracking_config:
|
||||
logger.warning("No tracking configuration found in pipeline")
|
||||
return False
|
||||
|
||||
model_file = self.pipeline_parser.tracking_config.model_file
|
||||
model_id = self.pipeline_parser.tracking_config.model_id
|
||||
|
||||
if not model_file:
|
||||
logger.warning("No tracking model file specified")
|
||||
return False
|
||||
|
||||
# Load tracking model
|
||||
logger.info(f"Loading tracking model: {model_id} ({model_file})")
|
||||
# Get the model ID from the ModelManager context
|
||||
# We need the actual model ID, not the model string identifier
|
||||
# For now, let's extract it from the model manager
|
||||
pipeline_models = list(self.model_manager.get_all_downloaded_models())
|
||||
if pipeline_models:
|
||||
actual_model_id = pipeline_models[0] # Use the first available model
|
||||
self.tracking_model = self.model_manager.get_yolo_model(actual_model_id, model_file)
|
||||
else:
|
||||
logger.error("No models available in ModelManager")
|
||||
return False
|
||||
self.tracking_model_id = model_id
|
||||
|
||||
if self.tracking_model:
|
||||
logger.info(f"Tracking model {model_id} loaded successfully")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to load tracking model {model_id}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing tracking model: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def process_frame(self,
|
||||
frame: np.ndarray,
|
||||
display_id: str,
|
||||
subscription_id: str,
|
||||
session_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a frame through tracking and potentially the detection pipeline.
|
||||
|
||||
Args:
|
||||
frame: Input frame to process
|
||||
display_id: Display identifier
|
||||
subscription_id: Full subscription identifier
|
||||
session_id: Optional session ID from backend
|
||||
|
||||
Returns:
|
||||
Dictionary with processing results
|
||||
"""
|
||||
start_time = time.time()
|
||||
result = {
|
||||
'tracked_vehicles': [],
|
||||
'validated_vehicle': None,
|
||||
'pipeline_result': None,
|
||||
'session_id': session_id,
|
||||
'processing_time': 0.0
|
||||
}
|
||||
|
||||
try:
|
||||
# Update stats
|
||||
self.stats['frames_processed'] += 1
|
||||
|
||||
# Run tracking model
|
||||
if self.tracking_model:
|
||||
# Run inference with tracking
|
||||
tracking_results = self.tracking_model.track(
|
||||
frame,
|
||||
confidence_threshold=self.tracker.min_confidence,
|
||||
trigger_classes=self.tracker.trigger_classes,
|
||||
persist=True
|
||||
)
|
||||
|
||||
# Process tracking results
|
||||
tracked_vehicles = self.tracker.process_detections(
|
||||
tracking_results,
|
||||
display_id,
|
||||
frame
|
||||
)
|
||||
|
||||
result['tracked_vehicles'] = [
|
||||
{
|
||||
'track_id': v.track_id,
|
||||
'bbox': v.bbox,
|
||||
'confidence': v.confidence,
|
||||
'is_stable': v.is_stable,
|
||||
'session_id': v.session_id
|
||||
}
|
||||
for v in tracked_vehicles
|
||||
]
|
||||
|
||||
# Log tracking info periodically
|
||||
if self.stats['frames_processed'] % 30 == 0: # Every 30 frames
|
||||
logger.debug(f"Tracking: {len(tracked_vehicles)} vehicles, "
|
||||
f"display={display_id}")
|
||||
|
||||
# Get stable vehicles for validation
|
||||
stable_vehicles = self.tracker.get_stable_vehicles(display_id)
|
||||
|
||||
# Validate and potentially process stable vehicles
|
||||
for vehicle in stable_vehicles:
|
||||
# Check if vehicle is already processed or has session
|
||||
if vehicle.processed_pipeline:
|
||||
continue
|
||||
|
||||
# Check for session cleared (post-fueling)
|
||||
if session_id and vehicle.session_id == session_id:
|
||||
# Same vehicle with same session, skip
|
||||
continue
|
||||
|
||||
# Check if this was a recently cleared session
|
||||
session_cleared = False
|
||||
if vehicle.session_id in self.cleared_sessions:
|
||||
clear_time = self.cleared_sessions[vehicle.session_id]
|
||||
if (time.time() - clear_time) < 30: # 30 second cooldown
|
||||
session_cleared = True
|
||||
|
||||
# Skip same car after session clear
|
||||
if self.validator.should_skip_same_car(vehicle, session_cleared):
|
||||
continue
|
||||
|
||||
# Validate vehicle
|
||||
validation_result = self.validator.validate_vehicle(vehicle, frame.shape)
|
||||
|
||||
if validation_result.is_valid and validation_result.should_process:
|
||||
logger.info(f"Vehicle {vehicle.track_id} validated for processing: "
|
||||
f"{validation_result.reason}")
|
||||
|
||||
result['validated_vehicle'] = {
|
||||
'track_id': vehicle.track_id,
|
||||
'state': validation_result.state.value,
|
||||
'confidence': validation_result.confidence
|
||||
}
|
||||
|
||||
# Generate session ID if not provided
|
||||
if not session_id:
|
||||
session_id = str(uuid.uuid4())
|
||||
logger.info(f"Generated session ID: {session_id}")
|
||||
|
||||
# Mark vehicle as processed
|
||||
self.tracker.mark_processed(vehicle.track_id, session_id)
|
||||
self.session_vehicles[session_id] = vehicle.track_id
|
||||
self.active_sessions[display_id] = session_id
|
||||
|
||||
# Execute detection pipeline (placeholder for Phase 5)
|
||||
pipeline_result = await self._execute_pipeline(
|
||||
frame,
|
||||
vehicle,
|
||||
display_id,
|
||||
session_id,
|
||||
subscription_id
|
||||
)
|
||||
|
||||
result['pipeline_result'] = pipeline_result
|
||||
result['session_id'] = session_id
|
||||
self.stats['pipelines_executed'] += 1
|
||||
|
||||
# Only process one vehicle per frame
|
||||
break
|
||||
|
||||
self.stats['vehicles_detected'] = len(tracked_vehicles)
|
||||
self.stats['vehicles_validated'] = len(stable_vehicles)
|
||||
|
||||
else:
|
||||
logger.warning("No tracking model available")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in tracking pipeline: {e}", exc_info=True)
|
||||
|
||||
result['processing_time'] = time.time() - start_time
|
||||
return result
|
||||
|
||||
async def _execute_pipeline(self,
|
||||
frame: np.ndarray,
|
||||
vehicle: TrackedVehicle,
|
||||
display_id: str,
|
||||
session_id: str,
|
||||
subscription_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute the main detection pipeline for a validated vehicle.
|
||||
This is a placeholder for Phase 5 implementation.
|
||||
|
||||
Args:
|
||||
frame: Input frame
|
||||
vehicle: Validated tracked vehicle
|
||||
display_id: Display identifier
|
||||
session_id: Session identifier
|
||||
subscription_id: Full subscription identifier
|
||||
|
||||
Returns:
|
||||
Pipeline execution results
|
||||
"""
|
||||
logger.info(f"Executing pipeline for vehicle {vehicle.track_id}, "
|
||||
f"session={session_id}, display={display_id}")
|
||||
|
||||
# Placeholder for Phase 5 pipeline execution
|
||||
# This will be implemented when we create the detection module
|
||||
pipeline_result = {
|
||||
'status': 'pending',
|
||||
'message': 'Pipeline execution will be implemented in Phase 5',
|
||||
'vehicle_id': vehicle.track_id,
|
||||
'session_id': session_id,
|
||||
'bbox': vehicle.bbox,
|
||||
'confidence': vehicle.confidence
|
||||
}
|
||||
|
||||
# Simulate pipeline execution
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return pipeline_result
|
||||
|
||||
def set_session_id(self, display_id: str, session_id: str):
|
||||
"""
|
||||
Set session ID for a display (from backend).
|
||||
|
||||
Args:
|
||||
display_id: Display identifier
|
||||
session_id: Session identifier
|
||||
"""
|
||||
self.active_sessions[display_id] = session_id
|
||||
logger.info(f"Set session {session_id} for display {display_id}")
|
||||
|
||||
# Find vehicle with this session
|
||||
vehicle = self.tracker.get_vehicle_by_session(session_id)
|
||||
if vehicle:
|
||||
self.session_vehicles[session_id] = vehicle.track_id
|
||||
|
||||
def clear_session_id(self, session_id: str):
|
||||
"""
|
||||
Clear session ID (post-fueling).
|
||||
|
||||
Args:
|
||||
session_id: Session identifier to clear
|
||||
"""
|
||||
# Mark session as cleared
|
||||
self.cleared_sessions[session_id] = time.time()
|
||||
|
||||
# Clear from tracker
|
||||
self.tracker.clear_session(session_id)
|
||||
|
||||
# Remove from active sessions
|
||||
display_to_remove = None
|
||||
for display_id, sess_id in self.active_sessions.items():
|
||||
if sess_id == session_id:
|
||||
display_to_remove = display_id
|
||||
break
|
||||
|
||||
if display_to_remove:
|
||||
del self.active_sessions[display_to_remove]
|
||||
|
||||
if session_id in self.session_vehicles:
|
||||
del self.session_vehicles[session_id]
|
||||
|
||||
logger.info(f"Cleared session {session_id}")
|
||||
|
||||
# Clean old cleared sessions (older than 5 minutes)
|
||||
current_time = time.time()
|
||||
old_sessions = [
|
||||
sid for sid, clear_time in self.cleared_sessions.items()
|
||||
if (current_time - clear_time) > 300
|
||||
]
|
||||
for sid in old_sessions:
|
||||
del self.cleared_sessions[sid]
|
||||
|
||||
def get_session_for_display(self, display_id: str) -> Optional[str]:
|
||||
"""Get active session for a display."""
|
||||
return self.active_sessions.get(display_id)
|
||||
|
||||
def reset_tracking(self):
|
||||
"""Reset all tracking state."""
|
||||
self.tracker.reset_tracking()
|
||||
self.active_sessions.clear()
|
||||
self.session_vehicles.clear()
|
||||
self.cleared_sessions.clear()
|
||||
logger.info("Tracking pipeline integration reset")
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive statistics."""
|
||||
tracker_stats = self.tracker.get_statistics()
|
||||
validator_stats = self.validator.get_statistics()
|
||||
|
||||
return {
|
||||
'integration': self.stats,
|
||||
'tracker': tracker_stats,
|
||||
'validator': validator_stats,
|
||||
'active_sessions': len(self.active_sessions),
|
||||
'cleared_sessions': len(self.cleared_sessions)
|
||||
}
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources."""
|
||||
self.executor.shutdown(wait=False)
|
||||
self.reset_tracking()
|
||||
logger.info("Tracking pipeline integration cleaned up")
|
Loading…
Add table
Add a link
Reference in a new issue