Refactor: done phase 4

This commit is contained in:
ziesorx 2025-09-23 17:56:40 +07:00
parent 7e8034c6e5
commit 9e4c23c75c
8 changed files with 1533 additions and 37 deletions

View file

@ -0,0 +1,369 @@
"""
Tracking-Pipeline Integration Module.
Connects the tracking system with the main detection pipeline and manages the flow.
"""
import logging
import time
import uuid
from typing import Dict, Optional, Any, List, Tuple
import asyncio
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from .tracker import VehicleTracker, TrackedVehicle
from .validator import StableCarValidator, ValidationResult, VehicleState
from ..models.inference import YOLOWrapper
from ..models.pipeline import PipelineParser
logger = logging.getLogger(__name__)
class TrackingPipelineIntegration:
"""
Integrates vehicle tracking with the detection pipeline.
Manages tracking state transitions and pipeline execution triggers.
"""
def __init__(self, pipeline_parser: PipelineParser, model_manager: Any):
"""
Initialize tracking-pipeline integration.
Args:
pipeline_parser: Pipeline parser with loaded configuration
model_manager: Model manager for loading models
"""
self.pipeline_parser = pipeline_parser
self.model_manager = model_manager
# Initialize tracking components
tracking_config = pipeline_parser.tracking_config.__dict__ if pipeline_parser.tracking_config else {}
self.tracker = VehicleTracker(tracking_config)
self.validator = StableCarValidator()
# Tracking model
self.tracking_model: Optional[YOLOWrapper] = None
self.tracking_model_id = None
# Session management
self.active_sessions: Dict[str, str] = {} # display_id -> session_id
self.session_vehicles: Dict[str, int] = {} # session_id -> track_id
self.cleared_sessions: Dict[str, float] = {} # session_id -> clear_time
# Thread pool for pipeline execution
self.executor = ThreadPoolExecutor(max_workers=2)
# Statistics
self.stats = {
'frames_processed': 0,
'vehicles_detected': 0,
'vehicles_validated': 0,
'pipelines_executed': 0
}
logger.info("TrackingPipelineIntegration initialized")
async def initialize_tracking_model(self) -> bool:
"""
Load and initialize the tracking model.
Returns:
True if successful, False otherwise
"""
try:
if not self.pipeline_parser.tracking_config:
logger.warning("No tracking configuration found in pipeline")
return False
model_file = self.pipeline_parser.tracking_config.model_file
model_id = self.pipeline_parser.tracking_config.model_id
if not model_file:
logger.warning("No tracking model file specified")
return False
# Load tracking model
logger.info(f"Loading tracking model: {model_id} ({model_file})")
# Get the model ID from the ModelManager context
# We need the actual model ID, not the model string identifier
# For now, let's extract it from the model manager
pipeline_models = list(self.model_manager.get_all_downloaded_models())
if pipeline_models:
actual_model_id = pipeline_models[0] # Use the first available model
self.tracking_model = self.model_manager.get_yolo_model(actual_model_id, model_file)
else:
logger.error("No models available in ModelManager")
return False
self.tracking_model_id = model_id
if self.tracking_model:
logger.info(f"Tracking model {model_id} loaded successfully")
return True
else:
logger.error(f"Failed to load tracking model {model_id}")
return False
except Exception as e:
logger.error(f"Error initializing tracking model: {e}", exc_info=True)
return False
async def process_frame(self,
frame: np.ndarray,
display_id: str,
subscription_id: str,
session_id: Optional[str] = None) -> Dict[str, Any]:
"""
Process a frame through tracking and potentially the detection pipeline.
Args:
frame: Input frame to process
display_id: Display identifier
subscription_id: Full subscription identifier
session_id: Optional session ID from backend
Returns:
Dictionary with processing results
"""
start_time = time.time()
result = {
'tracked_vehicles': [],
'validated_vehicle': None,
'pipeline_result': None,
'session_id': session_id,
'processing_time': 0.0
}
try:
# Update stats
self.stats['frames_processed'] += 1
# Run tracking model
if self.tracking_model:
# Run inference with tracking
tracking_results = self.tracking_model.track(
frame,
confidence_threshold=self.tracker.min_confidence,
trigger_classes=self.tracker.trigger_classes,
persist=True
)
# Process tracking results
tracked_vehicles = self.tracker.process_detections(
tracking_results,
display_id,
frame
)
result['tracked_vehicles'] = [
{
'track_id': v.track_id,
'bbox': v.bbox,
'confidence': v.confidence,
'is_stable': v.is_stable,
'session_id': v.session_id
}
for v in tracked_vehicles
]
# Log tracking info periodically
if self.stats['frames_processed'] % 30 == 0: # Every 30 frames
logger.debug(f"Tracking: {len(tracked_vehicles)} vehicles, "
f"display={display_id}")
# Get stable vehicles for validation
stable_vehicles = self.tracker.get_stable_vehicles(display_id)
# Validate and potentially process stable vehicles
for vehicle in stable_vehicles:
# Check if vehicle is already processed or has session
if vehicle.processed_pipeline:
continue
# Check for session cleared (post-fueling)
if session_id and vehicle.session_id == session_id:
# Same vehicle with same session, skip
continue
# Check if this was a recently cleared session
session_cleared = False
if vehicle.session_id in self.cleared_sessions:
clear_time = self.cleared_sessions[vehicle.session_id]
if (time.time() - clear_time) < 30: # 30 second cooldown
session_cleared = True
# Skip same car after session clear
if self.validator.should_skip_same_car(vehicle, session_cleared):
continue
# Validate vehicle
validation_result = self.validator.validate_vehicle(vehicle, frame.shape)
if validation_result.is_valid and validation_result.should_process:
logger.info(f"Vehicle {vehicle.track_id} validated for processing: "
f"{validation_result.reason}")
result['validated_vehicle'] = {
'track_id': vehicle.track_id,
'state': validation_result.state.value,
'confidence': validation_result.confidence
}
# Generate session ID if not provided
if not session_id:
session_id = str(uuid.uuid4())
logger.info(f"Generated session ID: {session_id}")
# Mark vehicle as processed
self.tracker.mark_processed(vehicle.track_id, session_id)
self.session_vehicles[session_id] = vehicle.track_id
self.active_sessions[display_id] = session_id
# Execute detection pipeline (placeholder for Phase 5)
pipeline_result = await self._execute_pipeline(
frame,
vehicle,
display_id,
session_id,
subscription_id
)
result['pipeline_result'] = pipeline_result
result['session_id'] = session_id
self.stats['pipelines_executed'] += 1
# Only process one vehicle per frame
break
self.stats['vehicles_detected'] = len(tracked_vehicles)
self.stats['vehicles_validated'] = len(stable_vehicles)
else:
logger.warning("No tracking model available")
except Exception as e:
logger.error(f"Error in tracking pipeline: {e}", exc_info=True)
result['processing_time'] = time.time() - start_time
return result
async def _execute_pipeline(self,
frame: np.ndarray,
vehicle: TrackedVehicle,
display_id: str,
session_id: str,
subscription_id: str) -> Dict[str, Any]:
"""
Execute the main detection pipeline for a validated vehicle.
This is a placeholder for Phase 5 implementation.
Args:
frame: Input frame
vehicle: Validated tracked vehicle
display_id: Display identifier
session_id: Session identifier
subscription_id: Full subscription identifier
Returns:
Pipeline execution results
"""
logger.info(f"Executing pipeline for vehicle {vehicle.track_id}, "
f"session={session_id}, display={display_id}")
# Placeholder for Phase 5 pipeline execution
# This will be implemented when we create the detection module
pipeline_result = {
'status': 'pending',
'message': 'Pipeline execution will be implemented in Phase 5',
'vehicle_id': vehicle.track_id,
'session_id': session_id,
'bbox': vehicle.bbox,
'confidence': vehicle.confidence
}
# Simulate pipeline execution
await asyncio.sleep(0.1)
return pipeline_result
def set_session_id(self, display_id: str, session_id: str):
"""
Set session ID for a display (from backend).
Args:
display_id: Display identifier
session_id: Session identifier
"""
self.active_sessions[display_id] = session_id
logger.info(f"Set session {session_id} for display {display_id}")
# Find vehicle with this session
vehicle = self.tracker.get_vehicle_by_session(session_id)
if vehicle:
self.session_vehicles[session_id] = vehicle.track_id
def clear_session_id(self, session_id: str):
"""
Clear session ID (post-fueling).
Args:
session_id: Session identifier to clear
"""
# Mark session as cleared
self.cleared_sessions[session_id] = time.time()
# Clear from tracker
self.tracker.clear_session(session_id)
# Remove from active sessions
display_to_remove = None
for display_id, sess_id in self.active_sessions.items():
if sess_id == session_id:
display_to_remove = display_id
break
if display_to_remove:
del self.active_sessions[display_to_remove]
if session_id in self.session_vehicles:
del self.session_vehicles[session_id]
logger.info(f"Cleared session {session_id}")
# Clean old cleared sessions (older than 5 minutes)
current_time = time.time()
old_sessions = [
sid for sid, clear_time in self.cleared_sessions.items()
if (current_time - clear_time) > 300
]
for sid in old_sessions:
del self.cleared_sessions[sid]
def get_session_for_display(self, display_id: str) -> Optional[str]:
"""Get active session for a display."""
return self.active_sessions.get(display_id)
def reset_tracking(self):
"""Reset all tracking state."""
self.tracker.reset_tracking()
self.active_sessions.clear()
self.session_vehicles.clear()
self.cleared_sessions.clear()
logger.info("Tracking pipeline integration reset")
def get_statistics(self) -> Dict[str, Any]:
"""Get comprehensive statistics."""
tracker_stats = self.tracker.get_statistics()
validator_stats = self.validator.get_statistics()
return {
'integration': self.stats,
'tracker': tracker_stats,
'validator': validator_stats,
'active_sessions': len(self.active_sessions),
'cleared_sessions': len(self.cleared_sessions)
}
def cleanup(self):
"""Cleanup resources."""
self.executor.shutdown(wait=False)
self.reset_tracking()
logger.info("Tracking pipeline integration cleaned up")