447 lines
No EOL
17 KiB
Python
447 lines
No EOL
17 KiB
Python
"""
|
|
Tracking-Pipeline Integration Module.
|
|
Connects the tracking system with the main detection pipeline and manages the flow.
|
|
"""
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from typing import Dict, Optional, Any, List, Tuple
|
|
import asyncio
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import numpy as np
|
|
|
|
from .tracker import VehicleTracker, TrackedVehicle
|
|
from .validator import StableCarValidator, ValidationResult, VehicleState
|
|
from ..models.inference import YOLOWrapper
|
|
from ..models.pipeline import PipelineParser
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TrackingPipelineIntegration:
|
|
"""
|
|
Integrates vehicle tracking with the detection pipeline.
|
|
Manages tracking state transitions and pipeline execution triggers.
|
|
"""
|
|
|
|
def __init__(self, pipeline_parser: PipelineParser, model_manager: Any, message_sender=None):
|
|
"""
|
|
Initialize tracking-pipeline integration.
|
|
|
|
Args:
|
|
pipeline_parser: Pipeline parser with loaded configuration
|
|
model_manager: Model manager for loading models
|
|
message_sender: Optional callback function for sending WebSocket messages
|
|
"""
|
|
self.pipeline_parser = pipeline_parser
|
|
self.model_manager = model_manager
|
|
self.message_sender = message_sender
|
|
|
|
# Initialize tracking components
|
|
tracking_config = pipeline_parser.tracking_config.__dict__ if pipeline_parser.tracking_config else {}
|
|
self.tracker = VehicleTracker(tracking_config)
|
|
self.validator = StableCarValidator()
|
|
|
|
# Tracking model
|
|
self.tracking_model: Optional[YOLOWrapper] = None
|
|
self.tracking_model_id = None
|
|
|
|
# Session management
|
|
self.active_sessions: Dict[str, str] = {} # display_id -> session_id
|
|
self.session_vehicles: Dict[str, int] = {} # session_id -> track_id
|
|
self.cleared_sessions: Dict[str, float] = {} # session_id -> clear_time
|
|
self.pending_vehicles: Dict[str, int] = {} # display_id -> track_id (waiting for session ID)
|
|
|
|
# Thread pool for pipeline execution
|
|
self.executor = ThreadPoolExecutor(max_workers=2)
|
|
|
|
# Statistics
|
|
self.stats = {
|
|
'frames_processed': 0,
|
|
'vehicles_detected': 0,
|
|
'vehicles_validated': 0,
|
|
'pipelines_executed': 0
|
|
}
|
|
|
|
# Test mode for mock detection
|
|
self.test_mode = True
|
|
|
|
logger.info("TrackingPipelineIntegration initialized")
|
|
|
|
async def initialize_tracking_model(self) -> bool:
|
|
"""
|
|
Load and initialize the tracking model.
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
if not self.pipeline_parser.tracking_config:
|
|
logger.warning("No tracking configuration found in pipeline")
|
|
return False
|
|
|
|
model_file = self.pipeline_parser.tracking_config.model_file
|
|
model_id = self.pipeline_parser.tracking_config.model_id
|
|
|
|
if not model_file:
|
|
logger.warning("No tracking model file specified")
|
|
return False
|
|
|
|
# Load tracking model
|
|
logger.info(f"Loading tracking model: {model_id} ({model_file})")
|
|
# Get the model ID from the ModelManager context
|
|
# We need the actual model ID, not the model string identifier
|
|
# For now, let's extract it from the model manager
|
|
pipeline_models = list(self.model_manager.get_all_downloaded_models())
|
|
if pipeline_models:
|
|
actual_model_id = pipeline_models[0] # Use the first available model
|
|
self.tracking_model = self.model_manager.get_yolo_model(actual_model_id, model_file)
|
|
else:
|
|
logger.error("No models available in ModelManager")
|
|
return False
|
|
self.tracking_model_id = model_id
|
|
|
|
if self.tracking_model:
|
|
logger.info(f"Tracking model {model_id} loaded successfully")
|
|
return True
|
|
else:
|
|
logger.error(f"Failed to load tracking model {model_id}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error initializing tracking model: {e}", exc_info=True)
|
|
return False
|
|
|
|
async def process_frame(self,
|
|
frame: np.ndarray,
|
|
display_id: str,
|
|
subscription_id: str,
|
|
session_id: Optional[str] = None) -> Dict[str, Any]:
|
|
"""
|
|
Process a frame through tracking and potentially the detection pipeline.
|
|
|
|
Args:
|
|
frame: Input frame to process
|
|
display_id: Display identifier
|
|
subscription_id: Full subscription identifier
|
|
session_id: Optional session ID from backend
|
|
|
|
Returns:
|
|
Dictionary with processing results
|
|
"""
|
|
start_time = time.time()
|
|
result = {
|
|
'tracked_vehicles': [],
|
|
'validated_vehicle': None,
|
|
'pipeline_result': None,
|
|
'session_id': session_id,
|
|
'processing_time': 0.0
|
|
}
|
|
|
|
try:
|
|
# Update stats
|
|
self.stats['frames_processed'] += 1
|
|
|
|
# Run tracking model
|
|
if self.tracking_model:
|
|
# Run inference with tracking
|
|
tracking_results = self.tracking_model.track(
|
|
frame,
|
|
confidence_threshold=self.tracker.min_confidence,
|
|
trigger_classes=self.tracker.trigger_classes,
|
|
persist=True
|
|
)
|
|
|
|
# Debug: Log raw detection results
|
|
if tracking_results and hasattr(tracking_results, 'detections'):
|
|
raw_detections = len(tracking_results.detections)
|
|
if raw_detections > 0:
|
|
class_names = [detection.class_name for detection in tracking_results.detections]
|
|
logger.info(f"[DEBUG] Raw detections: {raw_detections}, classes: {class_names}")
|
|
else:
|
|
logger.debug(f"[DEBUG] No raw detections found")
|
|
else:
|
|
logger.debug(f"[DEBUG] No tracking results or detections attribute")
|
|
|
|
# Process tracking results
|
|
tracked_vehicles = self.tracker.process_detections(
|
|
tracking_results,
|
|
display_id,
|
|
frame
|
|
)
|
|
|
|
result['tracked_vehicles'] = [
|
|
{
|
|
'track_id': v.track_id,
|
|
'bbox': v.bbox,
|
|
'confidence': v.confidence,
|
|
'is_stable': v.is_stable,
|
|
'session_id': v.session_id
|
|
}
|
|
for v in tracked_vehicles
|
|
]
|
|
|
|
# Log tracking info periodically
|
|
if self.stats['frames_processed'] % 30 == 0: # Every 30 frames
|
|
logger.debug(f"Tracking: {len(tracked_vehicles)} vehicles, "
|
|
f"display={display_id}")
|
|
|
|
# Get stable vehicles for validation
|
|
stable_vehicles = self.tracker.get_stable_vehicles(display_id)
|
|
|
|
# Validate and potentially process stable vehicles
|
|
for vehicle in stable_vehicles:
|
|
# Check if vehicle is already processed or has session
|
|
if vehicle.processed_pipeline:
|
|
continue
|
|
|
|
# Check for session cleared (post-fueling)
|
|
if session_id and vehicle.session_id == session_id:
|
|
# Same vehicle with same session, skip
|
|
continue
|
|
|
|
# Check if this was a recently cleared session
|
|
session_cleared = False
|
|
if vehicle.session_id in self.cleared_sessions:
|
|
clear_time = self.cleared_sessions[vehicle.session_id]
|
|
if (time.time() - clear_time) < 30: # 30 second cooldown
|
|
session_cleared = True
|
|
|
|
# Skip same car after session clear
|
|
if self.validator.should_skip_same_car(vehicle, session_cleared):
|
|
continue
|
|
|
|
# Validate vehicle
|
|
validation_result = self.validator.validate_vehicle(vehicle, frame.shape)
|
|
|
|
if validation_result.is_valid and validation_result.should_process:
|
|
logger.info(f"Vehicle {vehicle.track_id} validated for processing: "
|
|
f"{validation_result.reason}")
|
|
|
|
result['validated_vehicle'] = {
|
|
'track_id': vehicle.track_id,
|
|
'state': validation_result.state.value,
|
|
'confidence': validation_result.confidence
|
|
}
|
|
|
|
# Send mock image detection message in test mode
|
|
# Note: Backend will generate and send back session ID via setSessionId
|
|
if self.test_mode:
|
|
await self._send_mock_detection(subscription_id, None)
|
|
|
|
# Mark vehicle as pending session ID assignment
|
|
self.pending_vehicles[display_id] = vehicle.track_id
|
|
logger.info(f"Vehicle {vehicle.track_id} waiting for session ID from backend")
|
|
|
|
# Execute detection pipeline (placeholder for Phase 5)
|
|
pipeline_result = await self._execute_pipeline(
|
|
frame,
|
|
vehicle,
|
|
display_id,
|
|
None, # No session ID yet
|
|
subscription_id
|
|
)
|
|
|
|
result['pipeline_result'] = pipeline_result
|
|
# No session_id in result yet - backend will provide it
|
|
self.stats['pipelines_executed'] += 1
|
|
|
|
# Only process one vehicle per frame
|
|
break
|
|
|
|
self.stats['vehicles_detected'] = len(tracked_vehicles)
|
|
self.stats['vehicles_validated'] = len(stable_vehicles)
|
|
|
|
else:
|
|
logger.warning("No tracking model available")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in tracking pipeline: {e}", exc_info=True)
|
|
|
|
|
|
result['processing_time'] = time.time() - start_time
|
|
return result
|
|
|
|
async def _execute_pipeline(self,
|
|
frame: np.ndarray,
|
|
vehicle: TrackedVehicle,
|
|
display_id: str,
|
|
session_id: str,
|
|
subscription_id: str) -> Dict[str, Any]:
|
|
"""
|
|
Execute the main detection pipeline for a validated vehicle.
|
|
This is a placeholder for Phase 5 implementation.
|
|
|
|
Args:
|
|
frame: Input frame
|
|
vehicle: Validated tracked vehicle
|
|
display_id: Display identifier
|
|
session_id: Session identifier
|
|
subscription_id: Full subscription identifier
|
|
|
|
Returns:
|
|
Pipeline execution results
|
|
"""
|
|
logger.info(f"Executing pipeline for vehicle {vehicle.track_id}, "
|
|
f"session={session_id}, display={display_id}")
|
|
|
|
# Placeholder for Phase 5 pipeline execution
|
|
# This will be implemented when we create the detection module
|
|
pipeline_result = {
|
|
'status': 'pending',
|
|
'message': 'Pipeline execution will be implemented in Phase 5',
|
|
'vehicle_id': vehicle.track_id,
|
|
'session_id': session_id,
|
|
'bbox': vehicle.bbox,
|
|
'confidence': vehicle.confidence
|
|
}
|
|
|
|
# Simulate pipeline execution
|
|
await asyncio.sleep(0.1)
|
|
|
|
return pipeline_result
|
|
|
|
async def _send_mock_detection(self, subscription_id: str, session_id: str):
|
|
"""
|
|
Send mock image detection message to backend following worker.md specification.
|
|
|
|
Args:
|
|
subscription_id: Full subscription identifier (display-id;camera-id)
|
|
session_id: Session identifier for linking detection to user session
|
|
"""
|
|
try:
|
|
# Import here to avoid circular imports
|
|
from ..communication.messages import create_image_detection
|
|
|
|
# Create flat detection data as required by the model
|
|
detection_data = {
|
|
"carModel": None,
|
|
"carBrand": None,
|
|
"carYear": None,
|
|
"bodyType": None,
|
|
"licensePlateText": None,
|
|
"licensePlateConfidence": None
|
|
}
|
|
|
|
# Get model info from tracking configuration in pipeline.json
|
|
# Use 52 (from models/52/bangchak_poc2) as modelId
|
|
# Use tracking modelId as modelName
|
|
tracking_model_id = 52
|
|
tracking_model_name = "front_rear_detection_v1" # Default
|
|
|
|
if self.pipeline_parser and self.pipeline_parser.tracking_config:
|
|
tracking_model_name = self.pipeline_parser.tracking_config.model_id
|
|
|
|
# Create proper Pydantic message using the helper function
|
|
detection_message = create_image_detection(
|
|
subscription_identifier=subscription_id,
|
|
detection_data=detection_data,
|
|
model_id=tracking_model_id,
|
|
model_name=tracking_model_name
|
|
)
|
|
|
|
# Send to backend via WebSocket if sender is available
|
|
if self.message_sender:
|
|
await self.message_sender(detection_message)
|
|
logger.info(f"[MOCK DETECTION] Sent to backend: {detection_data}")
|
|
else:
|
|
logger.info(f"[MOCK DETECTION] No message sender available, would send: {detection_message}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error sending mock detection: {e}", exc_info=True)
|
|
|
|
def set_session_id(self, display_id: str, session_id: str):
|
|
"""
|
|
Set session ID for a display (from backend).
|
|
This is called when backend sends setSessionId after receiving imageDetection.
|
|
|
|
Args:
|
|
display_id: Display identifier
|
|
session_id: Session identifier
|
|
"""
|
|
self.active_sessions[display_id] = session_id
|
|
logger.info(f"Set session {session_id} for display {display_id}")
|
|
|
|
# Check if we have a pending vehicle for this display
|
|
if display_id in self.pending_vehicles:
|
|
track_id = self.pending_vehicles[display_id]
|
|
|
|
# Mark vehicle as processed with the session ID
|
|
self.tracker.mark_processed(track_id, session_id)
|
|
self.session_vehicles[session_id] = track_id
|
|
|
|
# Remove from pending
|
|
del self.pending_vehicles[display_id]
|
|
|
|
logger.info(f"Assigned session {session_id} to vehicle {track_id}, marked as processed")
|
|
else:
|
|
logger.warning(f"No pending vehicle found for display {display_id} when setting session {session_id}")
|
|
|
|
def clear_session_id(self, session_id: str):
|
|
"""
|
|
Clear session ID (post-fueling).
|
|
|
|
Args:
|
|
session_id: Session identifier to clear
|
|
"""
|
|
# Mark session as cleared
|
|
self.cleared_sessions[session_id] = time.time()
|
|
|
|
# Clear from tracker
|
|
self.tracker.clear_session(session_id)
|
|
|
|
# Remove from active sessions
|
|
display_to_remove = None
|
|
for display_id, sess_id in self.active_sessions.items():
|
|
if sess_id == session_id:
|
|
display_to_remove = display_id
|
|
break
|
|
|
|
if display_to_remove:
|
|
del self.active_sessions[display_to_remove]
|
|
|
|
if session_id in self.session_vehicles:
|
|
del self.session_vehicles[session_id]
|
|
|
|
logger.info(f"Cleared session {session_id}")
|
|
|
|
# Clean old cleared sessions (older than 5 minutes)
|
|
current_time = time.time()
|
|
old_sessions = [
|
|
sid for sid, clear_time in self.cleared_sessions.items()
|
|
if (current_time - clear_time) > 300
|
|
]
|
|
for sid in old_sessions:
|
|
del self.cleared_sessions[sid]
|
|
|
|
def get_session_for_display(self, display_id: str) -> Optional[str]:
|
|
"""Get active session for a display."""
|
|
return self.active_sessions.get(display_id)
|
|
|
|
def reset_tracking(self):
|
|
"""Reset all tracking state."""
|
|
self.tracker.reset_tracking()
|
|
self.active_sessions.clear()
|
|
self.session_vehicles.clear()
|
|
self.cleared_sessions.clear()
|
|
self.pending_vehicles.clear()
|
|
logger.info("Tracking pipeline integration reset")
|
|
|
|
def get_statistics(self) -> Dict[str, Any]:
|
|
"""Get comprehensive statistics."""
|
|
tracker_stats = self.tracker.get_statistics()
|
|
validator_stats = self.validator.get_statistics()
|
|
|
|
return {
|
|
'integration': self.stats,
|
|
'tracker': tracker_stats,
|
|
'validator': validator_stats,
|
|
'active_sessions': len(self.active_sessions),
|
|
'cleared_sessions': len(self.cleared_sessions)
|
|
}
|
|
|
|
def cleanup(self):
|
|
"""Cleanup resources."""
|
|
self.executor.shutdown(wait=False)
|
|
self.reset_tracking()
|
|
logger.info("Tracking pipeline integration cleaned up") |