From 9e4c23c75c5296f3949e277ad605455e07f4c735 Mon Sep 17 00:00:00 2001 From: ziesorx Date: Tue, 23 Sep 2025 17:56:40 +0700 Subject: [PATCH] Refactor: done phase 4 --- REFACTOR_PLAN.md | 50 ++-- core/communication/websocket.py | 188 ++++++++++++++- core/models/manager.py | 80 ++++++- core/streaming/manager.py | 108 ++++++++- core/tracking/__init__.py | 15 +- core/tracking/integration.py | 369 +++++++++++++++++++++++++++++ core/tracking/tracker.py | 352 +++++++++++++++++++++++++++ core/tracking/validator.py | 408 ++++++++++++++++++++++++++++++++ 8 files changed, 1533 insertions(+), 37 deletions(-) create mode 100644 core/tracking/integration.py create mode 100644 core/tracking/tracker.py create mode 100644 core/tracking/validator.py diff --git a/REFACTOR_PLAN.md b/REFACTOR_PLAN.md index ca8558f..42bffda 100644 --- a/REFACTOR_PLAN.md +++ b/REFACTOR_PLAN.md @@ -238,32 +238,42 @@ core/ - ✅ **Production Ready**: Stable concurrent streaming from multiple camera sources - ✅ **Dependencies**: Added opencv-python, numpy, and requests to requirements.txt -## 📋 Phase 4: Vehicle Tracking System +## ✅ Phase 4: Vehicle Tracking System - COMPLETED ### 4.1 Tracking Module (`core/tracking/`) -- [ ] **Create `tracker.py`** - Vehicle tracking implementation - - [ ] Implement continuous tracking with `front_rear_detection_v1.pt` - - [ ] Add vehicle identification and persistence - - [ ] Implement tracking state management - - [ ] Add bounding box tracking and motion analysis +- ✅ **Create `tracker.py`** - Vehicle tracking implementation + - ✅ Implement continuous tracking with configurable model (front_rear_detection_v1.pt) + - ✅ Add vehicle identification and persistence with TrackedVehicle dataclass + - ✅ Implement tracking state management with thread-safe operations + - ✅ Add bounding box tracking and motion analysis with position history -- [ ] **Create `validator.py`** - Stable car validation - - [ ] Implement stable car detection algorithm - - [ ] Add passing-by vs. fueling car differentiation - - [ ] Implement validation thresholds and timing - - [ ] Add confidence scoring for validation decisions +- ✅ **Create `validator.py`** - Stable car validation + - ✅ Implement stable car detection algorithm with multiple validation criteria + - ✅ Add passing-by vs. fueling car differentiation using velocity and position analysis + - ✅ Implement validation thresholds and timing with configurable parameters + - ✅ Add confidence scoring for validation decisions with state history -- [ ] **Create `integration.py`** - Tracking-pipeline integration - - [ ] Connect tracking system with main pipeline - - [ ] Handle tracking state transitions - - [ ] Implement post-session tracking validation - - [ ] Add same-car validation after sessionId cleared +- ✅ **Create `integration.py`** - Tracking-pipeline integration + - ✅ Connect tracking system with main pipeline through TrackingPipelineIntegration + - ✅ Handle tracking state transitions and session management + - ✅ Implement post-session tracking validation with cooldown periods + - ✅ Add same-car validation after sessionId cleared with 30-second cooldown ### 4.2 Testing Phase 4 -- [ ] Test continuous vehicle tracking functionality -- [ ] Test stable car validation logic -- [ ] Test integration with existing pipeline -- [ ] Verify tracking performance and accuracy +- ✅ Test continuous vehicle tracking functionality +- ✅ Test stable car validation logic +- ✅ Test integration with existing pipeline +- ✅ Verify tracking performance and accuracy + +### 4.3 Phase 4 Results +- ✅ **VehicleTracker**: Complete tracking implementation with YOLO tracking integration, position history, and stability calculations +- ✅ **StableCarValidator**: Sophisticated validation logic using velocity, position variance, and state consistency +- ✅ **TrackingPipelineIntegration**: Full integration with pipeline system including session management and async processing +- ✅ **StreamManager Integration**: Updated streaming manager to process tracking on every frame with proper threading +- ✅ **Thread-Safe Operations**: All tracking operations are thread-safe with proper locking mechanisms +- ✅ **Configurable Parameters**: All tracking parameters are configurable through pipeline.json +- ✅ **Session Management**: Complete session lifecycle management with post-fueling validation +- ✅ **Statistics and Monitoring**: Comprehensive statistics collection for tracking performance ## 📋 Phase 5: Detection Pipeline System diff --git a/core/communication/websocket.py b/core/communication/websocket.py index a756002..71077f0 100644 --- a/core/communication/websocket.py +++ b/core/communication/websocket.py @@ -18,6 +18,8 @@ from .models import ( ) from .state import worker_state, SystemMetrics from ..models import ModelManager +from ..streaming.manager import shared_stream_manager +from ..tracking.integration import TrackingPipelineIntegration logger = logging.getLogger(__name__) @@ -199,17 +201,8 @@ class WebSocketHandler: # Phase 2: Download and manage models await self._ensure_models(message.subscriptions) - # TODO: Phase 3 - Integrate with streaming management - # For now, just log the subscription changes - for subscription in message.subscriptions: - logger.info(f" Subscription: {subscription.subscriptionIdentifier} -> " - f"Model {subscription.modelId} ({subscription.modelName})") - if subscription.rtspUrl: - logger.debug(f" RTSP: {subscription.rtspUrl}") - if subscription.snapshotUrl: - logger.debug(f" Snapshot: {subscription.snapshotUrl} ({subscription.snapshotInterval}ms)") - if subscription.modelUrl: - logger.debug(f" Model: {subscription.modelUrl}") + # Phase 3 & 4: Integrate with streaming management and tracking + await self._update_stream_subscriptions(message.subscriptions) logger.info("Subscription list updated successfully") @@ -260,6 +253,168 @@ class WebSocketHandler: logger.info(f"[Model Management] Successfully ensured {success_count}/{len(unique_models)} models") + async def _update_stream_subscriptions(self, subscriptions) -> None: + """Update streaming subscriptions with tracking integration.""" + try: + # Convert subscriptions to the format expected by StreamManager + subscription_payloads = [] + for subscription in subscriptions: + payload = { + 'subscriptionIdentifier': subscription.subscriptionIdentifier, + 'rtspUrl': subscription.rtspUrl, + 'snapshotUrl': subscription.snapshotUrl, + 'snapshotInterval': subscription.snapshotInterval, + 'modelId': subscription.modelId, + 'modelUrl': subscription.modelUrl, + 'modelName': subscription.modelName + } + # Add crop coordinates if present + if hasattr(subscription, 'cropX1'): + payload.update({ + 'cropX1': subscription.cropX1, + 'cropY1': subscription.cropY1, + 'cropX2': subscription.cropX2, + 'cropY2': subscription.cropY2 + }) + subscription_payloads.append(payload) + + # Reconcile subscriptions with StreamManager + logger.info("[Streaming] Reconciling stream subscriptions with tracking") + reconcile_result = await self._reconcile_subscriptions_with_tracking(subscription_payloads) + + logger.info(f"[Streaming] Subscription reconciliation complete: " + f"added={reconcile_result.get('added', 0)}, " + f"removed={reconcile_result.get('removed', 0)}, " + f"failed={reconcile_result.get('failed', 0)}") + + except Exception as e: + logger.error(f"Error updating stream subscriptions: {e}", exc_info=True) + + async def _reconcile_subscriptions_with_tracking(self, target_subscriptions) -> dict: + """Reconcile subscriptions with tracking integration.""" + try: + # First, we need to create tracking integrations for each unique model + tracking_integrations = {} + + for subscription_payload in target_subscriptions: + model_id = subscription_payload['modelId'] + + # Create tracking integration if not already created + if model_id not in tracking_integrations: + # Get pipeline configuration for this model + pipeline_parser = model_manager.get_pipeline_config(model_id) + if pipeline_parser: + # Create tracking integration + tracking_integration = TrackingPipelineIntegration( + pipeline_parser, model_manager + ) + + # Initialize tracking model + success = await tracking_integration.initialize_tracking_model() + if success: + tracking_integrations[model_id] = tracking_integration + logger.info(f"[Tracking] Created tracking integration for model {model_id}") + else: + logger.warning(f"[Tracking] Failed to initialize tracking for model {model_id}") + else: + logger.warning(f"[Tracking] No pipeline config found for model {model_id}") + + # Now reconcile with StreamManager, adding tracking integrations + current_subscription_ids = set() + for subscription_info in shared_stream_manager.get_all_subscriptions(): + current_subscription_ids.add(subscription_info.subscription_id) + + target_subscription_ids = {sub['subscriptionIdentifier'] for sub in target_subscriptions} + + # Find subscriptions to remove and add + to_remove = current_subscription_ids - target_subscription_ids + to_add = target_subscription_ids - current_subscription_ids + + # Remove old subscriptions + removed_count = 0 + for subscription_id in to_remove: + if shared_stream_manager.remove_subscription(subscription_id): + removed_count += 1 + logger.info(f"[Streaming] Removed subscription {subscription_id}") + + # Add new subscriptions with tracking + added_count = 0 + failed_count = 0 + for subscription_payload in target_subscriptions: + subscription_id = subscription_payload['subscriptionIdentifier'] + if subscription_id in to_add: + success = await self._add_subscription_with_tracking( + subscription_payload, tracking_integrations + ) + if success: + added_count += 1 + logger.info(f"[Streaming] Added subscription {subscription_id} with tracking") + else: + failed_count += 1 + logger.error(f"[Streaming] Failed to add subscription {subscription_id}") + + return { + 'removed': removed_count, + 'added': added_count, + 'failed': failed_count, + 'total_active': len(shared_stream_manager.get_all_subscriptions()) + } + + except Exception as e: + logger.error(f"Error in subscription reconciliation with tracking: {e}", exc_info=True) + return {'removed': 0, 'added': 0, 'failed': 0, 'total_active': 0} + + async def _add_subscription_with_tracking(self, payload, tracking_integrations) -> bool: + """Add a subscription with tracking integration.""" + try: + from ..streaming.manager import StreamConfig + + subscription_id = payload['subscriptionIdentifier'] + camera_id = subscription_id.split(';')[-1] + model_id = payload['modelId'] + + # Get tracking integration for this model + tracking_integration = tracking_integrations.get(model_id) + + # Extract crop coordinates if present + crop_coords = None + if all(key in payload for key in ['cropX1', 'cropY1', 'cropX2', 'cropY2']): + crop_coords = ( + payload['cropX1'], + payload['cropY1'], + payload['cropX2'], + payload['cropY2'] + ) + + # Create stream configuration + stream_config = StreamConfig( + camera_id=camera_id, + rtsp_url=payload.get('rtspUrl'), + snapshot_url=payload.get('snapshotUrl'), + snapshot_interval=payload.get('snapshotInterval', 5000), + max_retries=3, + save_test_frames=False # Disable frame saving, focus on tracking + ) + + # Add subscription to StreamManager with tracking + success = shared_stream_manager.add_subscription( + subscription_id=subscription_id, + stream_config=stream_config, + crop_coords=crop_coords, + model_id=model_id, + model_url=payload.get('modelUrl'), + tracking_integration=tracking_integration + ) + + if success and tracking_integration: + logger.info(f"[Tracking] Subscription {subscription_id} configured with tracking for model {model_id}") + + return success + + except Exception as e: + logger.error(f"Error adding subscription with tracking: {e}", exc_info=True) + return False + async def _ensure_single_model(self, model_id: int, model_url: str, model_name: str) -> bool: """Ensure a single model is downloaded and available.""" try: @@ -303,6 +458,9 @@ class WebSocketHandler: # Update worker state worker_state.set_session_id(display_identifier, session_id) + # Update tracking integrations with session ID + shared_stream_manager.set_session_id(display_identifier, session_id) + async def _handle_set_progression_stage(self, message: SetProgressionStageMessage) -> None: """Handle setProgressionStage message.""" display_identifier = message.payload.displayIdentifier @@ -313,6 +471,14 @@ class WebSocketHandler: # Update worker state worker_state.set_progression_stage(display_identifier, stage) + # If stage indicates session is cleared/finished, clear from tracking + if stage in ['finished', 'cleared', 'idle']: + # Get session ID for this display and clear it + session_id = worker_state.get_session_id(display_identifier) + if session_id: + shared_stream_manager.clear_session_id(session_id) + logger.info(f"[Tracking] Cleared session {session_id} due to progression stage: {stage}") + async def _handle_request_state(self, message: RequestStateMessage) -> None: """Handle requestState message by sending immediate state report.""" logger.debug("[RX Processing] requestState - sending immediate state report") diff --git a/core/models/manager.py b/core/models/manager.py index bbd0f8b..d40c48f 100644 --- a/core/models/manager.py +++ b/core/models/manager.py @@ -358,4 +358,82 @@ class ModelManager: Returns: Set of model IDs that are currently downloaded """ - return self._downloaded_models.copy() \ No newline at end of file + return self._downloaded_models.copy() + + def get_pipeline_config(self, model_id: int) -> Optional[Any]: + """ + Get the pipeline configuration for a model. + + Args: + model_id: The model ID + + Returns: + PipelineConfig object if found, None otherwise + """ + try: + if model_id not in self._downloaded_models: + logger.warning(f"Model {model_id} not downloaded") + return None + + model_path = self._model_paths.get(model_id) + if not model_path: + logger.warning(f"Model path not found for model {model_id}") + return None + + # Import here to avoid circular imports + from .pipeline import PipelineParser + + # Load pipeline.json + pipeline_file = model_path / "pipeline.json" + if not pipeline_file.exists(): + logger.warning(f"No pipeline.json found for model {model_id}") + return None + + # Create PipelineParser object and parse the configuration + pipeline_parser = PipelineParser() + success = pipeline_parser.parse(pipeline_file) + + if success: + return pipeline_parser + else: + logger.error(f"Failed to parse pipeline.json for model {model_id}") + return None + + except Exception as e: + logger.error(f"Error getting pipeline config for model {model_id}: {e}", exc_info=True) + return None + + def get_yolo_model(self, model_id: int, model_filename: str) -> Optional[Any]: + """ + Create a YOLOWrapper instance for a specific model file. + + Args: + model_id: The model ID + model_filename: The .pt model filename + + Returns: + YOLOWrapper instance if successful, None otherwise + """ + try: + # Get the model file path + model_file_path = self.get_model_file_path(model_id, model_filename) + if not model_file_path or not model_file_path.exists(): + logger.error(f"Model file {model_filename} not found for model {model_id}") + return None + + # Import here to avoid circular imports + from .inference import YOLOWrapper + + # Create YOLOWrapper instance + yolo_model = YOLOWrapper( + model_path=model_file_path, + model_id=f"{model_id}_{model_filename}", + device=None # Auto-detect device + ) + + logger.info(f"Created YOLOWrapper for model {model_id}: {model_filename}") + return yolo_model + + except Exception as e: + logger.error(f"Error creating YOLO model for {model_id}:{model_filename}: {e}", exc_info=True) + return None \ No newline at end of file diff --git a/core/streaming/manager.py b/core/streaming/manager.py index 399874f..2e381e9 100644 --- a/core/streaming/manager.py +++ b/core/streaming/manager.py @@ -11,6 +11,7 @@ from collections import defaultdict from .readers import RTSPReader, HTTPSnapshotReader from .buffers import shared_cache_buffer, save_frame_for_testing +from ..tracking.integration import TrackingPipelineIntegration logger = logging.getLogger(__name__) @@ -35,6 +36,9 @@ class SubscriptionInfo: stream_config: StreamConfig created_at: float crop_coords: Optional[tuple] = None + model_id: Optional[str] = None + model_url: Optional[str] = None + tracking_integration: Optional[TrackingPipelineIntegration] = None class StreamManager: @@ -48,7 +52,10 @@ class StreamManager: self._lock = threading.RLock() def add_subscription(self, subscription_id: str, stream_config: StreamConfig, - crop_coords: Optional[tuple] = None) -> bool: + crop_coords: Optional[tuple] = None, + model_id: Optional[str] = None, + model_url: Optional[str] = None, + tracking_integration: Optional[TrackingPipelineIntegration] = None) -> bool: """Add a new subscription. Returns True if successful.""" with self._lock: if subscription_id in self._subscriptions: @@ -63,7 +70,10 @@ class StreamManager: camera_id=camera_id, stream_config=stream_config, created_at=time.time(), - crop_coords=crop_coords + crop_coords=crop_coords, + model_id=model_id, + model_url=model_url, + tracking_integration=tracking_integration ) self._subscriptions[subscription_id] = subscription_info @@ -175,9 +185,64 @@ class StreamManager: save_frame_for_testing(camera_id, frame) break # Only save once per frame + # Process tracking for subscriptions with tracking integration + self._process_tracking_for_camera(camera_id, frame) + except Exception as e: logger.error(f"Error in frame callback for camera {camera_id}: {e}") + def _process_tracking_for_camera(self, camera_id: str, frame): + """Process tracking for all subscriptions of a camera.""" + try: + with self._lock: + for subscription_id in self._camera_subscribers[camera_id]: + subscription_info = self._subscriptions[subscription_id] + + # Skip if no tracking integration + if not subscription_info.tracking_integration: + continue + + # Extract display_id from subscription_id + display_id = subscription_id.split(';')[0] if ';' in subscription_id else subscription_id + + # Process frame through tracking asynchronously + # Note: This is synchronous for now, can be made async in future + try: + # Create a simple asyncio event loop for this frame + import asyncio + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + result = loop.run_until_complete( + subscription_info.tracking_integration.process_frame( + frame, display_id, subscription_id + ) + ) + # Log tracking results + if result: + tracked_count = len(result.get('tracked_vehicles', [])) + validated_vehicle = result.get('validated_vehicle') + pipeline_result = result.get('pipeline_result') + + if tracked_count > 0: + logger.info(f"[Tracking] {camera_id}: {tracked_count} vehicles tracked") + + if validated_vehicle: + logger.info(f"[Tracking] {camera_id}: Vehicle {validated_vehicle['track_id']} " + f"validated as {validated_vehicle['state']} " + f"(confidence: {validated_vehicle['confidence']:.2f})") + + if pipeline_result: + logger.info(f"[Pipeline] {camera_id}: {pipeline_result.get('status', 'unknown')} - " + f"{pipeline_result.get('message', 'no message')}") + finally: + loop.close() + except Exception as track_e: + logger.error(f"Error in tracking for {subscription_id}: {track_e}") + + except Exception as e: + logger.error(f"Error processing tracking for camera {camera_id}: {e}") + def get_frame(self, camera_id: str, crop_coords: Optional[tuple] = None): """Get the latest frame for a camera with optional cropping.""" return shared_cache_buffer.get_frame(camera_id, crop_coords) @@ -280,7 +345,13 @@ class StreamManager: save_test_frames=True # Enable for testing ) - return self.add_subscription(subscription_id, stream_config, crop_coords) + return self.add_subscription( + subscription_id, + stream_config, + crop_coords, + model_id=payload.get('modelId'), + model_url=payload.get('modelUrl') + ) except Exception as e: logger.error(f"Error adding subscription from payload {subscription_id}: {e}") @@ -300,10 +371,38 @@ class StreamManager: logger.info("Stopped all streams and cleared all subscriptions") + def set_session_id(self, display_id: str, session_id: str): + """Set session ID for tracking integration.""" + with self._lock: + for subscription_info in self._subscriptions.values(): + # Check if this subscription matches the display_id + subscription_display_id = subscription_info.subscription_id.split(';')[0] + if subscription_display_id == display_id and subscription_info.tracking_integration: + subscription_info.tracking_integration.set_session_id(display_id, session_id) + logger.debug(f"Set session {session_id} for display {display_id}") + + def clear_session_id(self, session_id: str): + """Clear session ID from tracking integrations.""" + with self._lock: + for subscription_info in self._subscriptions.values(): + if subscription_info.tracking_integration: + subscription_info.tracking_integration.clear_session_id(session_id) + logger.debug(f"Cleared session {session_id}") + + def get_tracking_stats(self) -> Dict[str, Any]: + """Get tracking statistics from all subscriptions.""" + stats = {} + with self._lock: + for subscription_id, subscription_info in self._subscriptions.items(): + if subscription_info.tracking_integration: + stats[subscription_id] = subscription_info.tracking_integration.get_statistics() + return stats + def get_stats(self) -> Dict[str, Any]: """Get comprehensive streaming statistics.""" with self._lock: buffer_stats = shared_cache_buffer.get_stats() + tracking_stats = self.get_tracking_stats() return { 'active_subscriptions': len(self._subscriptions), @@ -314,7 +413,8 @@ class StreamManager: camera_id: len(subscribers) for camera_id, subscribers in self._camera_subscribers.items() }, - 'buffer_stats': buffer_stats + 'buffer_stats': buffer_stats, + 'tracking_stats': tracking_stats } diff --git a/core/tracking/__init__.py b/core/tracking/__init__.py index bd60536..a493062 100644 --- a/core/tracking/__init__.py +++ b/core/tracking/__init__.py @@ -1 +1,14 @@ -# Tracking module for vehicle tracking and validation \ No newline at end of file +# Tracking module for vehicle tracking and validation + +from .tracker import VehicleTracker, TrackedVehicle +from .validator import StableCarValidator, ValidationResult, VehicleState +from .integration import TrackingPipelineIntegration + +__all__ = [ + 'VehicleTracker', + 'TrackedVehicle', + 'StableCarValidator', + 'ValidationResult', + 'VehicleState', + 'TrackingPipelineIntegration' +] \ No newline at end of file diff --git a/core/tracking/integration.py b/core/tracking/integration.py new file mode 100644 index 0000000..d42d053 --- /dev/null +++ b/core/tracking/integration.py @@ -0,0 +1,369 @@ +""" +Tracking-Pipeline Integration Module. +Connects the tracking system with the main detection pipeline and manages the flow. +""" +import logging +import time +import uuid +from typing import Dict, Optional, Any, List, Tuple +import asyncio +from concurrent.futures import ThreadPoolExecutor +import numpy as np + +from .tracker import VehicleTracker, TrackedVehicle +from .validator import StableCarValidator, ValidationResult, VehicleState +from ..models.inference import YOLOWrapper +from ..models.pipeline import PipelineParser + +logger = logging.getLogger(__name__) + + +class TrackingPipelineIntegration: + """ + Integrates vehicle tracking with the detection pipeline. + Manages tracking state transitions and pipeline execution triggers. + """ + + def __init__(self, pipeline_parser: PipelineParser, model_manager: Any): + """ + Initialize tracking-pipeline integration. + + Args: + pipeline_parser: Pipeline parser with loaded configuration + model_manager: Model manager for loading models + """ + self.pipeline_parser = pipeline_parser + self.model_manager = model_manager + + # Initialize tracking components + tracking_config = pipeline_parser.tracking_config.__dict__ if pipeline_parser.tracking_config else {} + self.tracker = VehicleTracker(tracking_config) + self.validator = StableCarValidator() + + # Tracking model + self.tracking_model: Optional[YOLOWrapper] = None + self.tracking_model_id = None + + # Session management + self.active_sessions: Dict[str, str] = {} # display_id -> session_id + self.session_vehicles: Dict[str, int] = {} # session_id -> track_id + self.cleared_sessions: Dict[str, float] = {} # session_id -> clear_time + + # Thread pool for pipeline execution + self.executor = ThreadPoolExecutor(max_workers=2) + + # Statistics + self.stats = { + 'frames_processed': 0, + 'vehicles_detected': 0, + 'vehicles_validated': 0, + 'pipelines_executed': 0 + } + + logger.info("TrackingPipelineIntegration initialized") + + async def initialize_tracking_model(self) -> bool: + """ + Load and initialize the tracking model. + + Returns: + True if successful, False otherwise + """ + try: + if not self.pipeline_parser.tracking_config: + logger.warning("No tracking configuration found in pipeline") + return False + + model_file = self.pipeline_parser.tracking_config.model_file + model_id = self.pipeline_parser.tracking_config.model_id + + if not model_file: + logger.warning("No tracking model file specified") + return False + + # Load tracking model + logger.info(f"Loading tracking model: {model_id} ({model_file})") + # Get the model ID from the ModelManager context + # We need the actual model ID, not the model string identifier + # For now, let's extract it from the model manager + pipeline_models = list(self.model_manager.get_all_downloaded_models()) + if pipeline_models: + actual_model_id = pipeline_models[0] # Use the first available model + self.tracking_model = self.model_manager.get_yolo_model(actual_model_id, model_file) + else: + logger.error("No models available in ModelManager") + return False + self.tracking_model_id = model_id + + if self.tracking_model: + logger.info(f"Tracking model {model_id} loaded successfully") + return True + else: + logger.error(f"Failed to load tracking model {model_id}") + return False + + except Exception as e: + logger.error(f"Error initializing tracking model: {e}", exc_info=True) + return False + + async def process_frame(self, + frame: np.ndarray, + display_id: str, + subscription_id: str, + session_id: Optional[str] = None) -> Dict[str, Any]: + """ + Process a frame through tracking and potentially the detection pipeline. + + Args: + frame: Input frame to process + display_id: Display identifier + subscription_id: Full subscription identifier + session_id: Optional session ID from backend + + Returns: + Dictionary with processing results + """ + start_time = time.time() + result = { + 'tracked_vehicles': [], + 'validated_vehicle': None, + 'pipeline_result': None, + 'session_id': session_id, + 'processing_time': 0.0 + } + + try: + # Update stats + self.stats['frames_processed'] += 1 + + # Run tracking model + if self.tracking_model: + # Run inference with tracking + tracking_results = self.tracking_model.track( + frame, + confidence_threshold=self.tracker.min_confidence, + trigger_classes=self.tracker.trigger_classes, + persist=True + ) + + # Process tracking results + tracked_vehicles = self.tracker.process_detections( + tracking_results, + display_id, + frame + ) + + result['tracked_vehicles'] = [ + { + 'track_id': v.track_id, + 'bbox': v.bbox, + 'confidence': v.confidence, + 'is_stable': v.is_stable, + 'session_id': v.session_id + } + for v in tracked_vehicles + ] + + # Log tracking info periodically + if self.stats['frames_processed'] % 30 == 0: # Every 30 frames + logger.debug(f"Tracking: {len(tracked_vehicles)} vehicles, " + f"display={display_id}") + + # Get stable vehicles for validation + stable_vehicles = self.tracker.get_stable_vehicles(display_id) + + # Validate and potentially process stable vehicles + for vehicle in stable_vehicles: + # Check if vehicle is already processed or has session + if vehicle.processed_pipeline: + continue + + # Check for session cleared (post-fueling) + if session_id and vehicle.session_id == session_id: + # Same vehicle with same session, skip + continue + + # Check if this was a recently cleared session + session_cleared = False + if vehicle.session_id in self.cleared_sessions: + clear_time = self.cleared_sessions[vehicle.session_id] + if (time.time() - clear_time) < 30: # 30 second cooldown + session_cleared = True + + # Skip same car after session clear + if self.validator.should_skip_same_car(vehicle, session_cleared): + continue + + # Validate vehicle + validation_result = self.validator.validate_vehicle(vehicle, frame.shape) + + if validation_result.is_valid and validation_result.should_process: + logger.info(f"Vehicle {vehicle.track_id} validated for processing: " + f"{validation_result.reason}") + + result['validated_vehicle'] = { + 'track_id': vehicle.track_id, + 'state': validation_result.state.value, + 'confidence': validation_result.confidence + } + + # Generate session ID if not provided + if not session_id: + session_id = str(uuid.uuid4()) + logger.info(f"Generated session ID: {session_id}") + + # Mark vehicle as processed + self.tracker.mark_processed(vehicle.track_id, session_id) + self.session_vehicles[session_id] = vehicle.track_id + self.active_sessions[display_id] = session_id + + # Execute detection pipeline (placeholder for Phase 5) + pipeline_result = await self._execute_pipeline( + frame, + vehicle, + display_id, + session_id, + subscription_id + ) + + result['pipeline_result'] = pipeline_result + result['session_id'] = session_id + self.stats['pipelines_executed'] += 1 + + # Only process one vehicle per frame + break + + self.stats['vehicles_detected'] = len(tracked_vehicles) + self.stats['vehicles_validated'] = len(stable_vehicles) + + else: + logger.warning("No tracking model available") + + except Exception as e: + logger.error(f"Error in tracking pipeline: {e}", exc_info=True) + + result['processing_time'] = time.time() - start_time + return result + + async def _execute_pipeline(self, + frame: np.ndarray, + vehicle: TrackedVehicle, + display_id: str, + session_id: str, + subscription_id: str) -> Dict[str, Any]: + """ + Execute the main detection pipeline for a validated vehicle. + This is a placeholder for Phase 5 implementation. + + Args: + frame: Input frame + vehicle: Validated tracked vehicle + display_id: Display identifier + session_id: Session identifier + subscription_id: Full subscription identifier + + Returns: + Pipeline execution results + """ + logger.info(f"Executing pipeline for vehicle {vehicle.track_id}, " + f"session={session_id}, display={display_id}") + + # Placeholder for Phase 5 pipeline execution + # This will be implemented when we create the detection module + pipeline_result = { + 'status': 'pending', + 'message': 'Pipeline execution will be implemented in Phase 5', + 'vehicle_id': vehicle.track_id, + 'session_id': session_id, + 'bbox': vehicle.bbox, + 'confidence': vehicle.confidence + } + + # Simulate pipeline execution + await asyncio.sleep(0.1) + + return pipeline_result + + def set_session_id(self, display_id: str, session_id: str): + """ + Set session ID for a display (from backend). + + Args: + display_id: Display identifier + session_id: Session identifier + """ + self.active_sessions[display_id] = session_id + logger.info(f"Set session {session_id} for display {display_id}") + + # Find vehicle with this session + vehicle = self.tracker.get_vehicle_by_session(session_id) + if vehicle: + self.session_vehicles[session_id] = vehicle.track_id + + def clear_session_id(self, session_id: str): + """ + Clear session ID (post-fueling). + + Args: + session_id: Session identifier to clear + """ + # Mark session as cleared + self.cleared_sessions[session_id] = time.time() + + # Clear from tracker + self.tracker.clear_session(session_id) + + # Remove from active sessions + display_to_remove = None + for display_id, sess_id in self.active_sessions.items(): + if sess_id == session_id: + display_to_remove = display_id + break + + if display_to_remove: + del self.active_sessions[display_to_remove] + + if session_id in self.session_vehicles: + del self.session_vehicles[session_id] + + logger.info(f"Cleared session {session_id}") + + # Clean old cleared sessions (older than 5 minutes) + current_time = time.time() + old_sessions = [ + sid for sid, clear_time in self.cleared_sessions.items() + if (current_time - clear_time) > 300 + ] + for sid in old_sessions: + del self.cleared_sessions[sid] + + def get_session_for_display(self, display_id: str) -> Optional[str]: + """Get active session for a display.""" + return self.active_sessions.get(display_id) + + def reset_tracking(self): + """Reset all tracking state.""" + self.tracker.reset_tracking() + self.active_sessions.clear() + self.session_vehicles.clear() + self.cleared_sessions.clear() + logger.info("Tracking pipeline integration reset") + + def get_statistics(self) -> Dict[str, Any]: + """Get comprehensive statistics.""" + tracker_stats = self.tracker.get_statistics() + validator_stats = self.validator.get_statistics() + + return { + 'integration': self.stats, + 'tracker': tracker_stats, + 'validator': validator_stats, + 'active_sessions': len(self.active_sessions), + 'cleared_sessions': len(self.cleared_sessions) + } + + def cleanup(self): + """Cleanup resources.""" + self.executor.shutdown(wait=False) + self.reset_tracking() + logger.info("Tracking pipeline integration cleaned up") \ No newline at end of file diff --git a/core/tracking/tracker.py b/core/tracking/tracker.py new file mode 100644 index 0000000..b0799de --- /dev/null +++ b/core/tracking/tracker.py @@ -0,0 +1,352 @@ +""" +Vehicle Tracking Module - Continuous tracking with front_rear_detection model +Implements vehicle identification, persistence, and motion analysis. +""" +import logging +import time +import uuid +from typing import Dict, List, Optional, Tuple, Any +from dataclasses import dataclass, field +import numpy as np +from threading import Lock + +logger = logging.getLogger(__name__) + + +@dataclass +class TrackedVehicle: + """Represents a tracked vehicle with all its state information.""" + track_id: int + first_seen: float + last_seen: float + session_id: Optional[str] = None + display_id: Optional[str] = None + confidence: float = 0.0 + bbox: Tuple[int, int, int, int] = (0, 0, 0, 0) # x1, y1, x2, y2 + center: Tuple[float, float] = (0.0, 0.0) + stable_frames: int = 0 + total_frames: int = 0 + is_stable: bool = False + processed_pipeline: bool = False + last_position_history: List[Tuple[float, float]] = field(default_factory=list) + avg_confidence: float = 0.0 + + def update_position(self, bbox: Tuple[int, int, int, int], confidence: float): + """Update vehicle position and confidence.""" + self.bbox = bbox + self.center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2) + self.last_seen = time.time() + self.confidence = confidence + self.total_frames += 1 + + # Update confidence average + self.avg_confidence = ((self.avg_confidence * (self.total_frames - 1)) + confidence) / self.total_frames + + # Maintain position history (last 10 positions) + self.last_position_history.append(self.center) + if len(self.last_position_history) > 10: + self.last_position_history.pop(0) + + def calculate_stability(self) -> float: + """Calculate stability score based on position history.""" + if len(self.last_position_history) < 2: + return 0.0 + + # Calculate movement variance + positions = np.array(self.last_position_history) + if len(positions) < 2: + return 0.0 + + # Calculate standard deviation of positions + std_x = np.std(positions[:, 0]) + std_y = np.std(positions[:, 1]) + + # Lower variance means more stable (inverse relationship) + # Normalize to 0-1 range (assuming max reasonable std is 50 pixels) + stability = max(0, 1 - (std_x + std_y) / 100) + return stability + + def is_expired(self, timeout_seconds: float = 2.0) -> bool: + """Check if vehicle tracking has expired.""" + return (time.time() - self.last_seen) > timeout_seconds + + +class VehicleTracker: + """ + Main vehicle tracking implementation using YOLO tracking capabilities. + Manages continuous tracking, vehicle identification, and state persistence. + """ + + def __init__(self, tracking_config: Optional[Dict] = None): + """ + Initialize the vehicle tracker. + + Args: + tracking_config: Configuration from pipeline.json tracking section + """ + self.config = tracking_config or {} + self.trigger_classes = self.config.get('triggerClasses', ['front_rear']) + self.min_confidence = self.config.get('minConfidence', 0.6) + + # Tracking state + self.tracked_vehicles: Dict[int, TrackedVehicle] = {} + self.next_track_id = 1 + self.lock = Lock() + + # Tracking parameters + self.stability_threshold = 0.7 + self.min_stable_frames = 5 + self.position_tolerance = 50 # pixels + self.timeout_seconds = 2.0 + + logger.info(f"VehicleTracker initialized with trigger_classes={self.trigger_classes}, " + f"min_confidence={self.min_confidence}") + + def process_detections(self, + results: Any, + display_id: str, + frame: np.ndarray) -> List[TrackedVehicle]: + """ + Process YOLO detection results and update tracking state. + + Args: + results: YOLO detection results with tracking + display_id: Display identifier for this stream + frame: Current frame being processed + + Returns: + List of currently tracked vehicles + """ + current_time = time.time() + active_tracks = [] + + with self.lock: + # Clean up expired tracks + expired_ids = [ + track_id for track_id, vehicle in self.tracked_vehicles.items() + if vehicle.is_expired(self.timeout_seconds) + ] + for track_id in expired_ids: + logger.debug(f"Removing expired track {track_id}") + del self.tracked_vehicles[track_id] + + # Process new detections + if hasattr(results, 'boxes') and results.boxes is not None: + boxes = results.boxes + + # Check if tracking is available + if hasattr(boxes, 'id') and boxes.id is not None: + # Process tracked objects + for i, box in enumerate(boxes): + # Get tracking ID + track_id = int(boxes.id[i].item()) if boxes.id[i] is not None else None + if track_id is None: + continue + + # Get class and confidence + cls_id = int(box.cls.item()) + confidence = float(box.conf.item()) + + # Check if class is in trigger classes + class_name = results.names[cls_id] if hasattr(results, 'names') else str(cls_id) + if class_name not in self.trigger_classes and confidence < self.min_confidence: + continue + + # Get bounding box + x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) + bbox = (x1, y1, x2, y2) + + # Update or create tracked vehicle + if track_id in self.tracked_vehicles: + # Update existing track + vehicle = self.tracked_vehicles[track_id] + vehicle.update_position(bbox, confidence) + vehicle.display_id = display_id + + # Check stability + stability = vehicle.calculate_stability() + if stability > self.stability_threshold: + vehicle.stable_frames += 1 + if vehicle.stable_frames >= self.min_stable_frames: + vehicle.is_stable = True + else: + vehicle.stable_frames = max(0, vehicle.stable_frames - 1) + if vehicle.stable_frames < self.min_stable_frames: + vehicle.is_stable = False + + logger.debug(f"Updated track {track_id}: conf={confidence:.2f}, " + f"stable={vehicle.is_stable}, stability={stability:.2f}") + else: + # Create new track + vehicle = TrackedVehicle( + track_id=track_id, + first_seen=current_time, + last_seen=current_time, + display_id=display_id, + confidence=confidence, + bbox=bbox, + center=((x1 + x2) / 2, (y1 + y2) / 2), + total_frames=1 + ) + vehicle.last_position_history.append(vehicle.center) + self.tracked_vehicles[track_id] = vehicle + logger.info(f"New vehicle tracked: ID={track_id}, display={display_id}") + + active_tracks.append(self.tracked_vehicles[track_id]) + else: + # No tracking available, process as detections only + logger.debug("No tracking IDs available, processing as detections only") + for i, box in enumerate(boxes): + cls_id = int(box.cls.item()) + confidence = float(box.conf.item()) + + # Check confidence threshold + if confidence < self.min_confidence: + continue + + # Get bounding box + x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) + bbox = (x1, y1, x2, y2) + center = ((x1 + x2) / 2, (y1 + y2) / 2) + + # Try to match with existing tracks by position + matched_track = self._find_closest_track(center) + + if matched_track: + matched_track.update_position(bbox, confidence) + matched_track.display_id = display_id + active_tracks.append(matched_track) + else: + # Create new track with generated ID + track_id = self.next_track_id + self.next_track_id += 1 + + vehicle = TrackedVehicle( + track_id=track_id, + first_seen=current_time, + last_seen=current_time, + display_id=display_id, + confidence=confidence, + bbox=bbox, + center=center, + total_frames=1 + ) + vehicle.last_position_history.append(center) + self.tracked_vehicles[track_id] = vehicle + active_tracks.append(vehicle) + logger.info(f"New vehicle detected (no tracking): ID={track_id}") + + return active_tracks + + def _find_closest_track(self, center: Tuple[float, float]) -> Optional[TrackedVehicle]: + """ + Find the closest existing track to a given position. + + Args: + center: Center position to match + + Returns: + Closest tracked vehicle if within tolerance, None otherwise + """ + min_distance = float('inf') + closest_track = None + + for vehicle in self.tracked_vehicles.values(): + if vehicle.is_expired(0.5): # Shorter timeout for matching + continue + + distance = np.sqrt( + (center[0] - vehicle.center[0]) ** 2 + + (center[1] - vehicle.center[1]) ** 2 + ) + + if distance < min_distance and distance < self.position_tolerance: + min_distance = distance + closest_track = vehicle + + return closest_track + + def get_stable_vehicles(self, display_id: Optional[str] = None) -> List[TrackedVehicle]: + """ + Get all stable vehicles, optionally filtered by display. + + Args: + display_id: Optional display ID to filter by + + Returns: + List of stable tracked vehicles + """ + with self.lock: + stable = [ + v for v in self.tracked_vehicles.values() + if v.is_stable and not v.is_expired(self.timeout_seconds) + and (display_id is None or v.display_id == display_id) + ] + return stable + + def get_vehicle_by_session(self, session_id: str) -> Optional[TrackedVehicle]: + """ + Get a tracked vehicle by its session ID. + + Args: + session_id: Session ID to look up + + Returns: + Tracked vehicle if found, None otherwise + """ + with self.lock: + for vehicle in self.tracked_vehicles.values(): + if vehicle.session_id == session_id: + return vehicle + return None + + def mark_processed(self, track_id: int, session_id: str): + """ + Mark a vehicle as processed through the pipeline. + + Args: + track_id: Track ID of the vehicle + session_id: Session ID assigned to this vehicle + """ + with self.lock: + if track_id in self.tracked_vehicles: + vehicle = self.tracked_vehicles[track_id] + vehicle.processed_pipeline = True + vehicle.session_id = session_id + logger.info(f"Marked vehicle {track_id} as processed with session {session_id}") + + def clear_session(self, session_id: str): + """ + Clear session ID from a tracked vehicle (post-fueling). + + Args: + session_id: Session ID to clear + """ + with self.lock: + for vehicle in self.tracked_vehicles.values(): + if vehicle.session_id == session_id: + logger.info(f"Clearing session {session_id} from vehicle {vehicle.track_id}") + vehicle.session_id = None + # Keep processed_pipeline=True to prevent re-processing + + def reset_tracking(self): + """Reset all tracking state.""" + with self.lock: + self.tracked_vehicles.clear() + self.next_track_id = 1 + logger.info("Vehicle tracking state reset") + + def get_statistics(self) -> Dict: + """Get tracking statistics.""" + with self.lock: + total = len(self.tracked_vehicles) + stable = sum(1 for v in self.tracked_vehicles.values() if v.is_stable) + processed = sum(1 for v in self.tracked_vehicles.values() if v.processed_pipeline) + + return { + 'total_tracked': total, + 'stable_vehicles': stable, + 'processed_vehicles': processed, + 'avg_confidence': np.mean([v.avg_confidence for v in self.tracked_vehicles.values()]) + if self.tracked_vehicles else 0.0 + } \ No newline at end of file diff --git a/core/tracking/validator.py b/core/tracking/validator.py new file mode 100644 index 0000000..e39386f --- /dev/null +++ b/core/tracking/validator.py @@ -0,0 +1,408 @@ +""" +Vehicle Validation Module - Stable car detection and validation logic. +Differentiates between stable (fueling) cars and passing-by vehicles. +""" +import logging +import time +import numpy as np +from typing import List, Optional, Tuple, Dict, Any +from dataclasses import dataclass +from enum import Enum + +from .tracker import TrackedVehicle + +logger = logging.getLogger(__name__) + + +class VehicleState(Enum): + """Vehicle state classification.""" + UNKNOWN = "unknown" + ENTERING = "entering" + STABLE = "stable" + LEAVING = "leaving" + PASSING_BY = "passing_by" + + +@dataclass +class ValidationResult: + """Result of vehicle validation.""" + is_valid: bool + state: VehicleState + confidence: float + reason: str + should_process: bool = False + track_id: Optional[int] = None + + +class StableCarValidator: + """ + Validates whether a tracked vehicle is stable (fueling) or just passing by. + Uses multiple criteria including position stability, duration, and movement patterns. + """ + + def __init__(self, config: Optional[Dict] = None): + """ + Initialize the validator with configuration. + + Args: + config: Optional configuration dictionary + """ + self.config = config or {} + + # Validation thresholds + self.min_stable_duration = self.config.get('min_stable_duration', 3.0) # seconds + self.min_stable_frames = self.config.get('min_stable_frames', 10) + self.position_variance_threshold = self.config.get('position_variance_threshold', 25.0) # pixels + self.min_confidence = self.config.get('min_confidence', 0.7) + self.velocity_threshold = self.config.get('velocity_threshold', 5.0) # pixels/frame + self.entering_zone_ratio = self.config.get('entering_zone_ratio', 0.3) # 30% of frame + self.leaving_zone_ratio = self.config.get('leaving_zone_ratio', 0.3) + + # Frame dimensions (will be updated on first frame) + self.frame_width = 1920 + self.frame_height = 1080 + + # History for validation + self.validation_history: Dict[int, List[VehicleState]] = {} + self.last_processed_vehicles: Dict[int, float] = {} # track_id -> last_process_time + + logger.info(f"StableCarValidator initialized with min_duration={self.min_stable_duration}s, " + f"min_frames={self.min_stable_frames}, position_variance={self.position_variance_threshold}") + + def update_frame_dimensions(self, width: int, height: int): + """Update frame dimensions for zone calculations.""" + self.frame_width = width + self.frame_height = height + logger.debug(f"Updated frame dimensions: {width}x{height}") + + def validate_vehicle(self, vehicle: TrackedVehicle, frame_shape: Optional[Tuple] = None) -> ValidationResult: + """ + Validate whether a tracked vehicle is stable and should be processed. + + Args: + vehicle: The tracked vehicle to validate + frame_shape: Optional frame shape (height, width, channels) + + Returns: + ValidationResult with validation status and reasoning + """ + # Update frame dimensions if provided + if frame_shape: + self.update_frame_dimensions(frame_shape[1], frame_shape[0]) + + # Initialize validation history for new vehicles + if vehicle.track_id not in self.validation_history: + self.validation_history[vehicle.track_id] = [] + + # Check if already processed + if vehicle.processed_pipeline: + return ValidationResult( + is_valid=False, + state=VehicleState.STABLE, + confidence=1.0, + reason="Already processed through pipeline", + should_process=False, + track_id=vehicle.track_id + ) + + # Check if recently processed (cooldown period) + if vehicle.track_id in self.last_processed_vehicles: + time_since_process = time.time() - self.last_processed_vehicles[vehicle.track_id] + if time_since_process < 10.0: # 10 second cooldown + return ValidationResult( + is_valid=False, + state=VehicleState.STABLE, + confidence=1.0, + reason=f"Recently processed ({time_since_process:.1f}s ago)", + should_process=False, + track_id=vehicle.track_id + ) + + # Determine vehicle state + state = self._determine_vehicle_state(vehicle) + + # Update history + self.validation_history[vehicle.track_id].append(state) + if len(self.validation_history[vehicle.track_id]) > 20: + self.validation_history[vehicle.track_id].pop(0) + + # Validate based on state + if state == VehicleState.STABLE: + return self._validate_stable_vehicle(vehicle) + elif state == VehicleState.PASSING_BY: + return ValidationResult( + is_valid=False, + state=state, + confidence=0.8, + reason="Vehicle is passing by", + should_process=False, + track_id=vehicle.track_id + ) + elif state == VehicleState.ENTERING: + return ValidationResult( + is_valid=False, + state=state, + confidence=0.5, + reason="Vehicle is entering, waiting for stability", + should_process=False, + track_id=vehicle.track_id + ) + elif state == VehicleState.LEAVING: + return ValidationResult( + is_valid=False, + state=state, + confidence=0.5, + reason="Vehicle is leaving", + should_process=False, + track_id=vehicle.track_id + ) + else: + return ValidationResult( + is_valid=False, + state=state, + confidence=0.0, + reason="Unknown vehicle state", + should_process=False, + track_id=vehicle.track_id + ) + + def _determine_vehicle_state(self, vehicle: TrackedVehicle) -> VehicleState: + """ + Determine the current state of the vehicle based on movement patterns. + + Args: + vehicle: The tracked vehicle + + Returns: + Current vehicle state + """ + # Not enough data + if len(vehicle.last_position_history) < 3: + return VehicleState.UNKNOWN + + # Calculate velocity + velocity = self._calculate_velocity(vehicle) + + # Get position zones + x_position = vehicle.center[0] / self.frame_width + y_position = vehicle.center[1] / self.frame_height + + # Check if vehicle is stable + stability = vehicle.calculate_stability() + if stability > 0.7 and velocity < self.velocity_threshold: + # Check if it's been stable long enough + duration = time.time() - vehicle.first_seen + if duration > self.min_stable_duration and vehicle.stable_frames >= self.min_stable_frames: + return VehicleState.STABLE + else: + return VehicleState.ENTERING + + # Check if vehicle is entering or leaving + if velocity > self.velocity_threshold: + # Determine direction based on position history + positions = np.array(vehicle.last_position_history) + if len(positions) >= 2: + direction = positions[-1] - positions[0] + + # Entering: moving towards center + if x_position < self.entering_zone_ratio or x_position > (1 - self.entering_zone_ratio): + if abs(direction[0]) > abs(direction[1]): # Horizontal movement + if (x_position < 0.5 and direction[0] > 0) or (x_position > 0.5 and direction[0] < 0): + return VehicleState.ENTERING + + # Leaving: moving away from center + if 0.3 < x_position < 0.7: # In center zone + if abs(direction[0]) > abs(direction[1]): # Horizontal movement + if abs(direction[0]) > 10: # Significant movement + return VehicleState.LEAVING + + return VehicleState.PASSING_BY + + return VehicleState.UNKNOWN + + def _validate_stable_vehicle(self, vehicle: TrackedVehicle) -> ValidationResult: + """ + Perform detailed validation of a stable vehicle. + + Args: + vehicle: The stable vehicle to validate + + Returns: + Detailed validation result + """ + # Check duration + duration = time.time() - vehicle.first_seen + if duration < self.min_stable_duration: + return ValidationResult( + is_valid=False, + state=VehicleState.STABLE, + confidence=0.6, + reason=f"Not stable long enough ({duration:.1f}s < {self.min_stable_duration}s)", + should_process=False, + track_id=vehicle.track_id + ) + + # Check frame count + if vehicle.stable_frames < self.min_stable_frames: + return ValidationResult( + is_valid=False, + state=VehicleState.STABLE, + confidence=0.6, + reason=f"Not enough stable frames ({vehicle.stable_frames} < {self.min_stable_frames})", + should_process=False, + track_id=vehicle.track_id + ) + + # Check confidence + if vehicle.avg_confidence < self.min_confidence: + return ValidationResult( + is_valid=False, + state=VehicleState.STABLE, + confidence=vehicle.avg_confidence, + reason=f"Confidence too low ({vehicle.avg_confidence:.2f} < {self.min_confidence})", + should_process=False, + track_id=vehicle.track_id + ) + + # Check position variance + variance = self._calculate_position_variance(vehicle) + if variance > self.position_variance_threshold: + return ValidationResult( + is_valid=False, + state=VehicleState.STABLE, + confidence=0.7, + reason=f"Position variance too high ({variance:.1f} > {self.position_variance_threshold})", + should_process=False, + track_id=vehicle.track_id + ) + + # Check state history consistency + if vehicle.track_id in self.validation_history: + history = self.validation_history[vehicle.track_id][-5:] # Last 5 states + stable_count = sum(1 for s in history if s == VehicleState.STABLE) + if stable_count < 3: + return ValidationResult( + is_valid=False, + state=VehicleState.STABLE, + confidence=0.7, + reason="Inconsistent state history", + should_process=False, + track_id=vehicle.track_id + ) + + # All checks passed - vehicle is valid for processing + self.last_processed_vehicles[vehicle.track_id] = time.time() + + return ValidationResult( + is_valid=True, + state=VehicleState.STABLE, + confidence=vehicle.avg_confidence, + reason="Vehicle is stable and ready for processing", + should_process=True, + track_id=vehicle.track_id + ) + + def _calculate_velocity(self, vehicle: TrackedVehicle) -> float: + """ + Calculate the velocity of the vehicle based on position history. + + Args: + vehicle: The tracked vehicle + + Returns: + Velocity in pixels per frame + """ + if len(vehicle.last_position_history) < 2: + return 0.0 + + positions = np.array(vehicle.last_position_history) + if len(positions) < 2: + return 0.0 + + # Calculate velocity over last 3 frames + recent_positions = positions[-min(3, len(positions)):] + velocities = [] + + for i in range(1, len(recent_positions)): + dx = recent_positions[i][0] - recent_positions[i-1][0] + dy = recent_positions[i][1] - recent_positions[i-1][1] + velocity = np.sqrt(dx**2 + dy**2) + velocities.append(velocity) + + return np.mean(velocities) if velocities else 0.0 + + def _calculate_position_variance(self, vehicle: TrackedVehicle) -> float: + """ + Calculate the position variance of the vehicle. + + Args: + vehicle: The tracked vehicle + + Returns: + Position variance in pixels + """ + if len(vehicle.last_position_history) < 2: + return 0.0 + + positions = np.array(vehicle.last_position_history) + variance_x = np.var(positions[:, 0]) + variance_y = np.var(positions[:, 1]) + + return np.sqrt(variance_x + variance_y) + + def should_skip_same_car(self, + vehicle: TrackedVehicle, + session_cleared: bool = False) -> bool: + """ + Determine if we should skip processing for the same car after session clear. + + Args: + vehicle: The tracked vehicle + session_cleared: Whether the session was recently cleared + + Returns: + True if we should skip this vehicle + """ + # If vehicle has a session_id but it was cleared, skip for a period + if vehicle.session_id is None and vehicle.processed_pipeline and session_cleared: + # Check if enough time has passed since processing + if vehicle.track_id in self.last_processed_vehicles: + time_since = time.time() - self.last_processed_vehicles[vehicle.track_id] + if time_since < 30.0: # 30 second cooldown after session clear + logger.debug(f"Skipping same car {vehicle.track_id} after session clear " + f"({time_since:.1f}s since processing)") + return True + + return False + + def reset_vehicle(self, track_id: int): + """ + Reset validation state for a specific vehicle. + + Args: + track_id: Track ID of the vehicle to reset + """ + if track_id in self.validation_history: + del self.validation_history[track_id] + if track_id in self.last_processed_vehicles: + del self.last_processed_vehicles[track_id] + logger.debug(f"Reset validation state for vehicle {track_id}") + + def get_statistics(self) -> Dict: + """Get validation statistics.""" + return { + 'vehicles_in_history': len(self.validation_history), + 'recently_processed': len(self.last_processed_vehicles), + 'state_distribution': self._get_state_distribution() + } + + def _get_state_distribution(self) -> Dict[str, int]: + """Get distribution of current vehicle states.""" + distribution = {state.value: 0 for state in VehicleState} + + for history in self.validation_history.values(): + if history: + current_state = history[-1] + distribution[current_state.value] += 1 + + return distribution \ No newline at end of file