diff --git a/app.py b/app.py index 8383208..757a6df 100644 --- a/app.py +++ b/app.py @@ -118,15 +118,34 @@ def download_mpta(url: str, dest_path: str) -> str: def fetch_snapshot(url: str): try: from requests.auth import HTTPBasicAuth, HTTPDigestAuth + import requests.adapters + import urllib3 # Parse URL to extract credentials parsed = urlparse(url) - # Prepare headers - some cameras require User-Agent + # Prepare headers - some cameras require User-Agent and specific headers headers = { - 'User-Agent': 'Mozilla/5.0 (compatible; DetectorWorker/1.0)' + 'User-Agent': 'Mozilla/5.0 (compatible; DetectorWorker/1.0)', + 'Accept': 'image/jpeg,image/*,*/*', + 'Connection': 'close', + 'Cache-Control': 'no-cache' } + # Create a session with custom adapter for better connection handling + session = requests.Session() + adapter = requests.adapters.HTTPAdapter( + pool_connections=1, + pool_maxsize=1, + max_retries=urllib3.util.retry.Retry( + total=2, + backoff_factor=0.1, + status_forcelist=[500, 502, 503, 504] + ) + ) + session.mount('http://', adapter) + session.mount('https://', adapter) + # Reconstruct URL without credentials clean_url = f"{parsed.scheme}://{parsed.hostname}" if parsed.port: @@ -136,44 +155,68 @@ def fetch_snapshot(url: str): clean_url += f"?{parsed.query}" auth = None + response = None + if parsed.username and parsed.password: # Try HTTP Digest authentication first (common for IP cameras) try: auth = HTTPDigestAuth(parsed.username, parsed.password) - response = requests.get(clean_url, auth=auth, headers=headers, timeout=10) + response = session.get(clean_url, auth=auth, headers=headers, timeout=(5, 15), stream=True) if response.status_code == 200: logger.debug(f"Successfully authenticated using HTTP Digest for {clean_url}") elif response.status_code == 401: # If Digest fails, try Basic auth logger.debug(f"HTTP Digest failed, trying Basic auth for {clean_url}") auth = HTTPBasicAuth(parsed.username, parsed.password) - response = requests.get(clean_url, auth=auth, headers=headers, timeout=10) + response = session.get(clean_url, auth=auth, headers=headers, timeout=(5, 15), stream=True) if response.status_code == 200: logger.debug(f"Successfully authenticated using HTTP Basic for {clean_url}") except Exception as auth_error: logger.debug(f"Authentication setup error: {auth_error}") # Fallback to original URL with embedded credentials - response = requests.get(url, headers=headers, timeout=10) + response = session.get(url, headers=headers, timeout=(5, 15), stream=True) else: # No credentials in URL, make request as-is - response = requests.get(url, headers=headers, timeout=10) + response = session.get(url, headers=headers, timeout=(5, 15), stream=True) - if response.status_code == 200: + if response and response.status_code == 200: + # Read content with size limit to prevent memory issues + content = b'' + max_size = 10 * 1024 * 1024 # 10MB limit + for chunk in response.iter_content(chunk_size=8192): + content += chunk + if len(content) > max_size: + logger.error(f"Snapshot too large (>{max_size} bytes) from {clean_url}") + return None + # Convert response content to numpy array - nparr = np.frombuffer(response.content, np.uint8) + nparr = np.frombuffer(content, np.uint8) # Decode image frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if frame is not None: - logger.debug(f"Successfully fetched snapshot from {clean_url}, shape: {frame.shape}") + logger.debug(f"Successfully fetched snapshot from {clean_url}, shape: {frame.shape}, size: {len(content)} bytes") return frame else: - logger.error(f"Failed to decode image from snapshot URL: {clean_url}") + logger.error(f"Failed to decode image from snapshot URL: {clean_url} (content size: {len(content)} bytes)") return None - else: + elif response: logger.error(f"Failed to fetch snapshot (status code {response.status_code}): {clean_url}") + # Log response headers and first part of content for debugging + logger.debug(f"Response headers: {dict(response.headers)}") + if len(response.content) < 1000: + logger.debug(f"Response content: {response.content[:500]}") return None + else: + logger.error(f"No response received from snapshot URL: {clean_url}") + return None + except requests.exceptions.Timeout as e: + logger.error(f"Timeout fetching snapshot from {url}: {str(e)}") + return None + except requests.exceptions.ConnectionError as e: + logger.error(f"Connection error fetching snapshot from {url}: {str(e)}") + return None except Exception as e: - logger.error(f"Exception fetching snapshot from {url}: {str(e)}") + logger.error(f"Exception fetching snapshot from {url}: {str(e)}", exc_info=True) return None # Helper to get crop coordinates from stream @@ -324,7 +367,7 @@ async def detect(websocket: WebSocket): detection_data = { "type": "imageDetection", "subscriptionIdentifier": stream["subscriptionIdentifier"], - "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S.%fZ", time.gmtime()), + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), # "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()) + f".{int(time.time() * 1000) % 1000:03d}Z", "data": { "detection": detection_dict, @@ -452,6 +495,7 @@ async def detect(websocket: WebSocket): def snapshot_reader(camera_id, snapshot_url, snapshot_interval, buffer, stop_event): """Frame reader that fetches snapshots from HTTP/HTTPS URL at specified intervals""" retries = 0 + consecutive_failures = 0 # Track consecutive failures for backoff logger.info(f"Starting snapshot reader thread for camera {camera_id} from {snapshot_url}") frame_count = 0 last_log_time = time.time() @@ -466,15 +510,31 @@ async def detect(websocket: WebSocket): frame = fetch_snapshot(snapshot_url) if frame is None: - logger.warning(f"Failed to fetch snapshot for camera: {camera_id}, retry {retries+1}/{max_retries}") + consecutive_failures += 1 + logger.warning(f"Failed to fetch snapshot for camera: {camera_id}, consecutive failures: {consecutive_failures}") retries += 1 + + # Check network connectivity with a simple ping-like test + if consecutive_failures % 5 == 1: # Every 5th failure, test connectivity + try: + test_response = requests.get(snapshot_url, timeout=(2, 5), stream=False) + logger.info(f"Camera {camera_id}: Connectivity test result: {test_response.status_code}") + except Exception as test_error: + logger.warning(f"Camera {camera_id}: Connectivity test failed: {test_error}") + if retries > max_retries and max_retries != -1: logger.error(f"Max retries reached for snapshot camera: {camera_id}, stopping reader") break - time.sleep(min(interval_seconds, reconnect_interval)) + + # Exponential backoff based on consecutive failures + backoff_delay = min(30, max(1, min(2 ** min(consecutive_failures - 1, 6), interval_seconds * 2))) # Start with 1s, max 30s + logger.debug(f"Camera {camera_id}: Backing off for {backoff_delay:.1f}s (consecutive failures: {consecutive_failures})") + if stop_event.wait(backoff_delay): # Use wait with timeout instead of sleep + break # Exit if stop_event is set during backoff continue - # Successfully fetched a frame + # Successfully fetched a frame - reset consecutive failures + consecutive_failures = 0 # Reset backoff on success frame_count += 1 current_time = time.time() # Log frame stats every 5 seconds @@ -503,12 +563,18 @@ async def detect(websocket: WebSocket): time.sleep(sleep_time) except Exception as e: + consecutive_failures += 1 logger.error(f"Unexpected error fetching snapshot for camera {camera_id}: {str(e)}", exc_info=True) retries += 1 if retries > max_retries and max_retries != -1: logger.error(f"Max retries reached after error for snapshot camera {camera_id}") break - time.sleep(min(interval_seconds, reconnect_interval)) + + # Exponential backoff for exceptions too + backoff_delay = min(30, max(1, min(2 ** min(consecutive_failures - 1, 6), interval_seconds * 2))) # Start with 1s, max 30s + logger.debug(f"Camera {camera_id}: Exception backoff for {backoff_delay:.1f}s (consecutive failures: {consecutive_failures})") + if stop_event.wait(backoff_delay): # Use wait with timeout instead of sleep + break # Exit if stop_event is set during backoff except Exception as e: logger.error(f"Error in snapshot_reader thread for camera {camera_id}: {str(e)}", exc_info=True) finally: diff --git a/siwatsystem/pympta.py b/siwatsystem/pympta.py index 2ba7453..64218d6 100644 --- a/siwatsystem/pympta.py +++ b/siwatsystem/pympta.py @@ -13,6 +13,7 @@ import concurrent.futures from ultralytics import YOLO from urllib.parse import urlparse from .database import DatabaseManager +from datetime import datetime # Create a logger specifically for this module logger = logging.getLogger("detector_worker.pympta") @@ -108,6 +109,7 @@ def load_pipeline_node(node_config: dict, mpta_dir: str, redis_client, db_manage "modelFile": node_config["modelFile"], "triggerClasses": trigger_classes, "triggerClassIndices": trigger_class_indices, + "classMapping": node_config.get("classMapping", {}), "crop": node_config.get("crop", False), "cropClass": node_config.get("cropClass"), "minConfidence": node_config.get("minConfidence", None), @@ -608,8 +610,7 @@ def run_detection_with_tracking(frame, node, context=None): )[0] # Process detection results - all_detections = [] - regions_dict = {} + candidate_detections = [] min_confidence = node.get("minConfidence", 0.0) if res.boxes is None or len(res.boxes) == 0: @@ -618,6 +619,7 @@ def run_detection_with_tracking(frame, node, context=None): logger.debug(f"Processing {len(res.boxes)} raw detections") + # First pass: collect all valid detections for i, box in enumerate(res.boxes): # Extract detection data conf = float(box.cpu().conf[0]) @@ -658,17 +660,39 @@ def run_detection_with_tracking(frame, node, context=None): "class_id": cls_id } - all_detections.append(detection) - logger.debug(f"Detection {i} accepted: {class_name} (conf={conf:.3f}, id={track_id}, bbox={bbox})") - - # Update regions_dict with highest confidence detection per class - if class_name not in regions_dict or conf > regions_dict[class_name]["confidence"]: - regions_dict[class_name] = { - "bbox": bbox, - "confidence": conf, - "detection": detection, - "track_id": track_id - } + candidate_detections.append(detection) + logger.debug(f"Detection {i} candidate: {class_name} (conf={conf:.3f}, id={track_id}, bbox={bbox})") + + # Second pass: select only the highest confidence detection overall + if not candidate_detections: + logger.debug("No valid candidate detections found") + return [], {} + + # Find the single highest confidence detection across all detected classes + best_detection = max(candidate_detections, key=lambda x: x["confidence"]) + original_class = best_detection["class"] + logger.info(f"Selected highest confidence detection: {original_class} (conf={best_detection['confidence']:.3f})") + + # Apply class mapping if configured + mapped_class = original_class + class_mapping = node.get("classMapping", {}) + if original_class in class_mapping: + mapped_class = class_mapping[original_class] + logger.info(f"Class mapping applied: {original_class} → {mapped_class}") + # Update the detection object with mapped class + best_detection["class"] = mapped_class + best_detection["original_class"] = original_class # Keep original for reference + + # Keep only the best detection with mapped class + all_detections = [best_detection] + regions_dict = { + mapped_class: { + "bbox": best_detection["bbox"], + "confidence": best_detection["confidence"], + "detection": best_detection, + "track_id": best_detection["id"] + } + } # Multi-class validation if node.get("multiClass", False) and node.get("expectedClasses"): @@ -964,7 +988,7 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None): elif "color" in model_id: det["color"] = class_name - execute_actions(node, frame, det) + execute_actions(node, frame, det, context.get("regions_dict") if context else None) return (det, None) if return_bbox else det # ─── Session management check ─────────────────────────────────────── @@ -1019,13 +1043,14 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None): **(context or {}) } - # ─── Create initial database record when Car+Frontal detected ──── - if node.get("db_manager") and node.get("multiClass", False): - # Only create database record if we have both Car and Frontal - has_car = "Car" in regions_dict - has_frontal = "Frontal" in regions_dict + # ─── Create initial database record when valid detection found ──── + if node.get("db_manager") and regions_dict: + # Create database record if we have any valid detection (after class mapping and filtering) + detected_classes = list(regions_dict.keys()) + logger.debug(f"Valid detections found for database record: {detected_classes}") - if has_car and has_frontal: + # Always create record if we have valid detections that passed all filters + if detected_classes: # Generate UUID session_id since client session is None for now import uuid as uuid_lib from datetime import datetime @@ -1047,9 +1072,12 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None): detection_result["timestamp"] = timestamp # Update with proper timestamp logger.info(f"Created initial database record with session_id: {inserted_session_id}") else: - logger.debug(f"Database record not created - missing required classes. Has Car: {has_car}, Has Frontal: {has_frontal}") + logger.debug("Database record not created - no valid detections found after filtering") - execute_actions(node, frame, detection_result, regions_dict) + # Execute actions for root node only if it doesn't have branches + # Branch nodes with actions will execute them after branch processing + if not node.get("branches") or node.get("modelId") == "yolo11n": + execute_actions(node, frame, detection_result, regions_dict) # ─── Branch processing (no stability check here) ───────────────────────────── if node["branches"]: @@ -1089,21 +1117,28 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None): futures = {} for br in active_branches: - crop_class = br.get("cropClass", br.get("triggerClasses", [])[0] if br.get("triggerClasses") else None) sub_frame = frame + crop_class = br.get("cropClass") - logger.info(f"Starting parallel branch: {br['modelId']}, crop_class: {crop_class}") + logger.info(f"Starting parallel branch: {br['modelId']}, cropClass: {crop_class}") if br.get("crop", False) and crop_class: - cropped = crop_region_by_class(frame, regions_dict, crop_class) - if cropped is not None: - sub_frame = cv2.resize(cropped, (224, 224)) - logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']}") + if crop_class in regions_dict: + cropped = crop_region_by_class(frame, regions_dict, crop_class) + if cropped is not None: + sub_frame = cropped # Use cropped image without manual resizing + logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']} - model will handle resizing") + else: + logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch") + continue else: - logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch") + logger.warning(f"Crop class {crop_class} not found in detected regions for {br['modelId']}, skipping branch") continue - future = executor.submit(run_pipeline, sub_frame, br, True, context) + # Add regions_dict to context for child branches + branch_context = dict(context) if context else {} + branch_context["regions_dict"] = regions_dict + future = executor.submit(run_pipeline, sub_frame, br, True, branch_context) futures[future] = br # Collect results @@ -1119,22 +1154,29 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None): else: # Run branches sequentially for br in active_branches: - crop_class = br.get("cropClass", br.get("triggerClasses", [])[0] if br.get("triggerClasses") else None) sub_frame = frame + crop_class = br.get("cropClass") - logger.info(f"Starting sequential branch: {br['modelId']}, crop_class: {crop_class}") + logger.info(f"Starting sequential branch: {br['modelId']}, cropClass: {crop_class}") if br.get("crop", False) and crop_class: - cropped = crop_region_by_class(frame, regions_dict, crop_class) - if cropped is not None: - sub_frame = cv2.resize(cropped, (224, 224)) - logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']}") + if crop_class in regions_dict: + cropped = crop_region_by_class(frame, regions_dict, crop_class) + if cropped is not None: + sub_frame = cropped # Use cropped image without manual resizing + logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']} - model will handle resizing") + else: + logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch") + continue else: - logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch") + logger.warning(f"Crop class {crop_class} not found in detected regions for {br['modelId']}, skipping branch") continue try: - result, _ = run_pipeline(sub_frame, br, True, context) + # Add regions_dict to context for child branches + branch_context = dict(context) if context else {} + branch_context["regions_dict"] = regions_dict + result, _ = run_pipeline(sub_frame, br, True, branch_context) if result: branch_results[br["modelId"]] = result logger.info(f"Branch {br['modelId']} completed: {result}") @@ -1156,6 +1198,14 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None): start_cooldown_timer(camera_id, model_id) logger.info(f"Camera {camera_id}: Pipeline completed successfully, starting 30s cooldown") + # ─── Execute actions after successful detection AND branch processing ────────── + # This ensures detection nodes (like frontal_detection_v1) execute their actions + # after completing both detection and branch processing + if node.get("actions") and regions_dict and node.get("modelId") != "yolo11n": + # Execute actions for branch detection nodes, skip root to avoid duplication + logger.debug(f"Executing post-detection actions for branch node {node.get('modelId')}") + execute_actions(node, frame, detection_result, regions_dict) + # ─── Return detection result ──────────────────────────────── primary_detection = max(all_detections, key=lambda x: x["confidence"]) primary_bbox = primary_detection["bbox"] diff --git a/test/sample.png b/test/sample.png new file mode 100644 index 0000000..568e38f Binary files /dev/null and b/test/sample.png differ diff --git a/test/sample2.png b/test/sample2.png new file mode 100644 index 0000000..c1e8485 Binary files /dev/null and b/test/sample2.png differ diff --git a/test/test.py b/test/test.py new file mode 100644 index 0000000..ff073c4 --- /dev/null +++ b/test/test.py @@ -0,0 +1,60 @@ +from ultralytics import YOLO +import cv2 +import os + +# Load the model +# model = YOLO('../models/webcam-local-01/4/bangchak_poc/yolo11n.pt') +model = YOLO('yolo11m.pt') + +def test_image(image_path): + """Test a single image with YOLO model""" + if not os.path.exists(image_path): + print(f"Image not found: {image_path}") + return + + # Run inference - filter for car class only (class 2 in COCO) + results = model(image_path, classes=[2, 5, 7]) # 2, 5, 7 = car, bus, truck in COCO dataset + + # Display results + for r in results: + im_array = r.plot() # plot a BGR numpy array of predictions + + # Resize image for display (max width/height 800px) + height, width = im_array.shape[:2] + max_dimension = 800 + if width > max_dimension or height > max_dimension: + if width > height: + new_width = max_dimension + new_height = int(height * (max_dimension / width)) + else: + new_height = max_dimension + new_width = int(width * (max_dimension / height)) + im_array = cv2.resize(im_array, (new_width, new_height)) + + # Show image with predictions + cv2.imshow('YOLO Test - Car Detection Only', im_array) + cv2.waitKey(0) + cv2.destroyAllWindows() + + # Print detection info + print(f"\nDetections for {image_path}:") + if r.boxes is not None and len(r.boxes) > 0: + for i, box in enumerate(r.boxes): + cls = int(box.cls[0]) + conf = float(box.conf[0]) + original_class = model.names[cls] # Original class name (car/bus/truck) + # Get bounding box coordinates + x1, y1, x2, y2 = box.xyxy[0].tolist() + # Rename all vehicle types to "car" + print(f"Detection {i+1}: car (was: {original_class}) - Confidence: {conf:.3f} - BBox: ({x1:.0f}, {y1:.0f}, {x2:.0f}, {y2:.0f})") + print(f"Total cars detected: {len(r.boxes)}") + else: + print("No cars detected in the image") + +if __name__ == "__main__": + # Test with an image file + image_path = input("Enter image path (or press Enter for default test): ") + if not image_path: + image_path = "sample.png" # Default test image + + test_image(image_path) \ No newline at end of file diff --git a/test_botsort_zone_track.py b/test/test_botsort_zone_track.py similarity index 100% rename from test_botsort_zone_track.py rename to test/test_botsort_zone_track.py diff --git a/test_detection_tracking.py b/test_detection_tracking.py deleted file mode 100644 index ce38d8e..0000000 --- a/test_detection_tracking.py +++ /dev/null @@ -1,190 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for the refactored detection and tracking functionality. -""" -import os -import sys -import cv2 -import numpy as np -import logging -from pathlib import Path - -# Add the project root to Python path -sys.path.insert(0, str(Path(__file__).parent)) - -from siwatsystem.pympta import run_detection_with_tracking, load_pipeline_from_zip - -# Set up logging -logging.basicConfig( - level=logging.DEBUG, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -def create_test_frame(): - """Create a simple test frame for detection testing.""" - frame = np.zeros((480, 640, 3), dtype=np.uint8) - # Add some simple shapes to simulate objects - cv2.rectangle(frame, (50, 50), (200, 150), (255, 0, 0), -1) # Blue rectangle - cv2.rectangle(frame, (300, 200), (450, 350), (0, 255, 0), -1) # Green rectangle - cv2.putText(frame, "Test Frame", (250, 400), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) - return frame - -def test_detection_function(): - """Test the structured detection function with mock data.""" - logger.info("Testing run_detection_with_tracking function...") - - # Create test frame - test_frame = create_test_frame() - - # Mock node configuration (simulating what would come from an MPTA file) - mock_node = { - "modelId": "test_detection_v1", - "triggerClasses": ["car", "person"], - "triggerClassIndices": [0, 1], - "minConfidence": 0.5, - "multiClass": False, - "expectedClasses": [], - "tracking": { - "enabled": True, - "reidConfigPath": "botsort.yaml" - } - } - - # Mock context - test_context = { - "display_id": "test-display-001", - "camera_id": "test-cam-001" - } - - logger.info("Mock node configuration:") - for key, value in mock_node.items(): - logger.info(f" {key}: {value}") - - # Note: This test will fail without a real YOLO model, but demonstrates the structure - try: - detections, regions = run_detection_with_tracking(test_frame, mock_node, test_context) - logger.info(f"Function executed successfully!") - logger.info(f"Returned detections: {len(detections)}") - logger.info(f"Returned regions: {list(regions.keys())}") - return True - except Exception as e: - logger.error(f"Function failed (expected without real model): {e}") - return False - -def test_mpta_loading(): - """Test loading an MPTA file with tracking configuration.""" - logger.info("Testing MPTA loading with tracking configuration...") - - # Check if models directory exists - models_dir = Path("models") - if not models_dir.exists(): - logger.warning("No models directory found - skipping MPTA test") - return False - - # Look for any .mpta files - mpta_files = list(models_dir.glob("**/*.mpta")) - if not mpta_files: - logger.warning("No .mpta files found in models directory - skipping MPTA test") - return False - - mpta_file = mpta_files[0] - logger.info(f"Testing with MPTA file: {mpta_file}") - - try: - # Attempt to load pipeline - target_dir = f"temp_test_{os.getpid()}" - pipeline = load_pipeline_from_zip(str(mpta_file), target_dir) - - if pipeline: - logger.info("MPTA loaded successfully!") - logger.info(f"Pipeline model ID: {pipeline.get('modelId')}") - logger.info(f"Tracking config: {pipeline.get('tracking')}") - - # Clean up - import shutil - if os.path.exists(target_dir): - shutil.rmtree(target_dir) - - return True - else: - logger.error("Failed to load MPTA pipeline") - return False - - except Exception as e: - logger.error(f"MPTA loading failed: {e}") - return False - -def print_usage_example(): - """Print example usage of the new structured functions.""" - logger.info("\n" + "="*60) - logger.info("USAGE EXAMPLE - Structured Detection & Tracking") - logger.info("="*60) - - example_config = ''' - Example pipeline.json configuration: - - { - "pipeline": { - "modelId": "car_frontal_detection_v1", - "modelFile": "yolo11n.pt", - "triggerClasses": ["Car", "Frontal"], - "minConfidence": 0.7, - "multiClass": true, - "expectedClasses": ["Car", "Frontal"], - "tracking": { - "enabled": true, - "reidConfigPath": "botsort_reid.yaml" - }, - "actions": [...], - "branches": [...] - } - } - ''' - - logger.info(example_config) - - code_example = ''' - Usage in code: - - # Load pipeline from MPTA file - pipeline = load_pipeline_from_zip("model.mpta", "temp_dir") - - # Run detection with tracking - detections, regions = run_detection_with_tracking(frame, pipeline, context) - - # Process results - for detection in detections: - class_name = detection["class"] - confidence = detection["confidence"] - track_id = detection["id"] # Available when tracking enabled - bbox = detection["bbox"] - print(f"Detected: {class_name} (ID: {track_id}, conf: {confidence:.2f})") - ''' - - logger.info(code_example) - -def main(): - """Main test function.""" - logger.info("Starting detection & tracking refactoring tests...") - - # Test 1: Function structure - test1_passed = test_detection_function() - - # Test 2: MPTA loading - test2_passed = test_mpta_loading() - - # Print usage examples - print_usage_example() - - # Summary - logger.info("\n" + "="*60) - logger.info("TEST SUMMARY") - logger.info("="*60) - logger.info(f"Function structure test: {'PASS' if test1_passed else 'EXPECTED FAIL (no model)'}") - logger.info(f"MPTA loading test: {'PASS' if test2_passed else 'SKIP (no files)'}") - logger.info("\nRefactoring completed successfully!") - logger.info("The detection and tracking code is now structured and easy to configure.") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/view_redis_images.py b/view_redis_images.py new file mode 100644 index 0000000..b1b3c63 --- /dev/null +++ b/view_redis_images.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +""" +Script to view frontal images saved in Redis +""" +import redis +import cv2 +import numpy as np +import sys +from datetime import datetime + +# Redis connection config (from pipeline.json) +REDIS_CONFIG = { + "host": "10.100.1.3", + "port": 6379, + "password": "FBQgi0i5RevAAMO5Hh66", + "db": 0 +} + +def connect_redis(): + """Connect to Redis server.""" + try: + client = redis.Redis( + host=REDIS_CONFIG["host"], + port=REDIS_CONFIG["port"], + password=REDIS_CONFIG["password"], + db=REDIS_CONFIG["db"], + decode_responses=False # Keep bytes for images + ) + client.ping() + print(f"✅ Connected to Redis at {REDIS_CONFIG['host']}:{REDIS_CONFIG['port']}") + return client + except redis.exceptions.ConnectionError as e: + print(f"❌ Failed to connect to Redis: {e}") + return None + +def list_image_keys(client): + """List all image keys in Redis.""" + try: + # Look for keys matching the inference pattern + keys = client.keys("inference:*") + print(f"\n📋 Found {len(keys)} image keys:") + for i, key in enumerate(keys): + key_str = key.decode() if isinstance(key, bytes) else key + print(f"{i+1}. {key_str}") + return keys + except Exception as e: + print(f"❌ Error listing keys: {e}") + return [] + +def view_image(client, key): + """View a specific image from Redis.""" + try: + # Get image data from Redis + image_data = client.get(key) + if image_data is None: + print(f"❌ No data found for key: {key}") + return + + print(f"📸 Image size: {len(image_data)} bytes") + + # Convert bytes to numpy array + nparr = np.frombuffer(image_data, np.uint8) + + # Decode image + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + if img is None: + print("❌ Failed to decode image data") + return + + print(f"🖼️ Image dimensions: {img.shape[1]}x{img.shape[0]} pixels") + + # Display image + key_str = key.decode() if isinstance(key, bytes) else key + cv2.imshow(f'Redis Image: {key_str}', img) + print("👁️ Image displayed. Press any key to close...") + cv2.waitKey(0) + cv2.destroyAllWindows() + + # Ask if user wants to save the image + save = input("💾 Save image to file? (y/n): ").lower().strip() + if save == 'y': + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"redis_image_{timestamp}.jpg" + cv2.imwrite(filename, img) + print(f"💾 Image saved as: {filename}") + + except Exception as e: + print(f"❌ Error viewing image: {e}") + +def monitor_new_images(client): + """Monitor for new images being added to Redis.""" + print("👀 Monitoring for new images... (Press Ctrl+C to stop)") + try: + # Subscribe to Redis pub/sub for car detections + pubsub = client.pubsub() + pubsub.subscribe('car_detections') + + for message in pubsub.listen(): + if message['type'] == 'message': + data = message['data'].decode() + print(f"🚨 New detection: {data}") + + # Try to extract image key from message + import json + try: + detection_data = json.loads(data) + image_key = detection_data.get('image_key') + if image_key: + print(f"🖼️ New image available: {image_key}") + view_choice = input("View this image now? (y/n): ").lower().strip() + if view_choice == 'y': + view_image(client, image_key) + except json.JSONDecodeError: + pass + + except KeyboardInterrupt: + print("\n👋 Stopping monitor...") + except Exception as e: + print(f"❌ Monitor error: {e}") + +def main(): + """Main function.""" + print("🔍 Redis Image Viewer") + print("=" * 50) + + # Connect to Redis + client = connect_redis() + if not client: + return + + while True: + print("\n📋 Options:") + print("1. List all image keys") + print("2. View specific image") + print("3. Monitor for new images") + print("4. Exit") + + choice = input("\nEnter choice (1-4): ").strip() + + if choice == '1': + keys = list_image_keys(client) + elif choice == '2': + keys = list_image_keys(client) + if keys: + try: + idx = int(input(f"\nEnter image number (1-{len(keys)}): ")) - 1 + if 0 <= idx < len(keys): + view_image(client, keys[idx]) + else: + print("❌ Invalid selection") + except ValueError: + print("❌ Please enter a valid number") + elif choice == '3': + monitor_new_images(client) + elif choice == '4': + print("👋 Goodbye!") + break + else: + print("❌ Invalid choice") + +if __name__ == "__main__": + main() \ No newline at end of file