feat/tracking and save in redis finished
This commit is contained in:
		
							parent
							
								
									3a4a27ca68
								
							
						
					
					
						commit
						5873945115
					
				
					 8 changed files with 393 additions and 245 deletions
				
			
		
							
								
								
									
										100
									
								
								app.py
									
										
									
									
									
								
							
							
						
						
									
										100
									
								
								app.py
									
										
									
									
									
								
							| 
						 | 
					@ -118,15 +118,34 @@ def download_mpta(url: str, dest_path: str) -> str:
 | 
				
			||||||
def fetch_snapshot(url: str):
 | 
					def fetch_snapshot(url: str):
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        from requests.auth import HTTPBasicAuth, HTTPDigestAuth
 | 
					        from requests.auth import HTTPBasicAuth, HTTPDigestAuth
 | 
				
			||||||
 | 
					        import requests.adapters
 | 
				
			||||||
 | 
					        import urllib3
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        # Parse URL to extract credentials
 | 
					        # Parse URL to extract credentials
 | 
				
			||||||
        parsed = urlparse(url)
 | 
					        parsed = urlparse(url)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        # Prepare headers - some cameras require User-Agent
 | 
					        # Prepare headers - some cameras require User-Agent and specific headers
 | 
				
			||||||
        headers = {
 | 
					        headers = {
 | 
				
			||||||
            'User-Agent': 'Mozilla/5.0 (compatible; DetectorWorker/1.0)'
 | 
					            'User-Agent': 'Mozilla/5.0 (compatible; DetectorWorker/1.0)',
 | 
				
			||||||
 | 
					            'Accept': 'image/jpeg,image/*,*/*',
 | 
				
			||||||
 | 
					            'Connection': 'close',
 | 
				
			||||||
 | 
					            'Cache-Control': 'no-cache'
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
					        # Create a session with custom adapter for better connection handling
 | 
				
			||||||
 | 
					        session = requests.Session()
 | 
				
			||||||
 | 
					        adapter = requests.adapters.HTTPAdapter(
 | 
				
			||||||
 | 
					            pool_connections=1,
 | 
				
			||||||
 | 
					            pool_maxsize=1,
 | 
				
			||||||
 | 
					            max_retries=urllib3.util.retry.Retry(
 | 
				
			||||||
 | 
					                total=2,
 | 
				
			||||||
 | 
					                backoff_factor=0.1,
 | 
				
			||||||
 | 
					                status_forcelist=[500, 502, 503, 504]
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        session.mount('http://', adapter)
 | 
				
			||||||
 | 
					        session.mount('https://', adapter)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        # Reconstruct URL without credentials
 | 
					        # Reconstruct URL without credentials
 | 
				
			||||||
        clean_url = f"{parsed.scheme}://{parsed.hostname}"
 | 
					        clean_url = f"{parsed.scheme}://{parsed.hostname}"
 | 
				
			||||||
        if parsed.port:
 | 
					        if parsed.port:
 | 
				
			||||||
| 
						 | 
					@ -136,44 +155,68 @@ def fetch_snapshot(url: str):
 | 
				
			||||||
            clean_url += f"?{parsed.query}"
 | 
					            clean_url += f"?{parsed.query}"
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        auth = None
 | 
					        auth = None
 | 
				
			||||||
 | 
					        response = None
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        if parsed.username and parsed.password:
 | 
					        if parsed.username and parsed.password:
 | 
				
			||||||
            # Try HTTP Digest authentication first (common for IP cameras)
 | 
					            # Try HTTP Digest authentication first (common for IP cameras)
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                auth = HTTPDigestAuth(parsed.username, parsed.password)
 | 
					                auth = HTTPDigestAuth(parsed.username, parsed.password)
 | 
				
			||||||
                response = requests.get(clean_url, auth=auth, headers=headers, timeout=10)
 | 
					                response = session.get(clean_url, auth=auth, headers=headers, timeout=(5, 15), stream=True)
 | 
				
			||||||
                if response.status_code == 200:
 | 
					                if response.status_code == 200:
 | 
				
			||||||
                    logger.debug(f"Successfully authenticated using HTTP Digest for {clean_url}")
 | 
					                    logger.debug(f"Successfully authenticated using HTTP Digest for {clean_url}")
 | 
				
			||||||
                elif response.status_code == 401:
 | 
					                elif response.status_code == 401:
 | 
				
			||||||
                    # If Digest fails, try Basic auth
 | 
					                    # If Digest fails, try Basic auth
 | 
				
			||||||
                    logger.debug(f"HTTP Digest failed, trying Basic auth for {clean_url}")
 | 
					                    logger.debug(f"HTTP Digest failed, trying Basic auth for {clean_url}")
 | 
				
			||||||
                    auth = HTTPBasicAuth(parsed.username, parsed.password)
 | 
					                    auth = HTTPBasicAuth(parsed.username, parsed.password)
 | 
				
			||||||
                    response = requests.get(clean_url, auth=auth, headers=headers, timeout=10)
 | 
					                    response = session.get(clean_url, auth=auth, headers=headers, timeout=(5, 15), stream=True)
 | 
				
			||||||
                    if response.status_code == 200:
 | 
					                    if response.status_code == 200:
 | 
				
			||||||
                        logger.debug(f"Successfully authenticated using HTTP Basic for {clean_url}")
 | 
					                        logger.debug(f"Successfully authenticated using HTTP Basic for {clean_url}")
 | 
				
			||||||
            except Exception as auth_error:
 | 
					            except Exception as auth_error:
 | 
				
			||||||
                logger.debug(f"Authentication setup error: {auth_error}")
 | 
					                logger.debug(f"Authentication setup error: {auth_error}")
 | 
				
			||||||
                # Fallback to original URL with embedded credentials
 | 
					                # Fallback to original URL with embedded credentials
 | 
				
			||||||
                response = requests.get(url, headers=headers, timeout=10)
 | 
					                response = session.get(url, headers=headers, timeout=(5, 15), stream=True)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            # No credentials in URL, make request as-is
 | 
					            # No credentials in URL, make request as-is
 | 
				
			||||||
            response = requests.get(url, headers=headers, timeout=10)
 | 
					            response = session.get(url, headers=headers, timeout=(5, 15), stream=True)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        if response and response.status_code == 200:
 | 
				
			||||||
 | 
					            # Read content with size limit to prevent memory issues
 | 
				
			||||||
 | 
					            content = b''
 | 
				
			||||||
 | 
					            max_size = 10 * 1024 * 1024  # 10MB limit
 | 
				
			||||||
 | 
					            for chunk in response.iter_content(chunk_size=8192):
 | 
				
			||||||
 | 
					                content += chunk
 | 
				
			||||||
 | 
					                if len(content) > max_size:
 | 
				
			||||||
 | 
					                    logger.error(f"Snapshot too large (>{max_size} bytes) from {clean_url}")
 | 
				
			||||||
 | 
					                    return None
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
        if response.status_code == 200:
 | 
					 | 
				
			||||||
            # Convert response content to numpy array
 | 
					            # Convert response content to numpy array
 | 
				
			||||||
            nparr = np.frombuffer(response.content, np.uint8)
 | 
					            nparr = np.frombuffer(content, np.uint8)
 | 
				
			||||||
            # Decode image
 | 
					            # Decode image
 | 
				
			||||||
            frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
 | 
					            frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
 | 
				
			||||||
            if frame is not None:
 | 
					            if frame is not None:
 | 
				
			||||||
                logger.debug(f"Successfully fetched snapshot from {clean_url}, shape: {frame.shape}")
 | 
					                logger.debug(f"Successfully fetched snapshot from {clean_url}, shape: {frame.shape}, size: {len(content)} bytes")
 | 
				
			||||||
                return frame
 | 
					                return frame
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                logger.error(f"Failed to decode image from snapshot URL: {clean_url}")
 | 
					                logger.error(f"Failed to decode image from snapshot URL: {clean_url} (content size: {len(content)} bytes)")
 | 
				
			||||||
                return None
 | 
					                return None
 | 
				
			||||||
        else:
 | 
					        elif response:
 | 
				
			||||||
            logger.error(f"Failed to fetch snapshot (status code {response.status_code}): {clean_url}")
 | 
					            logger.error(f"Failed to fetch snapshot (status code {response.status_code}): {clean_url}")
 | 
				
			||||||
 | 
					            # Log response headers and first part of content for debugging
 | 
				
			||||||
 | 
					            logger.debug(f"Response headers: {dict(response.headers)}")
 | 
				
			||||||
 | 
					            if len(response.content) < 1000:
 | 
				
			||||||
 | 
					                logger.debug(f"Response content: {response.content[:500]}")
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            logger.error(f"No response received from snapshot URL: {clean_url}")
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					    except requests.exceptions.Timeout as e:
 | 
				
			||||||
 | 
					        logger.error(f"Timeout fetching snapshot from {url}: {str(e)}")
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					    except requests.exceptions.ConnectionError as e:
 | 
				
			||||||
 | 
					        logger.error(f"Connection error fetching snapshot from {url}: {str(e)}")
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
    except Exception as e:
 | 
					    except Exception as e:
 | 
				
			||||||
        logger.error(f"Exception fetching snapshot from {url}: {str(e)}")
 | 
					        logger.error(f"Exception fetching snapshot from {url}: {str(e)}", exc_info=True)
 | 
				
			||||||
        return None
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Helper to get crop coordinates from stream
 | 
					# Helper to get crop coordinates from stream
 | 
				
			||||||
| 
						 | 
					@ -324,7 +367,7 @@ async def detect(websocket: WebSocket):
 | 
				
			||||||
            detection_data = {
 | 
					            detection_data = {
 | 
				
			||||||
                "type": "imageDetection",
 | 
					                "type": "imageDetection",
 | 
				
			||||||
                "subscriptionIdentifier": stream["subscriptionIdentifier"],
 | 
					                "subscriptionIdentifier": stream["subscriptionIdentifier"],
 | 
				
			||||||
                "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S.%fZ", time.gmtime()),
 | 
					                "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
 | 
				
			||||||
                # "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()) + f".{int(time.time() * 1000) % 1000:03d}Z",
 | 
					                # "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()) + f".{int(time.time() * 1000) % 1000:03d}Z",
 | 
				
			||||||
                "data": {
 | 
					                "data": {
 | 
				
			||||||
                    "detection": detection_dict,
 | 
					                    "detection": detection_dict,
 | 
				
			||||||
| 
						 | 
					@ -452,6 +495,7 @@ async def detect(websocket: WebSocket):
 | 
				
			||||||
    def snapshot_reader(camera_id, snapshot_url, snapshot_interval, buffer, stop_event):
 | 
					    def snapshot_reader(camera_id, snapshot_url, snapshot_interval, buffer, stop_event):
 | 
				
			||||||
        """Frame reader that fetches snapshots from HTTP/HTTPS URL at specified intervals"""
 | 
					        """Frame reader that fetches snapshots from HTTP/HTTPS URL at specified intervals"""
 | 
				
			||||||
        retries = 0
 | 
					        retries = 0
 | 
				
			||||||
 | 
					        consecutive_failures = 0  # Track consecutive failures for backoff
 | 
				
			||||||
        logger.info(f"Starting snapshot reader thread for camera {camera_id} from {snapshot_url}")
 | 
					        logger.info(f"Starting snapshot reader thread for camera {camera_id} from {snapshot_url}")
 | 
				
			||||||
        frame_count = 0
 | 
					        frame_count = 0
 | 
				
			||||||
        last_log_time = time.time()
 | 
					        last_log_time = time.time()
 | 
				
			||||||
| 
						 | 
					@ -466,15 +510,31 @@ async def detect(websocket: WebSocket):
 | 
				
			||||||
                    frame = fetch_snapshot(snapshot_url)
 | 
					                    frame = fetch_snapshot(snapshot_url)
 | 
				
			||||||
                    
 | 
					                    
 | 
				
			||||||
                    if frame is None:
 | 
					                    if frame is None:
 | 
				
			||||||
                        logger.warning(f"Failed to fetch snapshot for camera: {camera_id}, retry {retries+1}/{max_retries}")
 | 
					                        consecutive_failures += 1
 | 
				
			||||||
 | 
					                        logger.warning(f"Failed to fetch snapshot for camera: {camera_id}, consecutive failures: {consecutive_failures}")
 | 
				
			||||||
                        retries += 1
 | 
					                        retries += 1
 | 
				
			||||||
 | 
					                        
 | 
				
			||||||
 | 
					                        # Check network connectivity with a simple ping-like test
 | 
				
			||||||
 | 
					                        if consecutive_failures % 5 == 1:  # Every 5th failure, test connectivity
 | 
				
			||||||
 | 
					                            try:
 | 
				
			||||||
 | 
					                                test_response = requests.get(snapshot_url, timeout=(2, 5), stream=False)
 | 
				
			||||||
 | 
					                                logger.info(f"Camera {camera_id}: Connectivity test result: {test_response.status_code}")
 | 
				
			||||||
 | 
					                            except Exception as test_error:
 | 
				
			||||||
 | 
					                                logger.warning(f"Camera {camera_id}: Connectivity test failed: {test_error}")
 | 
				
			||||||
 | 
					                        
 | 
				
			||||||
                        if retries > max_retries and max_retries != -1:
 | 
					                        if retries > max_retries and max_retries != -1:
 | 
				
			||||||
                            logger.error(f"Max retries reached for snapshot camera: {camera_id}, stopping reader")
 | 
					                            logger.error(f"Max retries reached for snapshot camera: {camera_id}, stopping reader")
 | 
				
			||||||
                            break
 | 
					                            break
 | 
				
			||||||
                        time.sleep(min(interval_seconds, reconnect_interval))
 | 
					                        
 | 
				
			||||||
 | 
					                        # Exponential backoff based on consecutive failures
 | 
				
			||||||
 | 
					                        backoff_delay = min(30, max(1, min(2 ** min(consecutive_failures - 1, 6), interval_seconds * 2)))  # Start with 1s, max 30s
 | 
				
			||||||
 | 
					                        logger.debug(f"Camera {camera_id}: Backing off for {backoff_delay:.1f}s (consecutive failures: {consecutive_failures})")
 | 
				
			||||||
 | 
					                        if stop_event.wait(backoff_delay):  # Use wait with timeout instead of sleep
 | 
				
			||||||
 | 
					                            break  # Exit if stop_event is set during backoff
 | 
				
			||||||
                        continue
 | 
					                        continue
 | 
				
			||||||
                    
 | 
					                    
 | 
				
			||||||
                    # Successfully fetched a frame
 | 
					                    # Successfully fetched a frame - reset consecutive failures
 | 
				
			||||||
 | 
					                    consecutive_failures = 0  # Reset backoff on success
 | 
				
			||||||
                    frame_count += 1
 | 
					                    frame_count += 1
 | 
				
			||||||
                    current_time = time.time()
 | 
					                    current_time = time.time()
 | 
				
			||||||
                    # Log frame stats every 5 seconds
 | 
					                    # Log frame stats every 5 seconds
 | 
				
			||||||
| 
						 | 
					@ -503,12 +563,18 @@ async def detect(websocket: WebSocket):
 | 
				
			||||||
                        time.sleep(sleep_time)
 | 
					                        time.sleep(sleep_time)
 | 
				
			||||||
                
 | 
					                
 | 
				
			||||||
                except Exception as e:
 | 
					                except Exception as e:
 | 
				
			||||||
 | 
					                    consecutive_failures += 1
 | 
				
			||||||
                    logger.error(f"Unexpected error fetching snapshot for camera {camera_id}: {str(e)}", exc_info=True)
 | 
					                    logger.error(f"Unexpected error fetching snapshot for camera {camera_id}: {str(e)}", exc_info=True)
 | 
				
			||||||
                    retries += 1
 | 
					                    retries += 1
 | 
				
			||||||
                    if retries > max_retries and max_retries != -1:
 | 
					                    if retries > max_retries and max_retries != -1:
 | 
				
			||||||
                        logger.error(f"Max retries reached after error for snapshot camera {camera_id}")
 | 
					                        logger.error(f"Max retries reached after error for snapshot camera {camera_id}")
 | 
				
			||||||
                        break
 | 
					                        break
 | 
				
			||||||
                    time.sleep(min(interval_seconds, reconnect_interval))
 | 
					                    
 | 
				
			||||||
 | 
					                    # Exponential backoff for exceptions too
 | 
				
			||||||
 | 
					                    backoff_delay = min(30, max(1, min(2 ** min(consecutive_failures - 1, 6), interval_seconds * 2)))  # Start with 1s, max 30s
 | 
				
			||||||
 | 
					                    logger.debug(f"Camera {camera_id}: Exception backoff for {backoff_delay:.1f}s (consecutive failures: {consecutive_failures})")
 | 
				
			||||||
 | 
					                    if stop_event.wait(backoff_delay):  # Use wait with timeout instead of sleep
 | 
				
			||||||
 | 
					                        break  # Exit if stop_event is set during backoff
 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
            logger.error(f"Error in snapshot_reader thread for camera {camera_id}: {str(e)}", exc_info=True)
 | 
					            logger.error(f"Error in snapshot_reader thread for camera {camera_id}: {str(e)}", exc_info=True)
 | 
				
			||||||
        finally:
 | 
					        finally:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,6 +13,7 @@ import concurrent.futures
 | 
				
			||||||
from ultralytics import YOLO
 | 
					from ultralytics import YOLO
 | 
				
			||||||
from urllib.parse import urlparse
 | 
					from urllib.parse import urlparse
 | 
				
			||||||
from .database import DatabaseManager
 | 
					from .database import DatabaseManager
 | 
				
			||||||
 | 
					from datetime import datetime
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Create a logger specifically for this module
 | 
					# Create a logger specifically for this module
 | 
				
			||||||
logger = logging.getLogger("detector_worker.pympta")
 | 
					logger = logging.getLogger("detector_worker.pympta")
 | 
				
			||||||
| 
						 | 
					@ -108,6 +109,7 @@ def load_pipeline_node(node_config: dict, mpta_dir: str, redis_client, db_manage
 | 
				
			||||||
        "modelFile": node_config["modelFile"],
 | 
					        "modelFile": node_config["modelFile"],
 | 
				
			||||||
        "triggerClasses": trigger_classes,
 | 
					        "triggerClasses": trigger_classes,
 | 
				
			||||||
        "triggerClassIndices": trigger_class_indices,
 | 
					        "triggerClassIndices": trigger_class_indices,
 | 
				
			||||||
 | 
					        "classMapping": node_config.get("classMapping", {}),
 | 
				
			||||||
        "crop": node_config.get("crop", False),
 | 
					        "crop": node_config.get("crop", False),
 | 
				
			||||||
        "cropClass": node_config.get("cropClass"),
 | 
					        "cropClass": node_config.get("cropClass"),
 | 
				
			||||||
        "minConfidence": node_config.get("minConfidence", None),
 | 
					        "minConfidence": node_config.get("minConfidence", None),
 | 
				
			||||||
| 
						 | 
					@ -608,8 +610,7 @@ def run_detection_with_tracking(frame, node, context=None):
 | 
				
			||||||
            )[0]
 | 
					            )[0]
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        # Process detection results
 | 
					        # Process detection results
 | 
				
			||||||
        all_detections = []
 | 
					        candidate_detections = []
 | 
				
			||||||
        regions_dict = {}
 | 
					 | 
				
			||||||
        min_confidence = node.get("minConfidence", 0.0)
 | 
					        min_confidence = node.get("minConfidence", 0.0)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        if res.boxes is None or len(res.boxes) == 0:
 | 
					        if res.boxes is None or len(res.boxes) == 0:
 | 
				
			||||||
| 
						 | 
					@ -618,6 +619,7 @@ def run_detection_with_tracking(frame, node, context=None):
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        logger.debug(f"Processing {len(res.boxes)} raw detections")
 | 
					        logger.debug(f"Processing {len(res.boxes)} raw detections")
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
					        # First pass: collect all valid detections
 | 
				
			||||||
        for i, box in enumerate(res.boxes):
 | 
					        for i, box in enumerate(res.boxes):
 | 
				
			||||||
            # Extract detection data
 | 
					            # Extract detection data
 | 
				
			||||||
            conf = float(box.cpu().conf[0])
 | 
					            conf = float(box.cpu().conf[0])
 | 
				
			||||||
| 
						 | 
					@ -658,17 +660,39 @@ def run_detection_with_tracking(frame, node, context=None):
 | 
				
			||||||
                "class_id": cls_id
 | 
					                "class_id": cls_id
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
            all_detections.append(detection)
 | 
					            candidate_detections.append(detection)
 | 
				
			||||||
            logger.debug(f"Detection {i} accepted: {class_name} (conf={conf:.3f}, id={track_id}, bbox={bbox})")
 | 
					            logger.debug(f"Detection {i} candidate: {class_name} (conf={conf:.3f}, id={track_id}, bbox={bbox})")
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
            # Update regions_dict with highest confidence detection per class
 | 
					        # Second pass: select only the highest confidence detection overall
 | 
				
			||||||
            if class_name not in regions_dict or conf > regions_dict[class_name]["confidence"]:
 | 
					        if not candidate_detections:
 | 
				
			||||||
                regions_dict[class_name] = {
 | 
					            logger.debug("No valid candidate detections found")
 | 
				
			||||||
                    "bbox": bbox,
 | 
					            return [], {}
 | 
				
			||||||
                    "confidence": conf,
 | 
					        
 | 
				
			||||||
                    "detection": detection,
 | 
					        # Find the single highest confidence detection across all detected classes
 | 
				
			||||||
                    "track_id": track_id
 | 
					        best_detection = max(candidate_detections, key=lambda x: x["confidence"])
 | 
				
			||||||
                }
 | 
					        original_class = best_detection["class"]
 | 
				
			||||||
 | 
					        logger.info(f"Selected highest confidence detection: {original_class} (conf={best_detection['confidence']:.3f})")
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Apply class mapping if configured
 | 
				
			||||||
 | 
					        mapped_class = original_class
 | 
				
			||||||
 | 
					        class_mapping = node.get("classMapping", {})
 | 
				
			||||||
 | 
					        if original_class in class_mapping:
 | 
				
			||||||
 | 
					            mapped_class = class_mapping[original_class]
 | 
				
			||||||
 | 
					            logger.info(f"Class mapping applied: {original_class} → {mapped_class}")
 | 
				
			||||||
 | 
					            # Update the detection object with mapped class
 | 
				
			||||||
 | 
					            best_detection["class"] = mapped_class
 | 
				
			||||||
 | 
					            best_detection["original_class"] = original_class  # Keep original for reference
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Keep only the best detection with mapped class
 | 
				
			||||||
 | 
					        all_detections = [best_detection]
 | 
				
			||||||
 | 
					        regions_dict = {
 | 
				
			||||||
 | 
					            mapped_class: {
 | 
				
			||||||
 | 
					                "bbox": best_detection["bbox"],
 | 
				
			||||||
 | 
					                "confidence": best_detection["confidence"],
 | 
				
			||||||
 | 
					                "detection": best_detection,
 | 
				
			||||||
 | 
					                "track_id": best_detection["id"]
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        # Multi-class validation
 | 
					        # Multi-class validation
 | 
				
			||||||
        if node.get("multiClass", False) and node.get("expectedClasses"):
 | 
					        if node.get("multiClass", False) and node.get("expectedClasses"):
 | 
				
			||||||
| 
						 | 
					@ -964,7 +988,7 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
				
			||||||
            elif "color" in model_id:
 | 
					            elif "color" in model_id:
 | 
				
			||||||
                det["color"] = class_name
 | 
					                det["color"] = class_name
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
            execute_actions(node, frame, det)
 | 
					            execute_actions(node, frame, det, context.get("regions_dict") if context else None)
 | 
				
			||||||
            return (det, None) if return_bbox else det
 | 
					            return (det, None) if return_bbox else det
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # ─── Session management check ───────────────────────────────────────
 | 
					        # ─── Session management check ───────────────────────────────────────
 | 
				
			||||||
| 
						 | 
					@ -1019,13 +1043,14 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
				
			||||||
            **(context or {})
 | 
					            **(context or {})
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        # ─── Create initial database record when Car+Frontal detected ────
 | 
					        # ─── Create initial database record when valid detection found ────
 | 
				
			||||||
        if node.get("db_manager") and node.get("multiClass", False):
 | 
					        if node.get("db_manager") and regions_dict:
 | 
				
			||||||
            # Only create database record if we have both Car and Frontal
 | 
					            # Create database record if we have any valid detection (after class mapping and filtering)
 | 
				
			||||||
            has_car = "Car" in regions_dict
 | 
					            detected_classes = list(regions_dict.keys())
 | 
				
			||||||
            has_frontal = "Frontal" in regions_dict
 | 
					            logger.debug(f"Valid detections found for database record: {detected_classes}")
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
            if has_car and has_frontal:
 | 
					            # Always create record if we have valid detections that passed all filters
 | 
				
			||||||
 | 
					            if detected_classes:
 | 
				
			||||||
                # Generate UUID session_id since client session is None for now
 | 
					                # Generate UUID session_id since client session is None for now
 | 
				
			||||||
                import uuid as uuid_lib
 | 
					                import uuid as uuid_lib
 | 
				
			||||||
                from datetime import datetime
 | 
					                from datetime import datetime
 | 
				
			||||||
| 
						 | 
					@ -1047,9 +1072,12 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
				
			||||||
                    detection_result["timestamp"] = timestamp  # Update with proper timestamp
 | 
					                    detection_result["timestamp"] = timestamp  # Update with proper timestamp
 | 
				
			||||||
                    logger.info(f"Created initial database record with session_id: {inserted_session_id}")
 | 
					                    logger.info(f"Created initial database record with session_id: {inserted_session_id}")
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                logger.debug(f"Database record not created - missing required classes. Has Car: {has_car}, Has Frontal: {has_frontal}")
 | 
					                logger.debug("Database record not created - no valid detections found after filtering")
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        execute_actions(node, frame, detection_result, regions_dict)
 | 
					        # Execute actions for root node only if it doesn't have branches
 | 
				
			||||||
 | 
					        # Branch nodes with actions will execute them after branch processing
 | 
				
			||||||
 | 
					        if not node.get("branches") or node.get("modelId") == "yolo11n":
 | 
				
			||||||
 | 
					            execute_actions(node, frame, detection_result, regions_dict)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # ─── Branch processing (no stability check here) ─────────────────────────────
 | 
					        # ─── Branch processing (no stability check here) ─────────────────────────────
 | 
				
			||||||
        if node["branches"]:
 | 
					        if node["branches"]:
 | 
				
			||||||
| 
						 | 
					@ -1089,21 +1117,28 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
				
			||||||
                        futures = {}
 | 
					                        futures = {}
 | 
				
			||||||
                        
 | 
					                        
 | 
				
			||||||
                        for br in active_branches:
 | 
					                        for br in active_branches:
 | 
				
			||||||
                            crop_class = br.get("cropClass", br.get("triggerClasses", [])[0] if br.get("triggerClasses") else None)
 | 
					 | 
				
			||||||
                            sub_frame = frame
 | 
					                            sub_frame = frame
 | 
				
			||||||
 | 
					                            crop_class = br.get("cropClass")
 | 
				
			||||||
                            
 | 
					                            
 | 
				
			||||||
                            logger.info(f"Starting parallel branch: {br['modelId']}, crop_class: {crop_class}")
 | 
					                            logger.info(f"Starting parallel branch: {br['modelId']}, cropClass: {crop_class}")
 | 
				
			||||||
                            
 | 
					                            
 | 
				
			||||||
                            if br.get("crop", False) and crop_class:
 | 
					                            if br.get("crop", False) and crop_class:
 | 
				
			||||||
                                cropped = crop_region_by_class(frame, regions_dict, crop_class)
 | 
					                                if crop_class in regions_dict:
 | 
				
			||||||
                                if cropped is not None:
 | 
					                                    cropped = crop_region_by_class(frame, regions_dict, crop_class)
 | 
				
			||||||
                                    sub_frame = cv2.resize(cropped, (224, 224))
 | 
					                                    if cropped is not None:
 | 
				
			||||||
                                    logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']}")
 | 
					                                        sub_frame = cropped  # Use cropped image without manual resizing
 | 
				
			||||||
 | 
					                                        logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']} - model will handle resizing")
 | 
				
			||||||
 | 
					                                    else:
 | 
				
			||||||
 | 
					                                        logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch")
 | 
				
			||||||
 | 
					                                        continue
 | 
				
			||||||
                                else:
 | 
					                                else:
 | 
				
			||||||
                                    logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch")
 | 
					                                    logger.warning(f"Crop class {crop_class} not found in detected regions for {br['modelId']}, skipping branch")
 | 
				
			||||||
                                    continue
 | 
					                                    continue
 | 
				
			||||||
                            
 | 
					                            
 | 
				
			||||||
                            future = executor.submit(run_pipeline, sub_frame, br, True, context)
 | 
					                            # Add regions_dict to context for child branches
 | 
				
			||||||
 | 
					                            branch_context = dict(context) if context else {}
 | 
				
			||||||
 | 
					                            branch_context["regions_dict"] = regions_dict
 | 
				
			||||||
 | 
					                            future = executor.submit(run_pipeline, sub_frame, br, True, branch_context)
 | 
				
			||||||
                            futures[future] = br
 | 
					                            futures[future] = br
 | 
				
			||||||
                        
 | 
					                        
 | 
				
			||||||
                        # Collect results
 | 
					                        # Collect results
 | 
				
			||||||
| 
						 | 
					@ -1119,22 +1154,29 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    # Run branches sequentially  
 | 
					                    # Run branches sequentially  
 | 
				
			||||||
                    for br in active_branches:
 | 
					                    for br in active_branches:
 | 
				
			||||||
                        crop_class = br.get("cropClass", br.get("triggerClasses", [])[0] if br.get("triggerClasses") else None)
 | 
					 | 
				
			||||||
                        sub_frame = frame
 | 
					                        sub_frame = frame
 | 
				
			||||||
 | 
					                        crop_class = br.get("cropClass")
 | 
				
			||||||
                        
 | 
					                        
 | 
				
			||||||
                        logger.info(f"Starting sequential branch: {br['modelId']}, crop_class: {crop_class}")
 | 
					                        logger.info(f"Starting sequential branch: {br['modelId']}, cropClass: {crop_class}")
 | 
				
			||||||
                        
 | 
					                        
 | 
				
			||||||
                        if br.get("crop", False) and crop_class:
 | 
					                        if br.get("crop", False) and crop_class:
 | 
				
			||||||
                            cropped = crop_region_by_class(frame, regions_dict, crop_class)
 | 
					                            if crop_class in regions_dict:
 | 
				
			||||||
                            if cropped is not None:
 | 
					                                cropped = crop_region_by_class(frame, regions_dict, crop_class)
 | 
				
			||||||
                                sub_frame = cv2.resize(cropped, (224, 224))
 | 
					                                if cropped is not None:
 | 
				
			||||||
                                logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']}")
 | 
					                                    sub_frame = cropped  # Use cropped image without manual resizing
 | 
				
			||||||
 | 
					                                    logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']} - model will handle resizing")
 | 
				
			||||||
 | 
					                                else:
 | 
				
			||||||
 | 
					                                    logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch")
 | 
				
			||||||
 | 
					                                    continue
 | 
				
			||||||
                            else:
 | 
					                            else:
 | 
				
			||||||
                                logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch")
 | 
					                                logger.warning(f"Crop class {crop_class} not found in detected regions for {br['modelId']}, skipping branch")
 | 
				
			||||||
                                continue
 | 
					                                continue
 | 
				
			||||||
                        
 | 
					                        
 | 
				
			||||||
                        try:
 | 
					                        try:
 | 
				
			||||||
                            result, _ = run_pipeline(sub_frame, br, True, context)
 | 
					                            # Add regions_dict to context for child branches
 | 
				
			||||||
 | 
					                            branch_context = dict(context) if context else {}
 | 
				
			||||||
 | 
					                            branch_context["regions_dict"] = regions_dict
 | 
				
			||||||
 | 
					                            result, _ = run_pipeline(sub_frame, br, True, branch_context)
 | 
				
			||||||
                            if result:
 | 
					                            if result:
 | 
				
			||||||
                                branch_results[br["modelId"]] = result
 | 
					                                branch_results[br["modelId"]] = result
 | 
				
			||||||
                                logger.info(f"Branch {br['modelId']} completed: {result}")
 | 
					                                logger.info(f"Branch {br['modelId']} completed: {result}")
 | 
				
			||||||
| 
						 | 
					@ -1156,6 +1198,14 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
				
			||||||
            start_cooldown_timer(camera_id, model_id)
 | 
					            start_cooldown_timer(camera_id, model_id)
 | 
				
			||||||
            logger.info(f"Camera {camera_id}: Pipeline completed successfully, starting 30s cooldown")
 | 
					            logger.info(f"Camera {camera_id}: Pipeline completed successfully, starting 30s cooldown")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # ─── Execute actions after successful detection AND branch processing ──────────
 | 
				
			||||||
 | 
					        # This ensures detection nodes (like frontal_detection_v1) execute their actions
 | 
				
			||||||
 | 
					        # after completing both detection and branch processing
 | 
				
			||||||
 | 
					        if node.get("actions") and regions_dict and node.get("modelId") != "yolo11n":
 | 
				
			||||||
 | 
					            # Execute actions for branch detection nodes, skip root to avoid duplication
 | 
				
			||||||
 | 
					            logger.debug(f"Executing post-detection actions for branch node {node.get('modelId')}")
 | 
				
			||||||
 | 
					            execute_actions(node, frame, detection_result, regions_dict)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # ─── Return detection result ────────────────────────────────
 | 
					        # ─── Return detection result ────────────────────────────────
 | 
				
			||||||
        primary_detection = max(all_detections, key=lambda x: x["confidence"])
 | 
					        primary_detection = max(all_detections, key=lambda x: x["confidence"])
 | 
				
			||||||
        primary_bbox = primary_detection["bbox"]
 | 
					        primary_bbox = primary_detection["bbox"]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										
											BIN
										
									
								
								test/sample.png
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								test/sample.png
									
										
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 2.8 MiB  | 
							
								
								
									
										
											BIN
										
									
								
								test/sample2.png
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								test/sample2.png
									
										
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 3.1 MiB  | 
							
								
								
									
										60
									
								
								test/test.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								test/test.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,60 @@
 | 
				
			||||||
 | 
					from ultralytics import YOLO
 | 
				
			||||||
 | 
					import cv2
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Load the model
 | 
				
			||||||
 | 
					# model = YOLO('../models/webcam-local-01/4/bangchak_poc/yolo11n.pt')
 | 
				
			||||||
 | 
					model = YOLO('yolo11m.pt')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_image(image_path):
 | 
				
			||||||
 | 
					    """Test a single image with YOLO model"""
 | 
				
			||||||
 | 
					    if not os.path.exists(image_path):
 | 
				
			||||||
 | 
					        print(f"Image not found: {image_path}")
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Run inference - filter for car class only (class 2 in COCO)
 | 
				
			||||||
 | 
					    results = model(image_path, classes=[2, 5, 7])  # 2, 5, 7 = car, bus, truck in COCO dataset
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Display results
 | 
				
			||||||
 | 
					    for r in results:
 | 
				
			||||||
 | 
					        im_array = r.plot()  # plot a BGR numpy array of predictions
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Resize image for display (max width/height 800px)
 | 
				
			||||||
 | 
					        height, width = im_array.shape[:2]
 | 
				
			||||||
 | 
					        max_dimension = 800
 | 
				
			||||||
 | 
					        if width > max_dimension or height > max_dimension:
 | 
				
			||||||
 | 
					            if width > height:
 | 
				
			||||||
 | 
					                new_width = max_dimension
 | 
				
			||||||
 | 
					                new_height = int(height * (max_dimension / width))
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                new_height = max_dimension
 | 
				
			||||||
 | 
					                new_width = int(width * (max_dimension / height))
 | 
				
			||||||
 | 
					            im_array = cv2.resize(im_array, (new_width, new_height))
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Show image with predictions
 | 
				
			||||||
 | 
					        cv2.imshow('YOLO Test - Car Detection Only', im_array)
 | 
				
			||||||
 | 
					        cv2.waitKey(0)
 | 
				
			||||||
 | 
					        cv2.destroyAllWindows()
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Print detection info
 | 
				
			||||||
 | 
					        print(f"\nDetections for {image_path}:")
 | 
				
			||||||
 | 
					        if r.boxes is not None and len(r.boxes) > 0:
 | 
				
			||||||
 | 
					            for i, box in enumerate(r.boxes):
 | 
				
			||||||
 | 
					                cls = int(box.cls[0])
 | 
				
			||||||
 | 
					                conf = float(box.conf[0])
 | 
				
			||||||
 | 
					                original_class = model.names[cls]  # Original class name (car/bus/truck)
 | 
				
			||||||
 | 
					                # Get bounding box coordinates
 | 
				
			||||||
 | 
					                x1, y1, x2, y2 = box.xyxy[0].tolist()
 | 
				
			||||||
 | 
					                # Rename all vehicle types to "car"
 | 
				
			||||||
 | 
					                print(f"Detection {i+1}: car (was: {original_class}) - Confidence: {conf:.3f} - BBox: ({x1:.0f}, {y1:.0f}, {x2:.0f}, {y2:.0f})")
 | 
				
			||||||
 | 
					            print(f"Total cars detected: {len(r.boxes)}")
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            print("No cars detected in the image")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Test with an image file
 | 
				
			||||||
 | 
					    image_path = input("Enter image path (or press Enter for default test): ")
 | 
				
			||||||
 | 
					    if not image_path:
 | 
				
			||||||
 | 
					        image_path = "sample.png"  # Default test image
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    test_image(image_path)
 | 
				
			||||||
| 
						 | 
					@ -1,190 +0,0 @@
 | 
				
			||||||
#!/usr/bin/env python3
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
Test script for the refactored detection and tracking functionality.
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
import sys
 | 
					 | 
				
			||||||
import cv2
 | 
					 | 
				
			||||||
import numpy as np
 | 
					 | 
				
			||||||
import logging
 | 
					 | 
				
			||||||
from pathlib import Path
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Add the project root to Python path
 | 
					 | 
				
			||||||
sys.path.insert(0, str(Path(__file__).parent))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from siwatsystem.pympta import run_detection_with_tracking, load_pipeline_from_zip
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Set up logging
 | 
					 | 
				
			||||||
logging.basicConfig(
 | 
					 | 
				
			||||||
    level=logging.DEBUG,
 | 
					 | 
				
			||||||
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
logger = logging.getLogger(__name__)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def create_test_frame():
 | 
					 | 
				
			||||||
    """Create a simple test frame for detection testing."""
 | 
					 | 
				
			||||||
    frame = np.zeros((480, 640, 3), dtype=np.uint8)
 | 
					 | 
				
			||||||
    # Add some simple shapes to simulate objects
 | 
					 | 
				
			||||||
    cv2.rectangle(frame, (50, 50), (200, 150), (255, 0, 0), -1)  # Blue rectangle
 | 
					 | 
				
			||||||
    cv2.rectangle(frame, (300, 200), (450, 350), (0, 255, 0), -1)  # Green rectangle
 | 
					 | 
				
			||||||
    cv2.putText(frame, "Test Frame", (250, 400), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
 | 
					 | 
				
			||||||
    return frame
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def test_detection_function():
 | 
					 | 
				
			||||||
    """Test the structured detection function with mock data."""
 | 
					 | 
				
			||||||
    logger.info("Testing run_detection_with_tracking function...")
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Create test frame
 | 
					 | 
				
			||||||
    test_frame = create_test_frame()
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Mock node configuration (simulating what would come from an MPTA file)
 | 
					 | 
				
			||||||
    mock_node = {
 | 
					 | 
				
			||||||
        "modelId": "test_detection_v1",
 | 
					 | 
				
			||||||
        "triggerClasses": ["car", "person"],
 | 
					 | 
				
			||||||
        "triggerClassIndices": [0, 1],
 | 
					 | 
				
			||||||
        "minConfidence": 0.5,
 | 
					 | 
				
			||||||
        "multiClass": False,
 | 
					 | 
				
			||||||
        "expectedClasses": [],
 | 
					 | 
				
			||||||
        "tracking": {
 | 
					 | 
				
			||||||
            "enabled": True,
 | 
					 | 
				
			||||||
            "reidConfigPath": "botsort.yaml"
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Mock context
 | 
					 | 
				
			||||||
    test_context = {
 | 
					 | 
				
			||||||
        "display_id": "test-display-001",
 | 
					 | 
				
			||||||
        "camera_id": "test-cam-001"
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    logger.info("Mock node configuration:")
 | 
					 | 
				
			||||||
    for key, value in mock_node.items():
 | 
					 | 
				
			||||||
        logger.info(f"  {key}: {value}")
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Note: This test will fail without a real YOLO model, but demonstrates the structure
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        detections, regions = run_detection_with_tracking(test_frame, mock_node, test_context)
 | 
					 | 
				
			||||||
        logger.info(f"Function executed successfully!")
 | 
					 | 
				
			||||||
        logger.info(f"Returned detections: {len(detections)}")
 | 
					 | 
				
			||||||
        logger.info(f"Returned regions: {list(regions.keys())}")
 | 
					 | 
				
			||||||
        return True
 | 
					 | 
				
			||||||
    except Exception as e:
 | 
					 | 
				
			||||||
        logger.error(f"Function failed (expected without real model): {e}")
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def test_mpta_loading():
 | 
					 | 
				
			||||||
    """Test loading an MPTA file with tracking configuration."""
 | 
					 | 
				
			||||||
    logger.info("Testing MPTA loading with tracking configuration...")
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Check if models directory exists
 | 
					 | 
				
			||||||
    models_dir = Path("models")
 | 
					 | 
				
			||||||
    if not models_dir.exists():
 | 
					 | 
				
			||||||
        logger.warning("No models directory found - skipping MPTA test")
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Look for any .mpta files
 | 
					 | 
				
			||||||
    mpta_files = list(models_dir.glob("**/*.mpta"))
 | 
					 | 
				
			||||||
    if not mpta_files:
 | 
					 | 
				
			||||||
        logger.warning("No .mpta files found in models directory - skipping MPTA test")
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    mpta_file = mpta_files[0]
 | 
					 | 
				
			||||||
    logger.info(f"Testing with MPTA file: {mpta_file}")
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        # Attempt to load pipeline
 | 
					 | 
				
			||||||
        target_dir = f"temp_test_{os.getpid()}"
 | 
					 | 
				
			||||||
        pipeline = load_pipeline_from_zip(str(mpta_file), target_dir)
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        if pipeline:
 | 
					 | 
				
			||||||
            logger.info("MPTA loaded successfully!")
 | 
					 | 
				
			||||||
            logger.info(f"Pipeline model ID: {pipeline.get('modelId')}")
 | 
					 | 
				
			||||||
            logger.info(f"Tracking config: {pipeline.get('tracking')}")
 | 
					 | 
				
			||||||
            
 | 
					 | 
				
			||||||
            # Clean up
 | 
					 | 
				
			||||||
            import shutil
 | 
					 | 
				
			||||||
            if os.path.exists(target_dir):
 | 
					 | 
				
			||||||
                shutil.rmtree(target_dir)
 | 
					 | 
				
			||||||
            
 | 
					 | 
				
			||||||
            return True
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            logger.error("Failed to load MPTA pipeline")
 | 
					 | 
				
			||||||
            return False
 | 
					 | 
				
			||||||
            
 | 
					 | 
				
			||||||
    except Exception as e:
 | 
					 | 
				
			||||||
        logger.error(f"MPTA loading failed: {e}")
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def print_usage_example():
 | 
					 | 
				
			||||||
    """Print example usage of the new structured functions."""
 | 
					 | 
				
			||||||
    logger.info("\n" + "="*60)
 | 
					 | 
				
			||||||
    logger.info("USAGE EXAMPLE - Structured Detection & Tracking")
 | 
					 | 
				
			||||||
    logger.info("="*60)
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    example_config = '''
 | 
					 | 
				
			||||||
    Example pipeline.json configuration:
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    {
 | 
					 | 
				
			||||||
      "pipeline": {
 | 
					 | 
				
			||||||
        "modelId": "car_frontal_detection_v1",
 | 
					 | 
				
			||||||
        "modelFile": "yolo11n.pt",
 | 
					 | 
				
			||||||
        "triggerClasses": ["Car", "Frontal"],
 | 
					 | 
				
			||||||
        "minConfidence": 0.7,
 | 
					 | 
				
			||||||
        "multiClass": true,
 | 
					 | 
				
			||||||
        "expectedClasses": ["Car", "Frontal"],
 | 
					 | 
				
			||||||
        "tracking": {
 | 
					 | 
				
			||||||
          "enabled": true,
 | 
					 | 
				
			||||||
          "reidConfigPath": "botsort_reid.yaml"
 | 
					 | 
				
			||||||
        },
 | 
					 | 
				
			||||||
        "actions": [...],
 | 
					 | 
				
			||||||
        "branches": [...]
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    '''
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    logger.info(example_config)
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    code_example = '''
 | 
					 | 
				
			||||||
    Usage in code:
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Load pipeline from MPTA file
 | 
					 | 
				
			||||||
    pipeline = load_pipeline_from_zip("model.mpta", "temp_dir")
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Run detection with tracking
 | 
					 | 
				
			||||||
    detections, regions = run_detection_with_tracking(frame, pipeline, context)
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Process results
 | 
					 | 
				
			||||||
    for detection in detections:
 | 
					 | 
				
			||||||
        class_name = detection["class"]
 | 
					 | 
				
			||||||
        confidence = detection["confidence"]
 | 
					 | 
				
			||||||
        track_id = detection["id"]  # Available when tracking enabled
 | 
					 | 
				
			||||||
        bbox = detection["bbox"]
 | 
					 | 
				
			||||||
        print(f"Detected: {class_name} (ID: {track_id}, conf: {confidence:.2f})")
 | 
					 | 
				
			||||||
    '''
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    logger.info(code_example)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def main():
 | 
					 | 
				
			||||||
    """Main test function."""
 | 
					 | 
				
			||||||
    logger.info("Starting detection & tracking refactoring tests...")
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Test 1: Function structure
 | 
					 | 
				
			||||||
    test1_passed = test_detection_function()
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Test 2: MPTA loading
 | 
					 | 
				
			||||||
    test2_passed = test_mpta_loading()
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Print usage examples
 | 
					 | 
				
			||||||
    print_usage_example()
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Summary
 | 
					 | 
				
			||||||
    logger.info("\n" + "="*60)
 | 
					 | 
				
			||||||
    logger.info("TEST SUMMARY")
 | 
					 | 
				
			||||||
    logger.info("="*60)
 | 
					 | 
				
			||||||
    logger.info(f"Function structure test: {'PASS' if test1_passed else 'EXPECTED FAIL (no model)'}")
 | 
					 | 
				
			||||||
    logger.info(f"MPTA loading test: {'PASS' if test2_passed else 'SKIP (no files)'}")
 | 
					 | 
				
			||||||
    logger.info("\nRefactoring completed successfully!")
 | 
					 | 
				
			||||||
    logger.info("The detection and tracking code is now structured and easy to configure.")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					 | 
				
			||||||
    main()
 | 
					 | 
				
			||||||
							
								
								
									
										162
									
								
								view_redis_images.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										162
									
								
								view_redis_images.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,162 @@
 | 
				
			||||||
 | 
					#!/usr/bin/env python3
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					Script to view frontal images saved in Redis
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					import redis
 | 
				
			||||||
 | 
					import cv2
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					from datetime import datetime
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Redis connection config (from pipeline.json)
 | 
				
			||||||
 | 
					REDIS_CONFIG = {
 | 
				
			||||||
 | 
					    "host": "10.100.1.3",
 | 
				
			||||||
 | 
					    "port": 6379,
 | 
				
			||||||
 | 
					    "password": "FBQgi0i5RevAAMO5Hh66",
 | 
				
			||||||
 | 
					    "db": 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def connect_redis():
 | 
				
			||||||
 | 
					    """Connect to Redis server."""
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        client = redis.Redis(
 | 
				
			||||||
 | 
					            host=REDIS_CONFIG["host"],
 | 
				
			||||||
 | 
					            port=REDIS_CONFIG["port"],
 | 
				
			||||||
 | 
					            password=REDIS_CONFIG["password"],
 | 
				
			||||||
 | 
					            db=REDIS_CONFIG["db"],
 | 
				
			||||||
 | 
					            decode_responses=False  # Keep bytes for images
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        client.ping()
 | 
				
			||||||
 | 
					        print(f"✅ Connected to Redis at {REDIS_CONFIG['host']}:{REDIS_CONFIG['port']}")
 | 
				
			||||||
 | 
					        return client
 | 
				
			||||||
 | 
					    except redis.exceptions.ConnectionError as e:
 | 
				
			||||||
 | 
					        print(f"❌ Failed to connect to Redis: {e}")
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def list_image_keys(client):
 | 
				
			||||||
 | 
					    """List all image keys in Redis."""
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        # Look for keys matching the inference pattern
 | 
				
			||||||
 | 
					        keys = client.keys("inference:*")
 | 
				
			||||||
 | 
					        print(f"\n📋 Found {len(keys)} image keys:")
 | 
				
			||||||
 | 
					        for i, key in enumerate(keys):
 | 
				
			||||||
 | 
					            key_str = key.decode() if isinstance(key, bytes) else key
 | 
				
			||||||
 | 
					            print(f"{i+1}. {key_str}")
 | 
				
			||||||
 | 
					        return keys
 | 
				
			||||||
 | 
					    except Exception as e:
 | 
				
			||||||
 | 
					        print(f"❌ Error listing keys: {e}")
 | 
				
			||||||
 | 
					        return []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def view_image(client, key):
 | 
				
			||||||
 | 
					    """View a specific image from Redis."""
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        # Get image data from Redis
 | 
				
			||||||
 | 
					        image_data = client.get(key)
 | 
				
			||||||
 | 
					        if image_data is None:
 | 
				
			||||||
 | 
					            print(f"❌ No data found for key: {key}")
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        print(f"📸 Image size: {len(image_data)} bytes")
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Convert bytes to numpy array
 | 
				
			||||||
 | 
					        nparr = np.frombuffer(image_data, np.uint8)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Decode image
 | 
				
			||||||
 | 
					        img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
 | 
				
			||||||
 | 
					        if img is None:
 | 
				
			||||||
 | 
					            print("❌ Failed to decode image data")
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        print(f"🖼️  Image dimensions: {img.shape[1]}x{img.shape[0]} pixels")
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Display image
 | 
				
			||||||
 | 
					        key_str = key.decode() if isinstance(key, bytes) else key
 | 
				
			||||||
 | 
					        cv2.imshow(f'Redis Image: {key_str}', img)
 | 
				
			||||||
 | 
					        print("👁️  Image displayed. Press any key to close...")
 | 
				
			||||||
 | 
					        cv2.waitKey(0)
 | 
				
			||||||
 | 
					        cv2.destroyAllWindows()
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Ask if user wants to save the image
 | 
				
			||||||
 | 
					        save = input("💾 Save image to file? (y/n): ").lower().strip()
 | 
				
			||||||
 | 
					        if save == 'y':
 | 
				
			||||||
 | 
					            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
 | 
				
			||||||
 | 
					            filename = f"redis_image_{timestamp}.jpg"
 | 
				
			||||||
 | 
					            cv2.imwrite(filename, img)
 | 
				
			||||||
 | 
					            print(f"💾 Image saved as: {filename}")
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    except Exception as e:
 | 
				
			||||||
 | 
					        print(f"❌ Error viewing image: {e}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def monitor_new_images(client):
 | 
				
			||||||
 | 
					    """Monitor for new images being added to Redis."""
 | 
				
			||||||
 | 
					    print("👀 Monitoring for new images... (Press Ctrl+C to stop)")
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        # Subscribe to Redis pub/sub for car detections
 | 
				
			||||||
 | 
					        pubsub = client.pubsub()
 | 
				
			||||||
 | 
					        pubsub.subscribe('car_detections')
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        for message in pubsub.listen():
 | 
				
			||||||
 | 
					            if message['type'] == 'message':
 | 
				
			||||||
 | 
					                data = message['data'].decode()
 | 
				
			||||||
 | 
					                print(f"🚨 New detection: {data}")
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					                # Try to extract image key from message
 | 
				
			||||||
 | 
					                import json
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    detection_data = json.loads(data)
 | 
				
			||||||
 | 
					                    image_key = detection_data.get('image_key')
 | 
				
			||||||
 | 
					                    if image_key:
 | 
				
			||||||
 | 
					                        print(f"🖼️  New image available: {image_key}")
 | 
				
			||||||
 | 
					                        view_choice = input("View this image now? (y/n): ").lower().strip()
 | 
				
			||||||
 | 
					                        if view_choice == 'y':
 | 
				
			||||||
 | 
					                            view_image(client, image_key)
 | 
				
			||||||
 | 
					                except json.JSONDecodeError:
 | 
				
			||||||
 | 
					                    pass
 | 
				
			||||||
 | 
					                    
 | 
				
			||||||
 | 
					    except KeyboardInterrupt:
 | 
				
			||||||
 | 
					        print("\n👋 Stopping monitor...")
 | 
				
			||||||
 | 
					    except Exception as e:
 | 
				
			||||||
 | 
					        print(f"❌ Monitor error: {e}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main():
 | 
				
			||||||
 | 
					    """Main function."""
 | 
				
			||||||
 | 
					    print("🔍 Redis Image Viewer")
 | 
				
			||||||
 | 
					    print("=" * 50)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Connect to Redis
 | 
				
			||||||
 | 
					    client = connect_redis()
 | 
				
			||||||
 | 
					    if not client:
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        print("\n📋 Options:")
 | 
				
			||||||
 | 
					        print("1. List all image keys")
 | 
				
			||||||
 | 
					        print("2. View specific image")
 | 
				
			||||||
 | 
					        print("3. Monitor for new images")
 | 
				
			||||||
 | 
					        print("4. Exit")
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        choice = input("\nEnter choice (1-4): ").strip()
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        if choice == '1':
 | 
				
			||||||
 | 
					            keys = list_image_keys(client)
 | 
				
			||||||
 | 
					        elif choice == '2':
 | 
				
			||||||
 | 
					            keys = list_image_keys(client)
 | 
				
			||||||
 | 
					            if keys:
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    idx = int(input(f"\nEnter image number (1-{len(keys)}): ")) - 1
 | 
				
			||||||
 | 
					                    if 0 <= idx < len(keys):
 | 
				
			||||||
 | 
					                        view_image(client, keys[idx])
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        print("❌ Invalid selection")
 | 
				
			||||||
 | 
					                except ValueError:
 | 
				
			||||||
 | 
					                    print("❌ Please enter a valid number")
 | 
				
			||||||
 | 
					        elif choice == '3':
 | 
				
			||||||
 | 
					            monitor_new_images(client)
 | 
				
			||||||
 | 
					        elif choice == '4':
 | 
				
			||||||
 | 
					            print("👋 Goodbye!")
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            print("❌ Invalid choice")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    main()
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue