Compare commits
1 commit
dev
...
feat/new-l
Author | SHA1 | Date | |
---|---|---|---|
|
aa4e0463d4 |
19 changed files with 761 additions and 6185 deletions
|
@ -1,112 +0,0 @@
|
||||||
name: Build Worker Base and Application Images
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
- dev
|
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
force_base_build:
|
|
||||||
description: 'Force base image build regardless of changes'
|
|
||||||
required: false
|
|
||||||
default: 'false'
|
|
||||||
type: boolean
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
check-base-changes:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
outputs:
|
|
||||||
base-changed: ${{ steps.changes.outputs.base-changed }}
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 2
|
|
||||||
- name: Check for base changes
|
|
||||||
id: changes
|
|
||||||
run: |
|
|
||||||
if git diff HEAD^ HEAD --name-only | grep -E "(Dockerfile\.base|requirements\.base\.txt)" > /dev/null; then
|
|
||||||
echo "base-changed=true" >> $GITHUB_OUTPUT
|
|
||||||
else
|
|
||||||
echo "base-changed=false" >> $GITHUB_OUTPUT
|
|
||||||
fi
|
|
||||||
|
|
||||||
build-base:
|
|
||||||
needs: check-base-changes
|
|
||||||
if: needs.check-base-changes.outputs.base-changed == 'true' || (github.event_name == 'workflow_dispatch' && github.event.inputs.force_base_build == 'true')
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
packages: write
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v2
|
|
||||||
|
|
||||||
- name: Login to GitHub Container Registry
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
registry: git.siwatsystem.com
|
|
||||||
username: ${{ github.actor }}
|
|
||||||
password: ${{ secrets.RUNNER_TOKEN }}
|
|
||||||
|
|
||||||
- name: Build and push base Docker image
|
|
||||||
uses: docker/build-push-action@v4
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
file: ./Dockerfile.base
|
|
||||||
push: true
|
|
||||||
tags: git.siwatsystem.com/adsist-cms/worker-base:latest
|
|
||||||
|
|
||||||
build-docker:
|
|
||||||
needs: [check-base-changes, build-base]
|
|
||||||
if: always() && (needs.build-base.result == 'success' || needs.build-base.result == 'skipped')
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
packages: write
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v2
|
|
||||||
|
|
||||||
- name: Login to GitHub Container Registry
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
registry: git.siwatsystem.com
|
|
||||||
username: ${{ github.actor }}
|
|
||||||
password: ${{ secrets.RUNNER_TOKEN }}
|
|
||||||
|
|
||||||
- name: Build and push Docker image
|
|
||||||
uses: docker/build-push-action@v4
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
file: ./Dockerfile
|
|
||||||
push: true
|
|
||||||
tags: git.siwatsystem.com/adsist-cms/worker:${{ github.ref_name == 'main' && 'latest' || 'dev' }}
|
|
||||||
|
|
||||||
deploy-stack:
|
|
||||||
needs: build-docker
|
|
||||||
runs-on: adsist
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
- name: Set up SSH connection
|
|
||||||
run: |
|
|
||||||
mkdir -p ~/.ssh
|
|
||||||
echo "${{ secrets.DEPLOY_KEY_CMS }}" > ~/.ssh/id_rsa
|
|
||||||
chmod 600 ~/.ssh/id_rsa
|
|
||||||
ssh-keyscan -H ${{ vars.DEPLOY_HOST_CMS }} >> ~/.ssh/known_hosts
|
|
||||||
- name: Deploy stack
|
|
||||||
run: |
|
|
||||||
echo "Pulling and starting containers on server..."
|
|
||||||
if [ "${{ github.ref_name }}" = "main" ]; then
|
|
||||||
echo "Deploying production stack..."
|
|
||||||
ssh -i ~/.ssh/id_rsa ${{ vars.DEPLOY_USER_CMS }}@${{ vars.DEPLOY_HOST_CMS }} "cd ~/cms-system-k8s && docker compose -f docker-compose.production.yml pull && docker compose -f docker-compose.production.yml up -d"
|
|
||||||
else
|
|
||||||
echo "Deploying staging stack..."
|
|
||||||
ssh -i ~/.ssh/id_rsa ${{ vars.DEPLOY_USER_CMS }}@${{ vars.DEPLOY_HOST_CMS }} "cd ~/cms-system-k8s && docker compose -f docker-compose.staging.yml pull && docker compose -f docker-compose.staging.yml up -d"
|
|
||||||
fi
|
|
6
.gitignore
vendored
6
.gitignore
vendored
|
@ -7,9 +7,3 @@ __pycache__/
|
||||||
.mptacache
|
.mptacache
|
||||||
|
|
||||||
mptas
|
mptas
|
||||||
detector_worker.log
|
|
||||||
.gitignore
|
|
||||||
no_frame_debug.log
|
|
||||||
|
|
||||||
feeder/
|
|
||||||
.venv/
|
|
||||||
|
|
277
CLAUDE.md
277
CLAUDE.md
|
@ -1,277 +0,0 @@
|
||||||
# Python Detector Worker - CLAUDE.md
|
|
||||||
|
|
||||||
## Project Overview
|
|
||||||
This is a FastAPI-based computer vision detection worker that processes video streams from RTSP/HTTP sources and runs advanced YOLO-based machine learning pipelines for multi-class object detection and parallel classification. The system features comprehensive database integration, Redis support, and hierarchical pipeline execution designed to work within a larger CMS (Content Management System) architecture.
|
|
||||||
|
|
||||||
### Key Features
|
|
||||||
- **Multi-Class Detection**: Simultaneous detection of multiple object classes (e.g., Car + Frontal)
|
|
||||||
- **Parallel Processing**: Concurrent execution of classification branches using ThreadPoolExecutor
|
|
||||||
- **Database Integration**: Automatic PostgreSQL schema management and record updates
|
|
||||||
- **Redis Actions**: Image storage with region cropping and pub/sub messaging
|
|
||||||
- **Pipeline Synchronization**: Branch coordination with `waitForBranches` functionality
|
|
||||||
- **Dynamic Field Mapping**: Template-based field resolution for database operations
|
|
||||||
|
|
||||||
## Architecture & Technology Stack
|
|
||||||
- **Framework**: FastAPI with WebSocket support
|
|
||||||
- **ML/CV**: PyTorch, Ultralytics YOLO, OpenCV
|
|
||||||
- **Containerization**: Docker (Python 3.13-bookworm base)
|
|
||||||
- **Data Storage**: Redis integration for action handling + PostgreSQL for persistent storage
|
|
||||||
- **Database**: Automatic schema management with gas_station_1 database
|
|
||||||
- **Parallel Processing**: ThreadPoolExecutor for concurrent classification
|
|
||||||
- **Communication**: WebSocket-based real-time protocol
|
|
||||||
|
|
||||||
## Core Components
|
|
||||||
|
|
||||||
### Main Application (`app.py`)
|
|
||||||
- **FastAPI WebSocket server** for real-time communication
|
|
||||||
- **Multi-camera stream management** with shared stream optimization
|
|
||||||
- **HTTP REST endpoint** for image retrieval (`/camera/{camera_id}/image`)
|
|
||||||
- **Threading-based frame readers** for RTSP streams and HTTP snapshots
|
|
||||||
- **Model loading and inference** using MPTA (Machine Learning Pipeline Archive) format
|
|
||||||
- **Session management** with display identifier mapping
|
|
||||||
- **Resource monitoring** (CPU, memory, GPU usage via psutil)
|
|
||||||
|
|
||||||
### Pipeline System (`siwatsystem/pympta.py`)
|
|
||||||
- **MPTA file handling** - ZIP archives containing model configurations
|
|
||||||
- **Hierarchical pipeline execution** with detection → classification branching
|
|
||||||
- **Multi-class detection** - Simultaneous detection of multiple classes (Car + Frontal)
|
|
||||||
- **Parallel processing** - Concurrent classification branches with ThreadPoolExecutor
|
|
||||||
- **Redis action system** - Image saving with region cropping and message publishing
|
|
||||||
- **PostgreSQL integration** - Automatic table creation and combined updates
|
|
||||||
- **Dynamic model loading** with GPU optimization
|
|
||||||
- **Configurable trigger classes and confidence thresholds**
|
|
||||||
- **Branch synchronization** - waitForBranches coordination for database updates
|
|
||||||
|
|
||||||
### Database System (`siwatsystem/database.py`)
|
|
||||||
- **DatabaseManager class** for PostgreSQL operations
|
|
||||||
- **Automatic table creation** with gas_station_1.car_frontal_info schema
|
|
||||||
- **Combined update operations** with field mapping from branch results
|
|
||||||
- **Session management** with UUID generation
|
|
||||||
- **Error handling** and connection management
|
|
||||||
|
|
||||||
### Testing & Debugging
|
|
||||||
- **Protocol test script** (`test_protocol.py`) for WebSocket communication validation
|
|
||||||
- **Pipeline webcam utility** (`pipeline_webcam.py`) for local testing with visual output
|
|
||||||
- **RTSP streaming debug tool** (`debug/rtsp_webcam.py`) using GStreamer
|
|
||||||
|
|
||||||
## Code Conventions & Patterns
|
|
||||||
|
|
||||||
### Logging
|
|
||||||
- **Structured logging** using Python's logging module
|
|
||||||
- **File + console output** to `detector_worker.log`
|
|
||||||
- **Debug level separation** for detailed troubleshooting
|
|
||||||
- **Context-aware messages** with camera IDs and model information
|
|
||||||
|
|
||||||
### Error Handling
|
|
||||||
- **Graceful failure handling** with retry mechanisms (configurable max_retries)
|
|
||||||
- **Thread-safe operations** using locks for streams and models
|
|
||||||
- **WebSocket disconnect handling** with proper cleanup
|
|
||||||
- **Model loading validation** with detailed error reporting
|
|
||||||
|
|
||||||
### Configuration
|
|
||||||
- **JSON configuration** (`config.json`) for runtime parameters:
|
|
||||||
- `poll_interval_ms`: Frame processing interval
|
|
||||||
- `max_streams`: Concurrent stream limit
|
|
||||||
- `target_fps`: Target frame rate
|
|
||||||
- `reconnect_interval_sec`: Stream reconnection delay
|
|
||||||
- `max_retries`: Maximum retry attempts (-1 for unlimited)
|
|
||||||
|
|
||||||
### Threading Model
|
|
||||||
- **Frame reader threads** for each camera stream (RTSP/HTTP)
|
|
||||||
- **Shared stream optimization** - multiple subscriptions can reuse the same camera stream
|
|
||||||
- **Async WebSocket handling** with concurrent task management
|
|
||||||
- **Thread-safe data structures** with proper locking mechanisms
|
|
||||||
|
|
||||||
## WebSocket Protocol
|
|
||||||
|
|
||||||
### Message Types
|
|
||||||
- **subscribe**: Start camera stream with model pipeline
|
|
||||||
- **unsubscribe**: Stop camera stream processing
|
|
||||||
- **requestState**: Request current worker status
|
|
||||||
- **setSessionId**: Associate display with session identifier
|
|
||||||
- **patchSession**: Update session data
|
|
||||||
- **stateReport**: Periodic heartbeat with system metrics
|
|
||||||
- **imageDetection**: Detection results with timestamp and model info
|
|
||||||
|
|
||||||
### Subscription Format
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "subscribe",
|
|
||||||
"payload": {
|
|
||||||
"subscriptionIdentifier": "display-001;cam-001",
|
|
||||||
"rtspUrl": "rtsp://...", // OR snapshotUrl
|
|
||||||
"snapshotUrl": "http://...",
|
|
||||||
"snapshotInterval": 5000,
|
|
||||||
"modelUrl": "http://...model.mpta",
|
|
||||||
"modelId": 101,
|
|
||||||
"modelName": "Vehicle Detection",
|
|
||||||
"cropX1": 100, "cropY1": 200,
|
|
||||||
"cropX2": 300, "cropY2": 400
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Model Pipeline (MPTA) Format
|
|
||||||
|
|
||||||
### Enhanced Structure
|
|
||||||
- **ZIP archive** containing models and configuration
|
|
||||||
- **pipeline.json** - Main configuration file with Redis + PostgreSQL settings
|
|
||||||
- **Model files** - YOLO .pt files for detection/classification
|
|
||||||
- **Multi-model support** - Detection + multiple classification models
|
|
||||||
|
|
||||||
### Advanced Pipeline Flow
|
|
||||||
1. **Multi-class detection stage** - YOLO detection of Car + Frontal simultaneously
|
|
||||||
2. **Validation stage** - Check for expected classes (flexible matching)
|
|
||||||
3. **Database initialization** - Create initial record with session_id
|
|
||||||
4. **Redis actions** - Save cropped frontal images with expiration
|
|
||||||
5. **Parallel classification** - Concurrent brand and body type classification
|
|
||||||
6. **Branch synchronization** - Wait for all classification branches to complete
|
|
||||||
7. **Database update** - Combined update with all classification results
|
|
||||||
|
|
||||||
### Enhanced Branch Configuration
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"modelId": "car_frontal_detection_v1",
|
|
||||||
"modelFile": "car_frontal_detection_v1.pt",
|
|
||||||
"multiClass": true,
|
|
||||||
"expectedClasses": ["Car", "Frontal"],
|
|
||||||
"triggerClasses": ["Car", "Frontal"],
|
|
||||||
"minConfidence": 0.8,
|
|
||||||
"actions": [
|
|
||||||
{
|
|
||||||
"type": "redis_save_image",
|
|
||||||
"region": "Frontal",
|
|
||||||
"key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
|
|
||||||
"expire_seconds": 600
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"branches": [
|
|
||||||
{
|
|
||||||
"modelId": "car_brand_cls_v1",
|
|
||||||
"modelFile": "car_brand_cls_v1.pt",
|
|
||||||
"parallel": true,
|
|
||||||
"crop": true,
|
|
||||||
"cropClass": "Frontal",
|
|
||||||
"triggerClasses": ["Frontal"],
|
|
||||||
"minConfidence": 0.85
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"parallelActions": [
|
|
||||||
{
|
|
||||||
"type": "postgresql_update_combined",
|
|
||||||
"table": "car_frontal_info",
|
|
||||||
"key_field": "session_id",
|
|
||||||
"waitForBranches": ["car_brand_cls_v1", "car_bodytype_cls_v1"],
|
|
||||||
"fields": {
|
|
||||||
"car_brand": "{car_brand_cls_v1.brand}",
|
|
||||||
"car_body_type": "{car_bodytype_cls_v1.body_type}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Stream Management
|
|
||||||
|
|
||||||
### Shared Streams
|
|
||||||
- Multiple subscriptions can share the same camera URL
|
|
||||||
- Reference counting prevents premature stream termination
|
|
||||||
- Automatic cleanup when last subscription ends
|
|
||||||
|
|
||||||
### Frame Processing
|
|
||||||
- **Queue-based buffering** with single frame capacity (latest frame only)
|
|
||||||
- **Configurable polling interval** based on target FPS
|
|
||||||
- **Automatic reconnection** with exponential backoff
|
|
||||||
|
|
||||||
## Development & Testing
|
|
||||||
|
|
||||||
### Local Development
|
|
||||||
```bash
|
|
||||||
# Install dependencies
|
|
||||||
pip install -r requirements.txt
|
|
||||||
|
|
||||||
# Run the worker
|
|
||||||
python app.py
|
|
||||||
|
|
||||||
# Test protocol compliance
|
|
||||||
python test_protocol.py
|
|
||||||
|
|
||||||
# Test pipeline with webcam
|
|
||||||
python pipeline_webcam.py --mpta-file path/to/model.mpta --video 0
|
|
||||||
```
|
|
||||||
|
|
||||||
### Docker Deployment
|
|
||||||
```bash
|
|
||||||
# Build container
|
|
||||||
docker build -t detector-worker .
|
|
||||||
|
|
||||||
# Run with volume mounts for models
|
|
||||||
docker run -p 8000:8000 -v ./models:/app/models detector-worker
|
|
||||||
```
|
|
||||||
|
|
||||||
### Testing Commands
|
|
||||||
- **Protocol testing**: `python test_protocol.py`
|
|
||||||
- **Pipeline validation**: `python pipeline_webcam.py --mpta-file <path> --video 0`
|
|
||||||
- **RTSP debugging**: `python debug/rtsp_webcam.py`
|
|
||||||
|
|
||||||
## Dependencies
|
|
||||||
- **fastapi[standard]**: Web framework with WebSocket support
|
|
||||||
- **uvicorn**: ASGI server
|
|
||||||
- **torch, torchvision**: PyTorch for ML inference
|
|
||||||
- **ultralytics**: YOLO implementation
|
|
||||||
- **opencv-python**: Computer vision operations
|
|
||||||
- **websockets**: WebSocket client/server
|
|
||||||
- **redis**: Redis client for action execution
|
|
||||||
- **psycopg2-binary**: PostgreSQL database adapter
|
|
||||||
- **scipy**: Scientific computing for advanced algorithms
|
|
||||||
- **filterpy**: Kalman filtering and state estimation
|
|
||||||
|
|
||||||
## Security Considerations
|
|
||||||
- Model files are loaded from trusted sources only
|
|
||||||
- Redis connections use authentication when configured
|
|
||||||
- WebSocket connections handle disconnects gracefully
|
|
||||||
- Resource usage is monitored to prevent DoS
|
|
||||||
|
|
||||||
## Database Integration
|
|
||||||
|
|
||||||
### Schema Management
|
|
||||||
The system automatically creates and manages PostgreSQL tables:
|
|
||||||
|
|
||||||
```sql
|
|
||||||
CREATE TABLE IF NOT EXISTS gas_station_1.car_frontal_info (
|
|
||||||
display_id VARCHAR(255),
|
|
||||||
captured_timestamp VARCHAR(255),
|
|
||||||
session_id VARCHAR(255) PRIMARY KEY,
|
|
||||||
license_character VARCHAR(255) DEFAULT NULL,
|
|
||||||
license_type VARCHAR(255) DEFAULT 'No model available',
|
|
||||||
car_brand VARCHAR(255) DEFAULT NULL,
|
|
||||||
car_model VARCHAR(255) DEFAULT NULL,
|
|
||||||
car_body_type VARCHAR(255) DEFAULT NULL,
|
|
||||||
created_at TIMESTAMP DEFAULT NOW(),
|
|
||||||
updated_at TIMESTAMP DEFAULT NOW()
|
|
||||||
);
|
|
||||||
```
|
|
||||||
|
|
||||||
### Workflow
|
|
||||||
1. **Detection**: When both "Car" and "Frontal" are detected, create initial database record with UUID session_id
|
|
||||||
2. **Redis Storage**: Save cropped frontal image to Redis with session_id in key
|
|
||||||
3. **Parallel Processing**: Run brand and body type classification concurrently
|
|
||||||
4. **Synchronization**: Wait for all branches to complete using `waitForBranches`
|
|
||||||
5. **Database Update**: Update record with combined classification results using field mapping
|
|
||||||
|
|
||||||
### Field Mapping
|
|
||||||
Templates like `{car_brand_cls_v1.brand}` are resolved to actual classification results:
|
|
||||||
- `car_brand_cls_v1.brand` → "Honda"
|
|
||||||
- `car_bodytype_cls_v1.body_type` → "Sedan"
|
|
||||||
|
|
||||||
## Performance Optimizations
|
|
||||||
- GPU acceleration when CUDA is available
|
|
||||||
- Shared camera streams reduce resource usage
|
|
||||||
- Frame queue optimization (single latest frame)
|
|
||||||
- Model caching across subscriptions
|
|
||||||
- Trigger class filtering for faster inference
|
|
||||||
- Parallel processing with ThreadPoolExecutor for classification branches
|
|
||||||
- Multi-class detection reduces inference passes
|
|
||||||
- Region-based cropping minimizes processing overhead
|
|
||||||
- Database connection pooling and prepared statements
|
|
||||||
- Redis image storage with automatic expiration
|
|
16
Dockerfile
16
Dockerfile
|
@ -1,11 +1,19 @@
|
||||||
# Use our pre-built base image with ML dependencies
|
# Use the official Python image from the Docker Hub
|
||||||
FROM git.siwatsystem.com/adsist-cms/worker-base:latest
|
FROM python:3.13-bookworm
|
||||||
|
|
||||||
# Copy and install application requirements (frequently changing dependencies)
|
# Set the working directory in the container
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy the requirements file into the container at /app
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
|
|
||||||
|
# Update apt, install libgl1, and clear apt cache
|
||||||
|
RUN apt update && apt install -y libgl1 && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install any dependencies specified in requirements.txt
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
# Copy the application code
|
# Copy the rest of the application code into the container at /app
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
# Run the application
|
# Run the application
|
||||||
|
|
|
@ -1,15 +0,0 @@
|
||||||
# Base image with all ML dependencies
|
|
||||||
FROM python:3.13-bookworm
|
|
||||||
|
|
||||||
# Install system dependencies
|
|
||||||
RUN apt update && apt install -y libgl1 && rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Copy and install base requirements (ML dependencies that rarely change)
|
|
||||||
COPY requirements.base.txt .
|
|
||||||
RUN pip install --no-cache-dir -r requirements.base.txt
|
|
||||||
|
|
||||||
# Set working directory
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
# This base image will be reused for all worker builds
|
|
||||||
CMD ["python3", "-m", "fastapi", "run", "--host", "0.0.0.0", "--port", "8000"]
|
|
366
app_single.py
Normal file
366
app_single.py
Normal file
|
@ -0,0 +1,366 @@
|
||||||
|
from typing import List
|
||||||
|
from fastapi import FastAPI, WebSocket
|
||||||
|
from fastapi.websockets import WebSocketDisconnect
|
||||||
|
from websockets.exceptions import ConnectionClosedError
|
||||||
|
from ultralytics import YOLO
|
||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
import base64
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import queue
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import asyncio
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
models = {}
|
||||||
|
|
||||||
|
with open("config.json", "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
poll_interval = config.get("poll_interval_ms", 100)
|
||||||
|
reconnect_interval = config.get("reconnect_interval_sec", 5)
|
||||||
|
TARGET_FPS = config.get("target_fps", 10)
|
||||||
|
poll_interval = 1000 / TARGET_FPS
|
||||||
|
logging.info(f"Poll interval: {poll_interval}ms")
|
||||||
|
max_streams = config.get("max_streams", 5)
|
||||||
|
max_retries = config.get("max_retries", 3)
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler("app.log"),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure the models directory exists
|
||||||
|
os.makedirs("models", exist_ok=True)
|
||||||
|
|
||||||
|
# Add constants for heartbeat
|
||||||
|
HEARTBEAT_INTERVAL = 2 # seconds
|
||||||
|
WORKER_TIMEOUT_MS = 10000
|
||||||
|
|
||||||
|
# Add a lock for thread-safe operations on shared resources
|
||||||
|
streams_lock = threading.Lock()
|
||||||
|
models_lock = threading.Lock()
|
||||||
|
|
||||||
|
@app.websocket("/")
|
||||||
|
async def detect(websocket: WebSocket):
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
logging.info("WebSocket connection accepted")
|
||||||
|
|
||||||
|
streams = {}
|
||||||
|
|
||||||
|
# This function is user-modifiable
|
||||||
|
# Save data you want to persist across frames in the persistent_data dictionary
|
||||||
|
async def handle_detection(camera_id, stream, frame, websocket, model: YOLO, persistent_data):
|
||||||
|
try:
|
||||||
|
highest_conf_box = None
|
||||||
|
max_conf = -1
|
||||||
|
|
||||||
|
for r in model.track(frame, stream=False, persist=True):
|
||||||
|
for box in r.boxes:
|
||||||
|
box_cpu = box.cpu()
|
||||||
|
conf = float(box_cpu.conf[0])
|
||||||
|
if conf > max_conf and hasattr(box, "id") and box.id is not None:
|
||||||
|
max_conf = conf
|
||||||
|
highest_conf_box = {
|
||||||
|
"class": model.names[int(box_cpu.cls[0])],
|
||||||
|
"confidence": conf,
|
||||||
|
"id": box.id.item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Broadcast to all subscribers of this URL
|
||||||
|
detection_data = {
|
||||||
|
"type": "imageDetection",
|
||||||
|
"cameraIdentifier": camera_id,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"data": {
|
||||||
|
"detections": highest_conf_box if highest_conf_box else None,
|
||||||
|
"modelId": stream['modelId'],
|
||||||
|
"modelName": stream['modelName']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logging.debug(f"Sending detection data for camera {camera_id}: {detection_data}")
|
||||||
|
await websocket.send_json(detection_data)
|
||||||
|
return persistent_data
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error in handle_detection for camera {camera_id}: {e}")
|
||||||
|
return persistent_data
|
||||||
|
|
||||||
|
def frame_reader(camera_id, cap, buffer, stop_event):
|
||||||
|
import time
|
||||||
|
retries = 0
|
||||||
|
try:
|
||||||
|
while not stop_event.is_set():
|
||||||
|
try:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
logging.warning(f"Connection lost for camera: {camera_id}, retry {retries+1}/{max_retries}")
|
||||||
|
cap.release()
|
||||||
|
time.sleep(reconnect_interval)
|
||||||
|
retries += 1
|
||||||
|
if retries > max_retries and max_retries != -1:
|
||||||
|
logging.error(f"Max retries reached for camera: {camera_id}")
|
||||||
|
break
|
||||||
|
# Re-open the VideoCapture
|
||||||
|
cap = cv2.VideoCapture(streams[camera_id]['rtsp_url'])
|
||||||
|
if not cap.isOpened():
|
||||||
|
logging.error(f"Failed to reopen RTSP stream for camera: {camera_id}")
|
||||||
|
continue
|
||||||
|
continue
|
||||||
|
retries = 0 # Reset on success
|
||||||
|
if not buffer.empty():
|
||||||
|
try:
|
||||||
|
buffer.get_nowait() # Discard the old frame
|
||||||
|
except queue.Empty:
|
||||||
|
pass
|
||||||
|
buffer.put(frame)
|
||||||
|
except cv2.error as e:
|
||||||
|
logging.error(f"OpenCV error for camera {camera_id}: {e}")
|
||||||
|
cap.release()
|
||||||
|
time.sleep(reconnect_interval)
|
||||||
|
retries += 1
|
||||||
|
if retries > max_retries and max_retries != -1:
|
||||||
|
logging.error(f"Max retries reached after OpenCV error for camera: {camera_id}")
|
||||||
|
break
|
||||||
|
# Re-open the VideoCapture
|
||||||
|
cap = cv2.VideoCapture(streams[camera_id]['rtsp_url'])
|
||||||
|
if not cap.isOpened():
|
||||||
|
logging.error(f"Failed to reopen RTSP stream for camera {camera_id} after OpenCV error")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Unexpected error for camera {camera_id}: {e}")
|
||||||
|
cap.release()
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error in frame_reader thread for camera {camera_id}: {e}")
|
||||||
|
|
||||||
|
async def process_streams():
|
||||||
|
global models
|
||||||
|
logging.info("Started processing streams")
|
||||||
|
persistent_data_dict = {}
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
start_time = time.time()
|
||||||
|
# Round-robin processing
|
||||||
|
with streams_lock:
|
||||||
|
current_streams = list(streams.items())
|
||||||
|
for camera_id, stream in current_streams:
|
||||||
|
buffer = stream['buffer']
|
||||||
|
if not buffer.empty():
|
||||||
|
frame = buffer.get()
|
||||||
|
with models_lock:
|
||||||
|
model = models.get(camera_id, {}).get(stream['modelId'])
|
||||||
|
key = (camera_id, stream['modelId'])
|
||||||
|
persistent_data = persistent_data_dict.get(key, {})
|
||||||
|
updated_persistent_data = await handle_detection(camera_id, stream, frame, websocket, model, persistent_data)
|
||||||
|
persistent_data_dict[key] = updated_persistent_data
|
||||||
|
elapsed_time = (time.time() - start_time) * 1000 # in ms
|
||||||
|
sleep_time = max(poll_interval - elapsed_time, 0)
|
||||||
|
logging.debug(f"Elapsed time: {elapsed_time}ms, sleeping for: {sleep_time}ms")
|
||||||
|
await asyncio.sleep(sleep_time / 1000.0)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logging.info("Stream processing task cancelled")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error in process_streams: {e}")
|
||||||
|
|
||||||
|
async def send_heartbeat():
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
cpu_usage = psutil.cpu_percent()
|
||||||
|
memory_usage = psutil.virtual_memory().percent
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # Convert to MB
|
||||||
|
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB
|
||||||
|
else:
|
||||||
|
gpu_usage = None
|
||||||
|
gpu_memory_usage = None
|
||||||
|
|
||||||
|
camera_connections = [
|
||||||
|
{
|
||||||
|
"cameraIdentifier": camera_id,
|
||||||
|
"modelId": stream['modelId'],
|
||||||
|
"modelName": stream['modelName'],
|
||||||
|
"online": True
|
||||||
|
}
|
||||||
|
for camera_id, stream in streams.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
state_report = {
|
||||||
|
"type": "stateReport",
|
||||||
|
"cpuUsage": cpu_usage,
|
||||||
|
"memoryUsage": memory_usage,
|
||||||
|
"gpuUsage": gpu_usage,
|
||||||
|
"gpuMemoryUsage": gpu_memory_usage,
|
||||||
|
"cameraConnections": camera_connections
|
||||||
|
}
|
||||||
|
await websocket.send_text(json.dumps(state_report))
|
||||||
|
logging.debug("Sent stateReport as heartbeat")
|
||||||
|
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error sending stateReport heartbeat: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
async def on_message():
|
||||||
|
global models
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
msg = await websocket.receive_text()
|
||||||
|
logging.debug(f"Received message: {msg}")
|
||||||
|
print(f"Received message: {msg}")
|
||||||
|
data = json.loads(msg)
|
||||||
|
msg_type = data.get("type")
|
||||||
|
|
||||||
|
if msg_type == "subscribe":
|
||||||
|
payload = data.get("payload", {})
|
||||||
|
camera_id = payload.get("cameraIdentifier")
|
||||||
|
rtsp_url = payload.get("rtspUrl")
|
||||||
|
model_url = payload.get("modelUrl")
|
||||||
|
modelId = payload.get("modelId")
|
||||||
|
modelName = payload.get("modelName")
|
||||||
|
|
||||||
|
if model_url:
|
||||||
|
with models_lock:
|
||||||
|
if camera_id not in models:
|
||||||
|
models[camera_id] = {}
|
||||||
|
if modelId not in models[camera_id]:
|
||||||
|
print(f"Downloading model from {model_url}")
|
||||||
|
parsed_url = urlparse(model_url)
|
||||||
|
filename = os.path.basename(parsed_url.path)
|
||||||
|
model_filename = os.path.join("models", filename)
|
||||||
|
# Download the model
|
||||||
|
response = requests.get(model_url, stream=True)
|
||||||
|
if response.status_code == 200:
|
||||||
|
with open(model_filename, 'wb') as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
logging.info(f"Downloaded model from {model_url} to {model_filename}")
|
||||||
|
model = YOLO(model_filename)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
model.to('cuda')
|
||||||
|
models[camera_id][modelId] = model
|
||||||
|
logging.info(f"Loaded model {modelId} for camera {camera_id}")
|
||||||
|
else:
|
||||||
|
logging.error(f"Failed to download model from {model_url}")
|
||||||
|
continue
|
||||||
|
if camera_id and rtsp_url:
|
||||||
|
with streams_lock:
|
||||||
|
if camera_id not in streams and len(streams) < max_streams:
|
||||||
|
cap = cv2.VideoCapture(rtsp_url)
|
||||||
|
if not cap.isOpened():
|
||||||
|
logging.error(f"Failed to open RTSP stream for camera {camera_id}")
|
||||||
|
continue
|
||||||
|
buffer = queue.Queue(maxsize=1)
|
||||||
|
stop_event = threading.Event()
|
||||||
|
thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event))
|
||||||
|
thread.daemon = True
|
||||||
|
thread.start()
|
||||||
|
streams[camera_id] = {
|
||||||
|
'cap': cap,
|
||||||
|
'buffer': buffer,
|
||||||
|
'thread': thread,
|
||||||
|
'rtsp_url': rtsp_url,
|
||||||
|
'stop_event': stop_event,
|
||||||
|
'modelId': modelId,
|
||||||
|
'modelName': modelName
|
||||||
|
}
|
||||||
|
logging.info(f"Subscribed to camera {camera_id} with modelId {modelId}, modelName {modelName} and URL {rtsp_url}")
|
||||||
|
elif camera_id and camera_id in streams:
|
||||||
|
stream = streams.pop(camera_id)
|
||||||
|
stream['cap'].release()
|
||||||
|
logging.info(f"Unsubscribed from camera {camera_id}")
|
||||||
|
if camera_id in models and modelId in models[camera_id]:
|
||||||
|
del models[camera_id][modelId]
|
||||||
|
if not models[camera_id]:
|
||||||
|
del models[camera_id]
|
||||||
|
elif msg_type == "unsubscribe":
|
||||||
|
payload = data.get("payload", {})
|
||||||
|
camera_id = payload.get("cameraIdentifier")
|
||||||
|
logging.debug(f"Unsubscribing from camera {camera_id}")
|
||||||
|
with streams_lock:
|
||||||
|
if camera_id and camera_id in streams:
|
||||||
|
stream = streams.pop(camera_id)
|
||||||
|
stream['stop_event'].set()
|
||||||
|
stream['thread'].join()
|
||||||
|
stream['cap'].release()
|
||||||
|
logging.info(f"Unsubscribed from camera {camera_id}")
|
||||||
|
if camera_id in models and modelId in models[camera_id]:
|
||||||
|
del models[camera_id][modelId]
|
||||||
|
if not models[camera_id]:
|
||||||
|
del models[camera_id]
|
||||||
|
elif msg_type == "requestState":
|
||||||
|
# Handle state request
|
||||||
|
cpu_usage = psutil.cpu_percent()
|
||||||
|
memory_usage = psutil.virtual_memory().percent
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # Convert to MB
|
||||||
|
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB
|
||||||
|
else:
|
||||||
|
gpu_usage = None
|
||||||
|
gpu_memory_usage = None
|
||||||
|
|
||||||
|
camera_connections = [
|
||||||
|
{
|
||||||
|
"cameraIdentifier": camera_id,
|
||||||
|
"modelId": stream['modelId'],
|
||||||
|
"modelName": stream['modelName'],
|
||||||
|
"online": True
|
||||||
|
}
|
||||||
|
for camera_id, stream in streams.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
state_report = {
|
||||||
|
"type": "stateReport",
|
||||||
|
"cpuUsage": cpu_usage,
|
||||||
|
"memoryUsage": memory_usage,
|
||||||
|
"gpuUsage": gpu_usage,
|
||||||
|
"gpuMemoryUsage": gpu_memory_usage,
|
||||||
|
"cameraConnections": camera_connections
|
||||||
|
}
|
||||||
|
await websocket.send_text(json.dumps(state_report))
|
||||||
|
else:
|
||||||
|
logging.error(f"Unknown message type: {msg_type}")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logging.error("Received invalid JSON message")
|
||||||
|
except (WebSocketDisconnect, ConnectionClosedError) as e:
|
||||||
|
logging.warning(f"WebSocket disconnected: {e}")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error handling message: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
await websocket.accept()
|
||||||
|
task = asyncio.create_task(process_streams())
|
||||||
|
heartbeat_task = asyncio.create_task(send_heartbeat())
|
||||||
|
message_task = asyncio.create_task(on_message())
|
||||||
|
|
||||||
|
await asyncio.gather(heartbeat_task, message_task)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error in detect websocket: {e}")
|
||||||
|
finally:
|
||||||
|
task.cancel()
|
||||||
|
await task
|
||||||
|
with streams_lock:
|
||||||
|
for camera_id, stream in streams.items():
|
||||||
|
stream['stop_event'].set()
|
||||||
|
stream['thread'].join()
|
||||||
|
stream['cap'].release()
|
||||||
|
stream['buffer'].queue.clear()
|
||||||
|
logging.info(f"Released camera {camera_id} and cleaned up resources")
|
||||||
|
streams.clear()
|
||||||
|
with models_lock:
|
||||||
|
models.clear()
|
||||||
|
logging.info("WebSocket connection closed")
|
143
debug.py
Normal file
143
debug.py
Normal file
|
@ -0,0 +1,143 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import threading # added threading
|
||||||
|
import yaml # for silencing YOLO
|
||||||
|
|
||||||
|
from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
||||||
|
|
||||||
|
# Silence YOLO logging
|
||||||
|
os.environ["YOLO_VERBOSE"] = "False"
|
||||||
|
for logger_name in ["ultralytics", "ultralytics.hub", "ultralytics.yolo.utils"]:
|
||||||
|
logging.getLogger(logger_name).setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
# Global variables for frame sharing
|
||||||
|
global_frame = None
|
||||||
|
global_ret = False
|
||||||
|
capture_running = False
|
||||||
|
|
||||||
|
def video_capture_loop(cap):
|
||||||
|
global global_frame, global_ret, capture_running
|
||||||
|
while capture_running:
|
||||||
|
global_ret, global_frame = cap.read()
|
||||||
|
time.sleep(0.01) # slight delay to reduce CPU usage
|
||||||
|
|
||||||
|
def clear_cache(cache_dir: str):
|
||||||
|
if os.path.exists(cache_dir):
|
||||||
|
shutil.rmtree(cache_dir)
|
||||||
|
|
||||||
|
def log_pipeline_flow(frame, model_tree, level=0):
|
||||||
|
"""
|
||||||
|
Wrapper around run_pipeline that logs the model flow and detection results.
|
||||||
|
Returns the same output as the original run_pipeline function.
|
||||||
|
"""
|
||||||
|
indent = " " * level
|
||||||
|
model_id = model_tree.get("modelId", "unknown")
|
||||||
|
logging.info(f"{indent}→ Running model: {model_id}")
|
||||||
|
|
||||||
|
detection, bbox = run_pipeline(frame, model_tree, return_bbox=True)
|
||||||
|
|
||||||
|
if detection:
|
||||||
|
confidence = detection.get("confidence", 0) * 100
|
||||||
|
class_name = detection.get("class", "unknown")
|
||||||
|
object_id = detection.get("id", "N/A")
|
||||||
|
|
||||||
|
logging.info(f"{indent}✓ Detected: {class_name} (ID: {object_id}, confidence: {confidence:.1f}%)")
|
||||||
|
|
||||||
|
# Check if any branches were triggered
|
||||||
|
triggered = False
|
||||||
|
for branch in model_tree.get("branches", []):
|
||||||
|
trigger_classes = branch.get("triggerClasses", [])
|
||||||
|
min_conf = branch.get("minConfidence", 0)
|
||||||
|
|
||||||
|
if class_name in trigger_classes and detection.get("confidence", 0) >= min_conf:
|
||||||
|
triggered = True
|
||||||
|
if branch.get("crop", False) and bbox:
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
cropped_frame = frame[y1:y2, x1:x2]
|
||||||
|
logging.info(f"{indent} ⌊ Triggering branch with cropped region {x1},{y1} to {x2},{y2}")
|
||||||
|
branch_result = log_pipeline_flow(cropped_frame, branch, level + 1)
|
||||||
|
else:
|
||||||
|
logging.info(f"{indent} ⌊ Triggering branch with full frame")
|
||||||
|
branch_result = log_pipeline_flow(frame, branch, level + 1)
|
||||||
|
|
||||||
|
if branch_result[0]: # If branch detection successful, return it
|
||||||
|
return branch_result
|
||||||
|
|
||||||
|
if not triggered and model_tree.get("branches"):
|
||||||
|
logging.info(f"{indent} ⌊ No branches triggered")
|
||||||
|
else:
|
||||||
|
logging.info(f"{indent}✗ No detection for {model_id}")
|
||||||
|
|
||||||
|
return detection, bbox
|
||||||
|
|
||||||
|
def main(mpta_file: str, video_source: str):
|
||||||
|
global capture_running
|
||||||
|
CACHE_DIR = os.path.join(".", ".mptacache")
|
||||||
|
clear_cache(CACHE_DIR)
|
||||||
|
logging.info(f"Loading pipeline from local file: {mpta_file}")
|
||||||
|
model_tree = load_pipeline_from_zip(mpta_file, CACHE_DIR)
|
||||||
|
if model_tree is None:
|
||||||
|
logging.error("Failed to load pipeline.")
|
||||||
|
return
|
||||||
|
|
||||||
|
cap = cv2.VideoCapture(video_source)
|
||||||
|
if not cap.isOpened():
|
||||||
|
logging.error(f"Cannot open video source {video_source}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Start video capture in a separate thread
|
||||||
|
capture_running = True
|
||||||
|
capture_thread = threading.Thread(target=video_capture_loop, args=(cap,))
|
||||||
|
capture_thread.start()
|
||||||
|
|
||||||
|
logging.info("Press 'q' to exit.")
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Use the global frame and ret updated by the thread
|
||||||
|
if not global_ret or global_frame is None:
|
||||||
|
continue # wait until a frame is available
|
||||||
|
|
||||||
|
frame = global_frame.copy() # local copy to work with
|
||||||
|
|
||||||
|
# Replace run_pipeline with our logging version
|
||||||
|
detection, bbox = log_pipeline_flow(frame, model_tree)
|
||||||
|
|
||||||
|
# Stop if "honda" is detected
|
||||||
|
if detection and detection.get("class", "").lower() == "toyota":
|
||||||
|
logging.info("Detected 'toyota'. Stopping pipeline.")
|
||||||
|
break
|
||||||
|
|
||||||
|
if bbox:
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||||
|
label = detection["class"] if detection else "Detection"
|
||||||
|
cv2.putText(frame, label, (x1, y1 - 10),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
|
||||||
|
|
||||||
|
cv2.imshow("Pipeline Webcam", frame)
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Stop capture thread and cleanup
|
||||||
|
capture_running = False
|
||||||
|
capture_thread.join()
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
clear_cache(CACHE_DIR)
|
||||||
|
logging.info("Cleaned up .mptacache directory on shutdown.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run pipeline webcam utility.")
|
||||||
|
parser.add_argument("--mpta-file", type=str, required=True, help="Path to the local pipeline mpta (ZIP) file.")
|
||||||
|
parser.add_argument("--video", type=str, default="0", help="Video source (default webcam index 0).")
|
||||||
|
args = parser.parse_args()
|
||||||
|
video_source = int(args.video) if args.video.isdigit() else args.video
|
||||||
|
main(args.mpta_file, video_source)
|
BIN
demoa.mpta
Normal file
BIN
demoa.mpta
Normal file
Binary file not shown.
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
23
pipeline.log
Normal file
23
pipeline.log
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
2025-05-12 18:10:04,590 [INFO] Loading pipeline from local file: demoa.mpta
|
||||||
|
2025-05-12 18:10:04,610 [INFO] Copied local .mpta file from demoa.mpta to .\.mptacache\pipeline.mpta
|
||||||
|
2025-05-12 18:10:04,901 [INFO] Extracted .mpta file to .\.mptacache
|
||||||
|
2025-05-12 18:10:04,905 [INFO] Loading model for node DetectionDraft from .\.mptacache\demoa\DetectionDraft.pt
|
||||||
|
2025-05-12 18:10:05,083 [INFO] Loading model for node ClassificationDraft from .\.mptacache\demoa\ClassificationDraft.pt
|
||||||
|
2025-05-12 18:10:08,035 [INFO] Press 'q' to exit.
|
||||||
|
2025-05-12 18:10:12,217 [INFO] Cleaned up .mptacache directory on shutdown.
|
||||||
|
2025-05-12 18:13:08,465 [INFO] Loading pipeline from local file: demoa.mpta
|
||||||
|
2025-05-12 18:13:08,512 [INFO] Copied local .mpta file from demoa.mpta to .\.mptacache\pipeline.mpta
|
||||||
|
2025-05-12 18:13:08,769 [INFO] Extracted .mpta file to .\.mptacache
|
||||||
|
2025-05-12 18:13:08,773 [INFO] Loading model for node DetectionDraft from .\.mptacache\demoa\DetectionDraft.pt
|
||||||
|
2025-05-12 18:13:09,083 [INFO] Loading model for node ClassificationDraft from .\.mptacache\demoa\ClassificationDraft.pt
|
||||||
|
2025-05-12 18:13:12,187 [INFO] Press 'q' to exit.
|
||||||
|
2025-05-12 18:13:14,146 [INFO] → Running model: DetectionDraft
|
||||||
|
2025-05-12 18:13:17,119 [INFO] Cleaned up .mptacache directory on shutdown.
|
||||||
|
2025-05-12 18:14:25,665 [INFO] Loading pipeline from local file: demoa.mpta
|
||||||
|
2025-05-12 18:14:25,687 [INFO] Copied local .mpta file from demoa.mpta to .\.mptacache\pipeline.mpta
|
||||||
|
2025-05-12 18:14:25,953 [INFO] Extracted .mpta file to .\.mptacache
|
||||||
|
2025-05-12 18:14:25,957 [INFO] Loading model for node DetectionDraft from .\.mptacache\demoa\DetectionDraft.pt
|
||||||
|
2025-05-12 18:14:26,138 [INFO] Loading model for node ClassificationDraft from .\.mptacache\demoa\ClassificationDraft.pt
|
||||||
|
2025-05-12 18:14:29,171 [INFO] Press 'q' to exit.
|
||||||
|
2025-05-12 18:14:30,146 [INFO] → Running model: DetectionDraft
|
||||||
|
2025-05-12 18:14:32,080 [INFO] Cleaned up .mptacache directory on shutdown.
|
327
pympta.md
327
pympta.md
|
@ -1,327 +0,0 @@
|
||||||
# pympta: Modular Pipeline Task Executor
|
|
||||||
|
|
||||||
`pympta` is a Python module designed to load and execute modular, multi-stage AI pipelines defined in a special package format (`.mpta`). It is primarily used within the detector worker to run complex computer vision tasks where the output of one model can trigger a subsequent model on a specific region of interest.
|
|
||||||
|
|
||||||
## Core Concepts
|
|
||||||
|
|
||||||
### 1. MPTA Package (`.mpta`)
|
|
||||||
|
|
||||||
An `.mpta` file is a standard `.zip` archive with a different extension. It bundles all the necessary components for a pipeline to run.
|
|
||||||
|
|
||||||
A typical `.mpta` file has the following structure:
|
|
||||||
|
|
||||||
```
|
|
||||||
my_pipeline.mpta/
|
|
||||||
├── pipeline.json
|
|
||||||
├── model1.pt
|
|
||||||
├── model2.pt
|
|
||||||
└── ...
|
|
||||||
```
|
|
||||||
|
|
||||||
- **`pipeline.json`**: (Required) The manifest file that defines the structure of the pipeline, the models to use, and the logic connecting them.
|
|
||||||
- **Model Files (`.pt`, etc.)**: The actual pre-trained model files (e.g., PyTorch, ONNX). The pipeline currently uses `ultralytics.YOLO` models.
|
|
||||||
|
|
||||||
### 2. Pipeline Structure
|
|
||||||
|
|
||||||
A pipeline is a tree-like structure of "nodes," defined in `pipeline.json`.
|
|
||||||
|
|
||||||
- **Root Node**: The entry point of the pipeline. It processes the initial, full-frame image.
|
|
||||||
- **Branch Nodes**: Child nodes that are triggered by specific detection results from their parent. For example, a root node might detect a "vehicle," which then triggers a branch node to detect a "license plate" within the vehicle's bounding box.
|
|
||||||
|
|
||||||
This modular structure allows for creating complex and efficient inference logic, avoiding the need to run every model on every frame.
|
|
||||||
|
|
||||||
## `pipeline.json` Specification
|
|
||||||
|
|
||||||
This file defines the entire pipeline logic. The root object contains a `pipeline` key for the pipeline definition, optional `redis` key for Redis configuration, and optional `postgresql` key for database integration.
|
|
||||||
|
|
||||||
### Top-Level Object Structure
|
|
||||||
|
|
||||||
| Key | Type | Required | Description |
|
|
||||||
| ------------ | ------ | -------- | ------------------------------------------------------- |
|
|
||||||
| `pipeline` | Object | Yes | The root node object of the pipeline. |
|
|
||||||
| `redis` | Object | No | Configuration for connecting to a Redis server. |
|
|
||||||
| `postgresql` | Object | No | Configuration for connecting to a PostgreSQL database. |
|
|
||||||
|
|
||||||
### Redis Configuration (`redis`)
|
|
||||||
|
|
||||||
| Key | Type | Required | Description |
|
|
||||||
| ---------- | ------ | -------- | ------------------------------------------------------- |
|
|
||||||
| `host` | String | Yes | The hostname or IP address of the Redis server. |
|
|
||||||
| `port` | Number | Yes | The port number of the Redis server. |
|
|
||||||
| `password` | String | No | The password for Redis authentication. |
|
|
||||||
| `db` | Number | No | The Redis database number to use. Defaults to `0`. |
|
|
||||||
|
|
||||||
### PostgreSQL Configuration (`postgresql`)
|
|
||||||
|
|
||||||
| Key | Type | Required | Description |
|
|
||||||
| ---------- | ------ | -------- | ------------------------------------------------------- |
|
|
||||||
| `host` | String | Yes | The hostname or IP address of the PostgreSQL server. |
|
|
||||||
| `port` | Number | Yes | The port number of the PostgreSQL server. |
|
|
||||||
| `database` | String | Yes | The database name to connect to. |
|
|
||||||
| `username` | String | Yes | The username for database authentication. |
|
|
||||||
| `password` | String | Yes | The password for database authentication. |
|
|
||||||
|
|
||||||
### Node Object Structure
|
|
||||||
|
|
||||||
| Key | Type | Required | Description |
|
|
||||||
| ------------------- | ------------- | -------- | -------------------------------------------------------------------------------------------------------------------------------------- |
|
|
||||||
| `modelId` | String | Yes | A unique identifier for this model node (e.g., "vehicle-detector"). |
|
|
||||||
| `modelFile` | String | Yes | The path to the model file within the `.mpta` archive (e.g., "yolov8n.pt"). |
|
|
||||||
| `minConfidence` | Float | Yes | The minimum confidence score (0.0 to 1.0) required for a detection to be considered valid and potentially trigger a branch. |
|
|
||||||
| `triggerClasses` | Array<String> | Yes | A list of class names that, when detected by the parent, can trigger this node. For the root node, this lists all classes of interest. |
|
|
||||||
| `crop` | Boolean | No | If `true`, the image is cropped to the parent's detection bounding box before being passed to this node's model. Defaults to `false`. |
|
|
||||||
| `cropClass` | String | No | The specific class to use for cropping (e.g., "Frontal" for frontal view cropping). |
|
|
||||||
| `multiClass` | Boolean | No | If `true`, enables multi-class detection mode where multiple classes can be detected simultaneously. |
|
|
||||||
| `expectedClasses` | Array<String> | No | When `multiClass` is true, defines which classes are expected. At least one must be detected for processing to continue. |
|
|
||||||
| `parallel` | Boolean | No | If `true`, this branch will be processed in parallel with other parallel branches. |
|
|
||||||
| `branches` | Array<Node> | No | A list of child node objects that can be triggered by this node's detections. |
|
|
||||||
| `actions` | Array<Action> | No | A list of actions to execute upon a successful detection in this node. |
|
|
||||||
| `parallelActions` | Array<Action> | No | A list of actions to execute after all specified branches have completed. |
|
|
||||||
|
|
||||||
### Action Object Structure
|
|
||||||
|
|
||||||
Actions allow the pipeline to interact with Redis and PostgreSQL databases. They are executed sequentially for a given detection.
|
|
||||||
|
|
||||||
#### Action Context & Dynamic Keys
|
|
||||||
|
|
||||||
All actions have access to a dynamic context for formatting keys and messages. The context is created for each detection event and includes:
|
|
||||||
|
|
||||||
- All key-value pairs from the detection result (e.g., `class`, `confidence`, `id`).
|
|
||||||
- `{timestamp_ms}`: The current Unix timestamp in milliseconds.
|
|
||||||
- `{timestamp}`: Formatted timestamp string (YYYY-MM-DDTHH-MM-SS).
|
|
||||||
- `{uuid}`: A unique identifier (UUID4) for the detection event.
|
|
||||||
- `{filename}`: Generated filename with UUID.
|
|
||||||
- `{camera_id}`: Full camera subscription identifier.
|
|
||||||
- `{display_id}`: Display identifier extracted from subscription.
|
|
||||||
- `{session_id}`: Session ID for database operations.
|
|
||||||
- `{image_key}`: If a `redis_save_image` action has already been executed for this event, this placeholder will be replaced with the key where the image was stored.
|
|
||||||
|
|
||||||
#### `redis_save_image`
|
|
||||||
|
|
||||||
Saves the current image frame (or cropped sub-image) to a Redis key.
|
|
||||||
|
|
||||||
| Key | Type | Required | Description |
|
|
||||||
| ---------------- | ------ | -------- | ------------------------------------------------------------------------------------------------------- |
|
|
||||||
| `type` | String | Yes | Must be `"redis_save_image"`. |
|
|
||||||
| `key` | String | Yes | The Redis key to save the image to. Can contain any of the dynamic placeholders. |
|
|
||||||
| `region` | String | No | Specific detected region to crop and save (e.g., "Frontal"). |
|
|
||||||
| `format` | String | No | Image format: "jpeg" or "png". Defaults to "jpeg". |
|
|
||||||
| `quality` | Number | No | JPEG quality (1-100). Defaults to 90. |
|
|
||||||
| `expire_seconds` | Number | No | If provided, sets an expiration time (in seconds) for the Redis key. |
|
|
||||||
|
|
||||||
#### `redis_publish`
|
|
||||||
|
|
||||||
Publishes a message to a Redis channel.
|
|
||||||
|
|
||||||
| Key | Type | Required | Description |
|
|
||||||
| --------- | ------ | -------- | ------------------------------------------------------------------------------------------------------- |
|
|
||||||
| `type` | String | Yes | Must be `"redis_publish"`. |
|
|
||||||
| `channel` | String | Yes | The Redis channel to publish the message to. |
|
|
||||||
| `message` | String | Yes | The message to publish. Can contain any of the dynamic placeholders, including `{image_key}`. |
|
|
||||||
|
|
||||||
#### `postgresql_update_combined`
|
|
||||||
|
|
||||||
Updates PostgreSQL database with results from multiple branches after they complete.
|
|
||||||
|
|
||||||
| Key | Type | Required | Description |
|
|
||||||
| ------------------ | ------------- | -------- | ------------------------------------------------------------------------------------------------------- |
|
|
||||||
| `type` | String | Yes | Must be `"postgresql_update_combined"`. |
|
|
||||||
| `table` | String | Yes | The database table name (will be prefixed with `gas_station_1.` schema). |
|
|
||||||
| `key_field` | String | Yes | The field to use as the update key (typically "session_id"). |
|
|
||||||
| `key_value` | String | Yes | Template for the key value (e.g., "{session_id}"). |
|
|
||||||
| `waitForBranches` | Array<String> | Yes | List of branch model IDs to wait for completion before executing update. |
|
|
||||||
| `fields` | Object | Yes | Field mapping object where keys are database columns and values are templates (e.g., "{branch.field}").|
|
|
||||||
|
|
||||||
### Complete Example `pipeline.json`
|
|
||||||
|
|
||||||
This example demonstrates a comprehensive pipeline for vehicle detection with parallel classification and database integration:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"redis": {
|
|
||||||
"host": "10.100.1.3",
|
|
||||||
"port": 6379,
|
|
||||||
"password": "your-redis-password",
|
|
||||||
"db": 0
|
|
||||||
},
|
|
||||||
"postgresql": {
|
|
||||||
"host": "10.100.1.3",
|
|
||||||
"port": 5432,
|
|
||||||
"database": "inference",
|
|
||||||
"username": "root",
|
|
||||||
"password": "your-db-password"
|
|
||||||
},
|
|
||||||
"pipeline": {
|
|
||||||
"modelId": "car_frontal_detection_v1",
|
|
||||||
"modelFile": "car_frontal_detection_v1.pt",
|
|
||||||
"crop": false,
|
|
||||||
"triggerClasses": ["Car", "Frontal"],
|
|
||||||
"minConfidence": 0.8,
|
|
||||||
"multiClass": true,
|
|
||||||
"expectedClasses": ["Car", "Frontal"],
|
|
||||||
"actions": [
|
|
||||||
{
|
|
||||||
"type": "redis_save_image",
|
|
||||||
"region": "Frontal",
|
|
||||||
"key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
|
|
||||||
"expire_seconds": 600,
|
|
||||||
"format": "jpeg",
|
|
||||||
"quality": 90
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "redis_publish",
|
|
||||||
"channel": "car_detections",
|
|
||||||
"message": "{\"event\":\"frontal_detected\"}"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"branches": [
|
|
||||||
{
|
|
||||||
"modelId": "car_brand_cls_v1",
|
|
||||||
"modelFile": "car_brand_cls_v1.pt",
|
|
||||||
"crop": true,
|
|
||||||
"cropClass": "Frontal",
|
|
||||||
"resizeTarget": [224, 224],
|
|
||||||
"triggerClasses": ["Frontal"],
|
|
||||||
"minConfidence": 0.85,
|
|
||||||
"parallel": true,
|
|
||||||
"branches": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"modelId": "car_bodytype_cls_v1",
|
|
||||||
"modelFile": "car_bodytype_cls_v1.pt",
|
|
||||||
"crop": true,
|
|
||||||
"cropClass": "Car",
|
|
||||||
"resizeTarget": [224, 224],
|
|
||||||
"triggerClasses": ["Car"],
|
|
||||||
"minConfidence": 0.85,
|
|
||||||
"parallel": true,
|
|
||||||
"branches": []
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"parallelActions": [
|
|
||||||
{
|
|
||||||
"type": "postgresql_update_combined",
|
|
||||||
"table": "car_frontal_info",
|
|
||||||
"key_field": "session_id",
|
|
||||||
"key_value": "{session_id}",
|
|
||||||
"waitForBranches": ["car_brand_cls_v1", "car_bodytype_cls_v1"],
|
|
||||||
"fields": {
|
|
||||||
"car_brand": "{car_brand_cls_v1.brand}",
|
|
||||||
"car_body_type": "{car_bodytype_cls_v1.body_type}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## API Reference
|
|
||||||
|
|
||||||
The `pympta` module exposes two main functions.
|
|
||||||
|
|
||||||
### `load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict`
|
|
||||||
|
|
||||||
Loads, extracts, and parses an `.mpta` file to build a pipeline tree in memory. It also establishes Redis and PostgreSQL connections if configured in `pipeline.json`.
|
|
||||||
|
|
||||||
- **Parameters:**
|
|
||||||
- `zip_source` (str): The file path to the local `.mpta` zip archive.
|
|
||||||
- `target_dir` (str): A directory path where the archive's contents will be extracted.
|
|
||||||
- **Returns:**
|
|
||||||
- A dictionary representing the root node of the pipeline, ready to be used with `run_pipeline`. Returns `None` if loading fails.
|
|
||||||
|
|
||||||
### `run_pipeline(frame, node: dict, return_bbox: bool = False, context: dict = None)`
|
|
||||||
|
|
||||||
Executes the inference pipeline on a single image frame.
|
|
||||||
|
|
||||||
- **Parameters:**
|
|
||||||
- `frame`: The input image frame (e.g., a NumPy array from OpenCV).
|
|
||||||
- `node` (dict): The pipeline node to execute (typically the root node returned by `load_pipeline_from_zip`).
|
|
||||||
- `return_bbox` (bool): If `True`, the function returns a tuple `(detection, bounding_box)`. Otherwise, it returns only the `detection`.
|
|
||||||
- `context` (dict): Optional context dictionary containing camera_id, display_id, session_id for action formatting.
|
|
||||||
- **Returns:**
|
|
||||||
- The final detection result from the last executed node in the chain. A detection is a dictionary like `{'class': 'car', 'confidence': 0.95, 'id': 1}`. If no detection meets the criteria, it returns `None` (or `(None, None)` if `return_bbox` is `True`).
|
|
||||||
|
|
||||||
## Database Integration
|
|
||||||
|
|
||||||
The pipeline system includes automatic PostgreSQL database management:
|
|
||||||
|
|
||||||
### Table Schema (`gas_station_1.car_frontal_info`)
|
|
||||||
|
|
||||||
The system automatically creates and manages the following table structure:
|
|
||||||
|
|
||||||
```sql
|
|
||||||
CREATE TABLE IF NOT EXISTS gas_station_1.car_frontal_info (
|
|
||||||
display_id VARCHAR(255),
|
|
||||||
captured_timestamp VARCHAR(255),
|
|
||||||
session_id VARCHAR(255) PRIMARY KEY,
|
|
||||||
license_character VARCHAR(255) DEFAULT NULL,
|
|
||||||
license_type VARCHAR(255) DEFAULT 'No model available',
|
|
||||||
car_brand VARCHAR(255) DEFAULT NULL,
|
|
||||||
car_model VARCHAR(255) DEFAULT NULL,
|
|
||||||
car_body_type VARCHAR(255) DEFAULT NULL,
|
|
||||||
created_at TIMESTAMP DEFAULT NOW(),
|
|
||||||
updated_at TIMESTAMP DEFAULT NOW()
|
|
||||||
);
|
|
||||||
```
|
|
||||||
|
|
||||||
### Workflow
|
|
||||||
|
|
||||||
1. **Initial Record Creation**: When both "Car" and "Frontal" are detected, an initial database record is created with a UUID session_id.
|
|
||||||
2. **Redis Storage**: Vehicle images are stored in Redis with keys containing the session_id.
|
|
||||||
3. **Parallel Classification**: Brand and body type classification run concurrently.
|
|
||||||
4. **Database Update**: After all branches complete, the database record is updated with classification results.
|
|
||||||
|
|
||||||
## Usage Example
|
|
||||||
|
|
||||||
This snippet shows how to use `pympta` with the enhanced features:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import cv2
|
|
||||||
from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline
|
|
||||||
|
|
||||||
# 1. Define paths
|
|
||||||
MPTA_FILE = "path/to/your/pipeline.mpta"
|
|
||||||
CACHE_DIR = ".mptacache"
|
|
||||||
|
|
||||||
# 2. Load the pipeline from the .mpta file
|
|
||||||
# This reads pipeline.json and loads the YOLO models into memory.
|
|
||||||
model_tree = load_pipeline_from_zip(MPTA_FILE, CACHE_DIR)
|
|
||||||
|
|
||||||
if not model_tree:
|
|
||||||
print("Failed to load pipeline.")
|
|
||||||
exit()
|
|
||||||
|
|
||||||
# 3. Open a video source
|
|
||||||
cap = cv2.VideoCapture(0)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
ret, frame = cap.read()
|
|
||||||
if not ret:
|
|
||||||
break
|
|
||||||
|
|
||||||
# 4. Run the pipeline on the current frame with context
|
|
||||||
context = {
|
|
||||||
"camera_id": "display-001;cam-001",
|
|
||||||
"display_id": "display-001",
|
|
||||||
"session_id": None # Will be generated automatically
|
|
||||||
}
|
|
||||||
|
|
||||||
detection_result, bounding_box = run_pipeline(frame, model_tree, return_bbox=True, context=context)
|
|
||||||
|
|
||||||
# 5. Display the results
|
|
||||||
if detection_result:
|
|
||||||
print(f"Detected: {detection_result['class']} with confidence {detection_result['confidence']:.2f}")
|
|
||||||
if bounding_box:
|
|
||||||
x1, y1, x2, y2 = bounding_box
|
|
||||||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
||||||
cv2.putText(frame, detection_result['class'], (x1, y1 - 10),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
|
|
||||||
|
|
||||||
cv2.imshow("Pipeline Output", frame)
|
|
||||||
|
|
||||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
|
||||||
break
|
|
||||||
|
|
||||||
cap.release()
|
|
||||||
cv2.destroyAllWindows()
|
|
||||||
```
|
|
|
@ -1,7 +0,0 @@
|
||||||
torch
|
|
||||||
torchvision
|
|
||||||
ultralytics
|
|
||||||
opencv-python
|
|
||||||
scipy
|
|
||||||
filterpy
|
|
||||||
psycopg2-binary
|
|
|
@ -1,6 +1,8 @@
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn
|
uvicorn
|
||||||
|
torch
|
||||||
|
torchvision
|
||||||
|
ultralytics
|
||||||
|
opencv-python
|
||||||
websockets
|
websockets
|
||||||
fastapi[standard]
|
fastapi[standard]
|
||||||
redis
|
|
||||||
urllib3<2.0.0
|
|
|
@ -1,211 +0,0 @@
|
||||||
import psycopg2
|
|
||||||
import psycopg2.extras
|
|
||||||
from typing import Optional, Dict, Any
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class DatabaseManager:
|
|
||||||
def __init__(self, config: Dict[str, Any]):
|
|
||||||
self.config = config
|
|
||||||
self.connection: Optional[psycopg2.extensions.connection] = None
|
|
||||||
|
|
||||||
def connect(self) -> bool:
|
|
||||||
try:
|
|
||||||
self.connection = psycopg2.connect(
|
|
||||||
host=self.config['host'],
|
|
||||||
port=self.config['port'],
|
|
||||||
database=self.config['database'],
|
|
||||||
user=self.config['username'],
|
|
||||||
password=self.config['password']
|
|
||||||
)
|
|
||||||
logger.info("PostgreSQL connection established successfully")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to connect to PostgreSQL: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def disconnect(self):
|
|
||||||
if self.connection:
|
|
||||||
self.connection.close()
|
|
||||||
self.connection = None
|
|
||||||
logger.info("PostgreSQL connection closed")
|
|
||||||
|
|
||||||
def is_connected(self) -> bool:
|
|
||||||
try:
|
|
||||||
if self.connection and not self.connection.closed:
|
|
||||||
cur = self.connection.cursor()
|
|
||||||
cur.execute("SELECT 1")
|
|
||||||
cur.fetchone()
|
|
||||||
cur.close()
|
|
||||||
return True
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return False
|
|
||||||
|
|
||||||
def update_car_info(self, session_id: str, brand: str, model: str, body_type: str) -> bool:
|
|
||||||
if not self.is_connected():
|
|
||||||
if not self.connect():
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
cur = self.connection.cursor()
|
|
||||||
query = """
|
|
||||||
INSERT INTO car_frontal_info (session_id, car_brand, car_model, car_body_type, updated_at)
|
|
||||||
VALUES (%s, %s, %s, %s, NOW())
|
|
||||||
ON CONFLICT (session_id)
|
|
||||||
DO UPDATE SET
|
|
||||||
car_brand = EXCLUDED.car_brand,
|
|
||||||
car_model = EXCLUDED.car_model,
|
|
||||||
car_body_type = EXCLUDED.car_body_type,
|
|
||||||
updated_at = NOW()
|
|
||||||
"""
|
|
||||||
cur.execute(query, (session_id, brand, model, body_type))
|
|
||||||
self.connection.commit()
|
|
||||||
cur.close()
|
|
||||||
logger.info(f"Updated car info for session {session_id}: {brand} {model} ({body_type})")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to update car info: {e}")
|
|
||||||
if self.connection:
|
|
||||||
self.connection.rollback()
|
|
||||||
return False
|
|
||||||
|
|
||||||
def execute_update(self, table: str, key_field: str, key_value: str, fields: Dict[str, str]) -> bool:
|
|
||||||
if not self.is_connected():
|
|
||||||
if not self.connect():
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
cur = self.connection.cursor()
|
|
||||||
|
|
||||||
# Build the UPDATE query dynamically
|
|
||||||
set_clauses = []
|
|
||||||
values = []
|
|
||||||
|
|
||||||
for field, value in fields.items():
|
|
||||||
if value == "NOW()":
|
|
||||||
set_clauses.append(f"{field} = NOW()")
|
|
||||||
else:
|
|
||||||
set_clauses.append(f"{field} = %s")
|
|
||||||
values.append(value)
|
|
||||||
|
|
||||||
# Add schema prefix if table doesn't already have it
|
|
||||||
full_table_name = table if '.' in table else f"gas_station_1.{table}"
|
|
||||||
|
|
||||||
query = f"""
|
|
||||||
INSERT INTO {full_table_name} ({key_field}, {', '.join(fields.keys())})
|
|
||||||
VALUES (%s, {', '.join(['%s'] * len(fields))})
|
|
||||||
ON CONFLICT ({key_field})
|
|
||||||
DO UPDATE SET {', '.join(set_clauses)}
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Add key_value to the beginning of values list
|
|
||||||
all_values = [key_value] + list(fields.values()) + values
|
|
||||||
|
|
||||||
cur.execute(query, all_values)
|
|
||||||
self.connection.commit()
|
|
||||||
cur.close()
|
|
||||||
logger.info(f"Updated {table} for {key_field}={key_value}")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to execute update on {table}: {e}")
|
|
||||||
if self.connection:
|
|
||||||
self.connection.rollback()
|
|
||||||
return False
|
|
||||||
|
|
||||||
def create_car_frontal_info_table(self) -> bool:
|
|
||||||
"""Create the car_frontal_info table in gas_station_1 schema if it doesn't exist."""
|
|
||||||
if not self.is_connected():
|
|
||||||
if not self.connect():
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
cur = self.connection.cursor()
|
|
||||||
|
|
||||||
# Create schema if it doesn't exist
|
|
||||||
cur.execute("CREATE SCHEMA IF NOT EXISTS gas_station_1")
|
|
||||||
|
|
||||||
# Create table if it doesn't exist
|
|
||||||
create_table_query = """
|
|
||||||
CREATE TABLE IF NOT EXISTS gas_station_1.car_frontal_info (
|
|
||||||
display_id VARCHAR(255),
|
|
||||||
captured_timestamp VARCHAR(255),
|
|
||||||
session_id VARCHAR(255) PRIMARY KEY,
|
|
||||||
license_character VARCHAR(255) DEFAULT NULL,
|
|
||||||
license_type VARCHAR(255) DEFAULT 'No model available',
|
|
||||||
car_brand VARCHAR(255) DEFAULT NULL,
|
|
||||||
car_model VARCHAR(255) DEFAULT NULL,
|
|
||||||
car_body_type VARCHAR(255) DEFAULT NULL,
|
|
||||||
updated_at TIMESTAMP DEFAULT NOW()
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
cur.execute(create_table_query)
|
|
||||||
|
|
||||||
# Add columns if they don't exist (for existing tables)
|
|
||||||
alter_queries = [
|
|
||||||
"ALTER TABLE gas_station_1.car_frontal_info ADD COLUMN IF NOT EXISTS car_brand VARCHAR(255) DEFAULT NULL",
|
|
||||||
"ALTER TABLE gas_station_1.car_frontal_info ADD COLUMN IF NOT EXISTS car_model VARCHAR(255) DEFAULT NULL",
|
|
||||||
"ALTER TABLE gas_station_1.car_frontal_info ADD COLUMN IF NOT EXISTS car_body_type VARCHAR(255) DEFAULT NULL",
|
|
||||||
"ALTER TABLE gas_station_1.car_frontal_info ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP DEFAULT NOW()"
|
|
||||||
]
|
|
||||||
|
|
||||||
for alter_query in alter_queries:
|
|
||||||
try:
|
|
||||||
cur.execute(alter_query)
|
|
||||||
logger.debug(f"Executed: {alter_query}")
|
|
||||||
except Exception as e:
|
|
||||||
# Ignore errors if column already exists (for older PostgreSQL versions)
|
|
||||||
if "already exists" in str(e).lower():
|
|
||||||
logger.debug(f"Column already exists, skipping: {alter_query}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"Error in ALTER TABLE: {e}")
|
|
||||||
|
|
||||||
self.connection.commit()
|
|
||||||
cur.close()
|
|
||||||
logger.info("Successfully created/verified car_frontal_info table with all required columns")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to create car_frontal_info table: {e}")
|
|
||||||
if self.connection:
|
|
||||||
self.connection.rollback()
|
|
||||||
return False
|
|
||||||
|
|
||||||
def insert_initial_detection(self, display_id: str, captured_timestamp: str, session_id: str = None) -> str:
|
|
||||||
"""Insert initial detection record and return the session_id."""
|
|
||||||
if not self.is_connected():
|
|
||||||
if not self.connect():
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Generate session_id if not provided
|
|
||||||
if not session_id:
|
|
||||||
session_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Ensure table exists
|
|
||||||
if not self.create_car_frontal_info_table():
|
|
||||||
logger.error("Failed to create/verify table before insertion")
|
|
||||||
return None
|
|
||||||
|
|
||||||
cur = self.connection.cursor()
|
|
||||||
insert_query = """
|
|
||||||
INSERT INTO gas_station_1.car_frontal_info
|
|
||||||
(display_id, captured_timestamp, session_id, license_character, license_type, car_brand, car_model, car_body_type)
|
|
||||||
VALUES (%s, %s, %s, NULL, 'No model available', NULL, NULL, NULL)
|
|
||||||
ON CONFLICT (session_id) DO NOTHING
|
|
||||||
"""
|
|
||||||
|
|
||||||
cur.execute(insert_query, (display_id, captured_timestamp, session_id))
|
|
||||||
self.connection.commit()
|
|
||||||
cur.close()
|
|
||||||
logger.info(f"Inserted initial detection record with session_id: {session_id}")
|
|
||||||
return session_id
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to insert initial detection record: {e}")
|
|
||||||
if self.connection:
|
|
||||||
self.connection.rollback()
|
|
||||||
return None
|
|
|
@ -5,624 +5,131 @@ import torch
|
||||||
import cv2
|
import cv2
|
||||||
import zipfile
|
import zipfile
|
||||||
import shutil
|
import shutil
|
||||||
import traceback
|
|
||||||
import redis
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
import concurrent.futures
|
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from .database import DatabaseManager
|
|
||||||
|
|
||||||
# Create a logger specifically for this module
|
def load_pipeline_node(node_config: dict, mpta_dir: str) -> dict:
|
||||||
logger = logging.getLogger("detector_worker.pympta")
|
|
||||||
|
|
||||||
def validate_redis_config(redis_config: dict) -> bool:
|
|
||||||
"""Validate Redis configuration parameters."""
|
|
||||||
required_fields = ["host", "port"]
|
|
||||||
for field in required_fields:
|
|
||||||
if field not in redis_config:
|
|
||||||
logger.error(f"Missing required Redis config field: {field}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not isinstance(redis_config["port"], int) or redis_config["port"] <= 0:
|
|
||||||
logger.error(f"Invalid Redis port: {redis_config['port']}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def validate_postgresql_config(pg_config: dict) -> bool:
|
|
||||||
"""Validate PostgreSQL configuration parameters."""
|
|
||||||
required_fields = ["host", "port", "database", "username", "password"]
|
|
||||||
for field in required_fields:
|
|
||||||
if field not in pg_config:
|
|
||||||
logger.error(f"Missing required PostgreSQL config field: {field}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not isinstance(pg_config["port"], int) or pg_config["port"] <= 0:
|
|
||||||
logger.error(f"Invalid PostgreSQL port: {pg_config['port']}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def crop_region_by_class(frame, regions_dict, class_name):
|
|
||||||
"""Crop a specific region from frame based on detected class."""
|
|
||||||
if class_name not in regions_dict:
|
|
||||||
logger.warning(f"Class '{class_name}' not found in detected regions")
|
|
||||||
return None
|
|
||||||
|
|
||||||
bbox = regions_dict[class_name]['bbox']
|
|
||||||
x1, y1, x2, y2 = bbox
|
|
||||||
cropped = frame[y1:y2, x1:x2]
|
|
||||||
|
|
||||||
if cropped.size == 0:
|
|
||||||
logger.warning(f"Empty crop for class '{class_name}' with bbox {bbox}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return cropped
|
|
||||||
|
|
||||||
def format_action_context(base_context, additional_context=None):
|
|
||||||
"""Format action context with dynamic values."""
|
|
||||||
context = {**base_context}
|
|
||||||
if additional_context:
|
|
||||||
context.update(additional_context)
|
|
||||||
return context
|
|
||||||
|
|
||||||
def load_pipeline_node(node_config: dict, mpta_dir: str, redis_client, db_manager=None) -> dict:
|
|
||||||
# Recursively load a model node from configuration.
|
|
||||||
model_path = os.path.join(mpta_dir, node_config["modelFile"])
|
model_path = os.path.join(mpta_dir, node_config["modelFile"])
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
logger.error(f"Model file {model_path} not found. Current directory: {os.getcwd()}")
|
logging.error(f"Model file {model_path} not found.")
|
||||||
logger.error(f"Directory content: {os.listdir(os.path.dirname(model_path))}")
|
|
||||||
raise FileNotFoundError(f"Model file {model_path} not found.")
|
raise FileNotFoundError(f"Model file {model_path} not found.")
|
||||||
logger.info(f"Loading model for node {node_config['modelId']} from {model_path}")
|
logging.info(f"Loading model {node_config['modelId']} from {model_path}")
|
||||||
model = YOLO(model_path)
|
model = YOLO(model_path)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
logger.info(f"CUDA available. Moving model {node_config['modelId']} to GPU")
|
|
||||||
model.to("cuda")
|
model.to("cuda")
|
||||||
else:
|
|
||||||
logger.info(f"CUDA not available. Using CPU for model {node_config['modelId']}")
|
|
||||||
|
|
||||||
# Prepare trigger class indices for optimization
|
# map triggerClasses names → indices for YOLO
|
||||||
trigger_classes = node_config.get("triggerClasses", [])
|
names = model.names # idx -> class name
|
||||||
trigger_class_indices = None
|
trigger_names = node_config.get("triggerClasses", [])
|
||||||
if trigger_classes and hasattr(model, "names"):
|
trigger_inds = [i for i, nm in names.items() if nm in trigger_names]
|
||||||
# Convert class names to indices for the model
|
|
||||||
trigger_class_indices = [i for i, name in model.names.items()
|
|
||||||
if name in trigger_classes]
|
|
||||||
logger.debug(f"Converted trigger classes to indices: {trigger_class_indices}")
|
|
||||||
|
|
||||||
node = {
|
return {
|
||||||
"modelId": node_config["modelId"],
|
"modelId": node_config["modelId"],
|
||||||
"modelFile": node_config["modelFile"],
|
"modelFile": node_config["modelFile"],
|
||||||
"triggerClasses": trigger_classes,
|
"triggerClasses": trigger_names,
|
||||||
"triggerClassIndices": trigger_class_indices,
|
"triggerClassIndices": trigger_inds,
|
||||||
"crop": node_config.get("crop", False),
|
"crop": node_config.get("crop", False),
|
||||||
"cropClass": node_config.get("cropClass"),
|
"minConfidence": node_config.get("minConfidence", 0.0),
|
||||||
"minConfidence": node_config.get("minConfidence", None),
|
|
||||||
"multiClass": node_config.get("multiClass", False),
|
|
||||||
"expectedClasses": node_config.get("expectedClasses", []),
|
|
||||||
"parallel": node_config.get("parallel", False),
|
|
||||||
"actions": node_config.get("actions", []),
|
|
||||||
"parallelActions": node_config.get("parallelActions", []),
|
|
||||||
"model": model,
|
"model": model,
|
||||||
"branches": [],
|
"branches": [
|
||||||
"redis_client": redis_client,
|
load_pipeline_node(child, mpta_dir)
|
||||||
"db_manager": db_manager
|
for child in node_config.get("branches", [])
|
||||||
|
]
|
||||||
}
|
}
|
||||||
logger.debug(f"Configured node {node_config['modelId']} with trigger classes: {node['triggerClasses']}")
|
|
||||||
for child in node_config.get("branches", []):
|
|
||||||
logger.debug(f"Loading branch for parent node {node_config['modelId']}")
|
|
||||||
node["branches"].append(load_pipeline_node(child, mpta_dir, redis_client, db_manager))
|
|
||||||
return node
|
|
||||||
|
|
||||||
def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
|
def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
|
||||||
logger.info(f"Attempting to load pipeline from {zip_source} to {target_dir}")
|
|
||||||
os.makedirs(target_dir, exist_ok=True)
|
os.makedirs(target_dir, exist_ok=True)
|
||||||
zip_path = os.path.join(target_dir, "pipeline.mpta")
|
zip_path = os.path.join(target_dir, "pipeline.mpta")
|
||||||
|
|
||||||
# Parse the source; only local files are supported here.
|
|
||||||
parsed = urlparse(zip_source)
|
parsed = urlparse(zip_source)
|
||||||
if parsed.scheme in ("", "file"):
|
if parsed.scheme in ("", "file"):
|
||||||
local_path = parsed.path if parsed.scheme == "file" else zip_source
|
local = parsed.path if parsed.scheme == "file" else zip_source
|
||||||
logger.debug(f"Checking if local file exists: {local_path}")
|
if not os.path.exists(local):
|
||||||
if os.path.exists(local_path):
|
logging.error(f"Local file {local} does not exist.")
|
||||||
try:
|
|
||||||
shutil.copy(local_path, zip_path)
|
|
||||||
logger.info(f"Copied local .mpta file from {local_path} to {zip_path}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to copy local .mpta file from {local_path}: {str(e)}", exc_info=True)
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
logger.error(f"Local file {local_path} does not exist. Current directory: {os.getcwd()}")
|
|
||||||
# List all subdirectories of models directory to help debugging
|
|
||||||
if os.path.exists("models"):
|
|
||||||
logger.error(f"Content of models directory: {os.listdir('models')}")
|
|
||||||
for root, dirs, files in os.walk("models"):
|
|
||||||
logger.error(f"Directory {root} contains subdirs: {dirs} and files: {files}")
|
|
||||||
else:
|
|
||||||
logger.error("The models directory doesn't exist")
|
|
||||||
return None
|
return None
|
||||||
|
shutil.copy(local, zip_path)
|
||||||
else:
|
else:
|
||||||
logger.error(f"HTTP download functionality has been moved. Use a local file path here. Received: {zip_source}")
|
logging.error("HTTP download not supported; use local file.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
with zipfile.ZipFile(zip_path, "r") as z:
|
||||||
if not os.path.exists(zip_path):
|
z.extractall(target_dir)
|
||||||
logger.error(f"Zip file not found at expected location: {zip_path}")
|
os.remove(zip_path)
|
||||||
return None
|
|
||||||
|
|
||||||
logger.debug(f"Extracting .mpta file from {zip_path} to {target_dir}")
|
base = os.path.splitext(os.path.basename(zip_source))[0]
|
||||||
# Extract contents and track the directories created
|
mpta_dir = os.path.join(target_dir, base)
|
||||||
extracted_dirs = []
|
cfg = os.path.join(mpta_dir, "pipeline.json")
|
||||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
if not os.path.exists(cfg):
|
||||||
file_list = zip_ref.namelist()
|
logging.error("pipeline.json not found in archive.")
|
||||||
logger.debug(f"Files in .mpta archive: {file_list}")
|
|
||||||
|
|
||||||
# Extract and track the top-level directories
|
|
||||||
for file_path in file_list:
|
|
||||||
parts = file_path.split('/')
|
|
||||||
if len(parts) > 1:
|
|
||||||
top_dir = parts[0]
|
|
||||||
if top_dir and top_dir not in extracted_dirs:
|
|
||||||
extracted_dirs.append(top_dir)
|
|
||||||
|
|
||||||
# Now extract the files
|
|
||||||
zip_ref.extractall(target_dir)
|
|
||||||
|
|
||||||
logger.info(f"Successfully extracted .mpta file to {target_dir}")
|
|
||||||
logger.debug(f"Extracted directories: {extracted_dirs}")
|
|
||||||
|
|
||||||
# Check what was actually created after extraction
|
|
||||||
actual_dirs = [d for d in os.listdir(target_dir) if os.path.isdir(os.path.join(target_dir, d))]
|
|
||||||
logger.debug(f"Actual directories created: {actual_dirs}")
|
|
||||||
except zipfile.BadZipFile as e:
|
|
||||||
logger.error(f"Bad zip file {zip_path}: {str(e)}", exc_info=True)
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to extract .mpta file {zip_path}: {str(e)}", exc_info=True)
|
|
||||||
return None
|
|
||||||
finally:
|
|
||||||
if os.path.exists(zip_path):
|
|
||||||
os.remove(zip_path)
|
|
||||||
logger.debug(f"Removed temporary zip file: {zip_path}")
|
|
||||||
|
|
||||||
# Use the first extracted directory if it exists, otherwise use the expected name
|
|
||||||
pipeline_name = os.path.basename(zip_source)
|
|
||||||
pipeline_name = os.path.splitext(pipeline_name)[0]
|
|
||||||
|
|
||||||
# Find the directory with pipeline.json
|
|
||||||
mpta_dir = None
|
|
||||||
# First try the expected directory name
|
|
||||||
expected_dir = os.path.join(target_dir, pipeline_name)
|
|
||||||
if os.path.exists(expected_dir) and os.path.exists(os.path.join(expected_dir, "pipeline.json")):
|
|
||||||
mpta_dir = expected_dir
|
|
||||||
logger.debug(f"Found pipeline.json in the expected directory: {mpta_dir}")
|
|
||||||
else:
|
|
||||||
# Look through all subdirectories for pipeline.json
|
|
||||||
for subdir in actual_dirs:
|
|
||||||
potential_dir = os.path.join(target_dir, subdir)
|
|
||||||
if os.path.exists(os.path.join(potential_dir, "pipeline.json")):
|
|
||||||
mpta_dir = potential_dir
|
|
||||||
logger.info(f"Found pipeline.json in directory: {mpta_dir} (different from expected: {expected_dir})")
|
|
||||||
break
|
|
||||||
|
|
||||||
if not mpta_dir:
|
|
||||||
logger.error(f"Could not find pipeline.json in any extracted directory. Directory content: {os.listdir(target_dir)}")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
pipeline_json_path = os.path.join(mpta_dir, "pipeline.json")
|
with open(cfg) as f:
|
||||||
if not os.path.exists(pipeline_json_path):
|
pipeline_config = json.load(f)
|
||||||
logger.error(f"pipeline.json not found in the .mpta file. Files in directory: {os.listdir(mpta_dir)}")
|
return load_pipeline_node(pipeline_config["pipeline"], mpta_dir)
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(pipeline_json_path, "r") as f:
|
|
||||||
pipeline_config = json.load(f)
|
|
||||||
logger.info(f"Successfully loaded pipeline configuration from {pipeline_json_path}")
|
|
||||||
logger.debug(f"Pipeline config: {json.dumps(pipeline_config, indent=2)}")
|
|
||||||
|
|
||||||
# Establish Redis connection if configured
|
def run_pipeline(frame, node: dict, return_bbox: bool=False):
|
||||||
redis_client = None
|
|
||||||
if "redis" in pipeline_config:
|
|
||||||
redis_config = pipeline_config["redis"]
|
|
||||||
if not validate_redis_config(redis_config):
|
|
||||||
logger.error("Invalid Redis configuration, skipping Redis connection")
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
redis_client = redis.Redis(
|
|
||||||
host=redis_config["host"],
|
|
||||||
port=redis_config["port"],
|
|
||||||
password=redis_config.get("password"),
|
|
||||||
db=redis_config.get("db", 0),
|
|
||||||
decode_responses=True
|
|
||||||
)
|
|
||||||
redis_client.ping()
|
|
||||||
logger.info(f"Successfully connected to Redis at {redis_config['host']}:{redis_config['port']}")
|
|
||||||
except redis.exceptions.ConnectionError as e:
|
|
||||||
logger.error(f"Failed to connect to Redis: {e}")
|
|
||||||
redis_client = None
|
|
||||||
|
|
||||||
# Establish PostgreSQL connection if configured
|
|
||||||
db_manager = None
|
|
||||||
if "postgresql" in pipeline_config:
|
|
||||||
pg_config = pipeline_config["postgresql"]
|
|
||||||
if not validate_postgresql_config(pg_config):
|
|
||||||
logger.error("Invalid PostgreSQL configuration, skipping database connection")
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
db_manager = DatabaseManager(pg_config)
|
|
||||||
if db_manager.connect():
|
|
||||||
logger.info(f"Successfully connected to PostgreSQL at {pg_config['host']}:{pg_config['port']}")
|
|
||||||
else:
|
|
||||||
logger.error("Failed to connect to PostgreSQL")
|
|
||||||
db_manager = None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error initializing PostgreSQL connection: {e}")
|
|
||||||
db_manager = None
|
|
||||||
|
|
||||||
return load_pipeline_node(pipeline_config["pipeline"], mpta_dir, redis_client, db_manager)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.error(f"Error parsing pipeline.json: {str(e)}", exc_info=True)
|
|
||||||
return None
|
|
||||||
except KeyError as e:
|
|
||||||
logger.error(f"Missing key in pipeline.json: {str(e)}", exc_info=True)
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading pipeline.json: {str(e)}", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def execute_actions(node, frame, detection_result, regions_dict=None):
|
|
||||||
if not node["redis_client"] or not node["actions"]:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Create a dynamic context for this detection event
|
|
||||||
from datetime import datetime
|
|
||||||
action_context = {
|
|
||||||
**detection_result,
|
|
||||||
"timestamp_ms": int(time.time() * 1000),
|
|
||||||
"uuid": str(uuid.uuid4()),
|
|
||||||
"timestamp": datetime.now().strftime("%Y-%m-%dT%H-%M-%S"),
|
|
||||||
"filename": f"{uuid.uuid4()}.jpg"
|
|
||||||
}
|
|
||||||
|
|
||||||
for action in node["actions"]:
|
|
||||||
try:
|
|
||||||
if action["type"] == "redis_save_image":
|
|
||||||
key = action["key"].format(**action_context)
|
|
||||||
|
|
||||||
# Check if we need to crop a specific region
|
|
||||||
region_name = action.get("region")
|
|
||||||
image_to_save = frame
|
|
||||||
|
|
||||||
if region_name and regions_dict:
|
|
||||||
cropped_image = crop_region_by_class(frame, regions_dict, region_name)
|
|
||||||
if cropped_image is not None:
|
|
||||||
image_to_save = cropped_image
|
|
||||||
logger.debug(f"Cropped region '{region_name}' for redis_save_image")
|
|
||||||
else:
|
|
||||||
logger.warning(f"Could not crop region '{region_name}', saving full frame instead")
|
|
||||||
|
|
||||||
# Encode image with specified format and quality (default to JPEG)
|
|
||||||
img_format = action.get("format", "jpeg").lower()
|
|
||||||
quality = action.get("quality", 90)
|
|
||||||
|
|
||||||
if img_format == "jpeg":
|
|
||||||
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
|
|
||||||
success, buffer = cv2.imencode('.jpg', image_to_save, encode_params)
|
|
||||||
elif img_format == "png":
|
|
||||||
success, buffer = cv2.imencode('.png', image_to_save)
|
|
||||||
else:
|
|
||||||
success, buffer = cv2.imencode('.jpg', image_to_save, [cv2.IMWRITE_JPEG_QUALITY, quality])
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
logger.error(f"Failed to encode image for redis_save_image")
|
|
||||||
continue
|
|
||||||
|
|
||||||
expire_seconds = action.get("expire_seconds")
|
|
||||||
if expire_seconds:
|
|
||||||
node["redis_client"].setex(key, expire_seconds, buffer.tobytes())
|
|
||||||
logger.info(f"Saved image to Redis with key: {key} (expires in {expire_seconds}s)")
|
|
||||||
else:
|
|
||||||
node["redis_client"].set(key, buffer.tobytes())
|
|
||||||
logger.info(f"Saved image to Redis with key: {key}")
|
|
||||||
action_context["image_key"] = key
|
|
||||||
elif action["type"] == "redis_publish":
|
|
||||||
channel = action["channel"]
|
|
||||||
try:
|
|
||||||
# Handle JSON message format by creating it programmatically
|
|
||||||
message_template = action["message"]
|
|
||||||
|
|
||||||
# Check if the message is JSON-like (starts and ends with braces)
|
|
||||||
if message_template.strip().startswith('{') and message_template.strip().endswith('}'):
|
|
||||||
# Create JSON data programmatically to avoid formatting issues
|
|
||||||
json_data = {}
|
|
||||||
|
|
||||||
# Add common fields
|
|
||||||
json_data["event"] = "frontal_detected"
|
|
||||||
json_data["display_id"] = action_context.get("display_id", "unknown")
|
|
||||||
json_data["session_id"] = action_context.get("session_id")
|
|
||||||
json_data["timestamp"] = action_context.get("timestamp", "")
|
|
||||||
json_data["image_key"] = action_context.get("image_key", "")
|
|
||||||
|
|
||||||
# Convert to JSON string
|
|
||||||
message = json.dumps(json_data)
|
|
||||||
else:
|
|
||||||
# Use regular string formatting for non-JSON messages
|
|
||||||
message = message_template.format(**action_context)
|
|
||||||
|
|
||||||
# Publish to Redis
|
|
||||||
if not node["redis_client"]:
|
|
||||||
logger.error("Redis client is None, cannot publish message")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Test Redis connection
|
|
||||||
try:
|
|
||||||
node["redis_client"].ping()
|
|
||||||
logger.debug("Redis connection is active")
|
|
||||||
except Exception as ping_error:
|
|
||||||
logger.error(f"Redis connection test failed: {ping_error}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
result = node["redis_client"].publish(channel, message)
|
|
||||||
logger.info(f"Published message to Redis channel '{channel}': {message}")
|
|
||||||
logger.info(f"Redis publish result (subscribers count): {result}")
|
|
||||||
|
|
||||||
# Additional debug info
|
|
||||||
if result == 0:
|
|
||||||
logger.warning(f"No subscribers listening to channel '{channel}'")
|
|
||||||
else:
|
|
||||||
logger.info(f"Message delivered to {result} subscriber(s)")
|
|
||||||
|
|
||||||
except KeyError as e:
|
|
||||||
logger.error(f"Missing key in redis_publish message template: {e}")
|
|
||||||
logger.debug(f"Available context keys: {list(action_context.keys())}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in redis_publish action: {e}")
|
|
||||||
logger.debug(f"Message template: {action['message']}")
|
|
||||||
logger.debug(f"Available context keys: {list(action_context.keys())}")
|
|
||||||
import traceback
|
|
||||||
logger.debug(f"Full traceback: {traceback.format_exc()}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error executing action {action['type']}: {e}")
|
|
||||||
|
|
||||||
def execute_parallel_actions(node, frame, detection_result, regions_dict):
|
|
||||||
"""Execute parallel actions after all required branches have completed."""
|
|
||||||
if not node.get("parallelActions"):
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug("Executing parallel actions...")
|
|
||||||
branch_results = detection_result.get("branch_results", {})
|
|
||||||
|
|
||||||
for action in node["parallelActions"]:
|
|
||||||
try:
|
|
||||||
action_type = action.get("type")
|
|
||||||
logger.debug(f"Processing parallel action: {action_type}")
|
|
||||||
|
|
||||||
if action_type == "postgresql_update_combined":
|
|
||||||
# Check if all required branches have completed
|
|
||||||
wait_for_branches = action.get("waitForBranches", [])
|
|
||||||
missing_branches = [branch for branch in wait_for_branches if branch not in branch_results]
|
|
||||||
|
|
||||||
if missing_branches:
|
|
||||||
logger.warning(f"Cannot execute postgresql_update_combined: missing branch results for {missing_branches}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.info(f"All required branches completed: {wait_for_branches}")
|
|
||||||
|
|
||||||
# Execute the database update
|
|
||||||
execute_postgresql_update_combined(node, action, detection_result, branch_results)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unknown parallel action type: {action_type}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error executing parallel action {action.get('type', 'unknown')}: {e}")
|
|
||||||
import traceback
|
|
||||||
logger.debug(f"Full traceback: {traceback.format_exc()}")
|
|
||||||
|
|
||||||
def execute_postgresql_update_combined(node, action, detection_result, branch_results):
|
|
||||||
"""Execute a PostgreSQL update with combined branch results."""
|
|
||||||
if not node.get("db_manager"):
|
|
||||||
logger.error("No database manager available for postgresql_update_combined action")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
table = action["table"]
|
|
||||||
key_field = action["key_field"]
|
|
||||||
key_value_template = action["key_value"]
|
|
||||||
fields = action["fields"]
|
|
||||||
|
|
||||||
# Create context for key value formatting
|
|
||||||
action_context = {**detection_result}
|
|
||||||
key_value = key_value_template.format(**action_context)
|
|
||||||
|
|
||||||
logger.info(f"Executing database update: table={table}, {key_field}={key_value}")
|
|
||||||
|
|
||||||
# Process field mappings
|
|
||||||
mapped_fields = {}
|
|
||||||
for db_field, value_template in fields.items():
|
|
||||||
try:
|
|
||||||
mapped_value = resolve_field_mapping(value_template, branch_results, action_context)
|
|
||||||
if mapped_value is not None:
|
|
||||||
mapped_fields[db_field] = mapped_value
|
|
||||||
logger.debug(f"Mapped field: {db_field} = {mapped_value}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"Could not resolve field mapping for {db_field}: {value_template}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error mapping field {db_field} with template '{value_template}': {e}")
|
|
||||||
|
|
||||||
if not mapped_fields:
|
|
||||||
logger.warning("No fields mapped successfully, skipping database update")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Execute the database update
|
|
||||||
success = node["db_manager"].execute_update(table, key_field, key_value, mapped_fields)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
logger.info(f"Successfully updated database: {table} with {len(mapped_fields)} fields")
|
|
||||||
else:
|
|
||||||
logger.error(f"Failed to update database: {table}")
|
|
||||||
|
|
||||||
except KeyError as e:
|
|
||||||
logger.error(f"Missing required field in postgresql_update_combined action: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in postgresql_update_combined action: {e}")
|
|
||||||
import traceback
|
|
||||||
logger.debug(f"Full traceback: {traceback.format_exc()}")
|
|
||||||
|
|
||||||
def resolve_field_mapping(value_template, branch_results, action_context):
|
|
||||||
"""Resolve field mapping templates like {car_brand_cls_v1.brand}."""
|
|
||||||
try:
|
|
||||||
# Handle simple context variables first (non-branch references)
|
|
||||||
if not '.' in value_template:
|
|
||||||
return value_template.format(**action_context)
|
|
||||||
|
|
||||||
# Handle branch result references like {model_id.field}
|
|
||||||
import re
|
|
||||||
branch_refs = re.findall(r'\{([^}]+\.[^}]+)\}', value_template)
|
|
||||||
|
|
||||||
resolved_template = value_template
|
|
||||||
for ref in branch_refs:
|
|
||||||
try:
|
|
||||||
model_id, field_name = ref.split('.', 1)
|
|
||||||
|
|
||||||
if model_id in branch_results:
|
|
||||||
branch_data = branch_results[model_id]
|
|
||||||
if field_name in branch_data:
|
|
||||||
field_value = branch_data[field_name]
|
|
||||||
resolved_template = resolved_template.replace(f'{{{ref}}}', str(field_value))
|
|
||||||
logger.debug(f"Resolved {ref} to {field_value}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"Field '{field_name}' not found in branch '{model_id}' results. Available fields: {list(branch_data.keys())}")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
logger.warning(f"Branch '{model_id}' not found in results. Available branches: {list(branch_results.keys())}")
|
|
||||||
return None
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(f"Invalid branch reference format: {ref}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Format any remaining simple variables
|
|
||||||
try:
|
|
||||||
final_value = resolved_template.format(**action_context)
|
|
||||||
return final_value
|
|
||||||
except KeyError as e:
|
|
||||||
logger.warning(f"Could not resolve context variable in template: {e}")
|
|
||||||
return resolved_template
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error resolving field mapping '{value_template}': {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def validate_pipeline_execution(node, regions_dict):
|
|
||||||
"""
|
"""
|
||||||
Pre-validate that all required branches will execute successfully before
|
- For detection nodes (task != 'classify'):
|
||||||
committing to Redis actions and database records.
|
• runs `track(..., classes=triggerClassIndices)`
|
||||||
|
• picks top box ≥ minConfidence
|
||||||
Returns:
|
• optionally crops & resizes → recurse into child
|
||||||
- (True, []) if pipeline can execute completely
|
• else returns (det_dict, bbox)
|
||||||
- (False, missing_branches) if some required branches won't execute
|
- For classify nodes:
|
||||||
"""
|
• runs `predict()`
|
||||||
# Get all branches that parallel actions are waiting for
|
• returns top (class,confidence) and no bbox
|
||||||
required_branches = set()
|
|
||||||
|
|
||||||
for action in node.get("parallelActions", []):
|
|
||||||
if action.get("type") == "postgresql_update_combined":
|
|
||||||
wait_for_branches = action.get("waitForBranches", [])
|
|
||||||
required_branches.update(wait_for_branches)
|
|
||||||
|
|
||||||
if not required_branches:
|
|
||||||
# No parallel actions requiring specific branches
|
|
||||||
logger.debug("No parallel actions with waitForBranches - validation passes")
|
|
||||||
return True, []
|
|
||||||
|
|
||||||
logger.debug(f"Pre-validation: checking if required branches {list(required_branches)} will execute")
|
|
||||||
|
|
||||||
# Check each required branch
|
|
||||||
missing_branches = []
|
|
||||||
|
|
||||||
for branch in node.get("branches", []):
|
|
||||||
branch_id = branch["modelId"]
|
|
||||||
|
|
||||||
if branch_id not in required_branches:
|
|
||||||
continue # This branch is not required by parallel actions
|
|
||||||
|
|
||||||
# Check if this branch would be triggered
|
|
||||||
trigger_classes = branch.get("triggerClasses", [])
|
|
||||||
min_conf = branch.get("minConfidence", 0)
|
|
||||||
|
|
||||||
branch_triggered = False
|
|
||||||
for det_class in regions_dict:
|
|
||||||
det_confidence = regions_dict[det_class]["confidence"]
|
|
||||||
|
|
||||||
if (det_class in trigger_classes and det_confidence >= min_conf):
|
|
||||||
branch_triggered = True
|
|
||||||
logger.debug(f"Pre-validation: branch {branch_id} WILL be triggered by {det_class} (conf={det_confidence:.3f} >= {min_conf})")
|
|
||||||
break
|
|
||||||
|
|
||||||
if not branch_triggered:
|
|
||||||
missing_branches.append(branch_id)
|
|
||||||
logger.warning(f"Pre-validation: branch {branch_id} will NOT be triggered - no matching classes or insufficient confidence")
|
|
||||||
logger.debug(f" Required: {trigger_classes} with min_conf={min_conf}")
|
|
||||||
logger.debug(f" Available: {[(cls, regions_dict[cls]['confidence']) for cls in regions_dict]}")
|
|
||||||
|
|
||||||
if missing_branches:
|
|
||||||
logger.error(f"Pipeline pre-validation FAILED: required branches {missing_branches} will not execute")
|
|
||||||
return False, missing_branches
|
|
||||||
else:
|
|
||||||
logger.info(f"Pipeline pre-validation PASSED: all required branches {list(required_branches)} will execute")
|
|
||||||
return True, []
|
|
||||||
|
|
||||||
def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
|
|
||||||
"""
|
|
||||||
Enhanced pipeline that supports:
|
|
||||||
- Multi-class detection (detecting multiple classes simultaneously)
|
|
||||||
- Parallel branch processing
|
|
||||||
- Region-based actions and cropping
|
|
||||||
- Context passing for session/camera information
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
task = getattr(node["model"], "task", None)
|
task = getattr(node["model"], "task", None)
|
||||||
|
|
||||||
# ─── Classification stage ───────────────────────────────────
|
# ─── Classification stage ───────────────────────────────────
|
||||||
|
# if task == "classify":
|
||||||
|
# results = node["model"].predict(frame, stream=False)
|
||||||
|
# dets = []
|
||||||
|
# for r in results:
|
||||||
|
# probs = r.probs
|
||||||
|
# if probs is not None:
|
||||||
|
# # sort descending
|
||||||
|
# idxs = probs.argsort(descending=True)
|
||||||
|
# for cid in idxs:
|
||||||
|
# dets.append({
|
||||||
|
# "class": node["model"].names[int(cid)],
|
||||||
|
# "confidence": float(probs[int(cid)]),
|
||||||
|
# "id": None
|
||||||
|
# })
|
||||||
|
# if not dets:
|
||||||
|
# return (None, None) if return_bbox else None
|
||||||
|
|
||||||
|
# best = dets[0]
|
||||||
|
# return (best, None) if return_bbox else best
|
||||||
|
|
||||||
if task == "classify":
|
if task == "classify":
|
||||||
|
# run the classifier and grab its top-1 directly via the Probs API
|
||||||
results = node["model"].predict(frame, stream=False)
|
results = node["model"].predict(frame, stream=False)
|
||||||
|
# nothing returned?
|
||||||
if not results:
|
if not results:
|
||||||
return (None, None) if return_bbox else None
|
return (None, None) if return_bbox else None
|
||||||
|
|
||||||
r = results[0]
|
# take the first result's probs object
|
||||||
|
r = results[0]
|
||||||
probs = r.probs
|
probs = r.probs
|
||||||
if probs is None:
|
if probs is None:
|
||||||
return (None, None) if return_bbox else None
|
return (None, None) if return_bbox else None
|
||||||
|
|
||||||
top1_idx = int(probs.top1)
|
# get the top-1 class index and its confidence
|
||||||
|
top1_idx = int(probs.top1)
|
||||||
top1_conf = float(probs.top1conf)
|
top1_conf = float(probs.top1conf)
|
||||||
class_name = node["model"].names[top1_idx]
|
|
||||||
|
|
||||||
det = {
|
det = {
|
||||||
"class": class_name,
|
"class": node["model"].names[top1_idx],
|
||||||
"confidence": top1_conf,
|
"confidence": top1_conf,
|
||||||
"id": None,
|
"id": None
|
||||||
class_name: class_name # Add class name as key for backward compatibility
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add specific field mappings for database operations based on model type
|
|
||||||
model_id = node.get("modelId", "").lower()
|
|
||||||
if "brand" in model_id or "brand_cls" in model_id:
|
|
||||||
det["brand"] = class_name
|
|
||||||
elif "bodytype" in model_id or "body" in model_id:
|
|
||||||
det["body_type"] = class_name
|
|
||||||
elif "color" in model_id:
|
|
||||||
det["color"] = class_name
|
|
||||||
|
|
||||||
execute_actions(node, frame, det)
|
|
||||||
return (det, None) if return_bbox else det
|
return (det, None) if return_bbox else det
|
||||||
|
|
||||||
# ─── Detection stage - Multi-class support ──────────────────
|
|
||||||
tk = node["triggerClassIndices"]
|
|
||||||
logger.debug(f"Running detection for node {node['modelId']} with trigger classes: {node.get('triggerClasses', [])} (indices: {tk})")
|
|
||||||
logger.debug(f"Node configuration: minConfidence={node['minConfidence']}, multiClass={node.get('multiClass', False)}")
|
|
||||||
|
|
||||||
|
# ─── Detection stage ────────────────────────────────────────
|
||||||
|
# only look for your triggerClasses
|
||||||
|
tk = node["triggerClassIndices"]
|
||||||
res = node["model"].track(
|
res = node["model"].track(
|
||||||
frame,
|
frame,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -630,238 +137,46 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
|
||||||
**({"classes": tk} if tk else {})
|
**({"classes": tk} if tk else {})
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# Collect all detections above confidence threshold
|
dets, boxes = [], []
|
||||||
all_detections = []
|
for box in res.boxes:
|
||||||
all_boxes = []
|
|
||||||
regions_dict = {}
|
|
||||||
|
|
||||||
logger.debug(f"Raw detection results from model: {len(res.boxes) if res.boxes is not None else 0} detections")
|
|
||||||
|
|
||||||
for i, box in enumerate(res.boxes):
|
|
||||||
conf = float(box.cpu().conf[0])
|
conf = float(box.cpu().conf[0])
|
||||||
cid = int(box.cpu().cls[0])
|
cid = int(box.cpu().cls[0])
|
||||||
name = node["model"].names[cid]
|
name = node["model"].names[cid]
|
||||||
|
|
||||||
logger.debug(f"Detection {i}: class='{name}' (id={cid}), confidence={conf:.3f}, threshold={node['minConfidence']}")
|
|
||||||
|
|
||||||
if conf < node["minConfidence"]:
|
if conf < node["minConfidence"]:
|
||||||
logger.debug(f" -> REJECTED: confidence {conf:.3f} < threshold {node['minConfidence']}")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
xy = box.cpu().xyxy[0]
|
xy = box.cpu().xyxy[0]
|
||||||
x1, y1, x2, y2 = map(int, xy)
|
x1,y1,x2,y2 = map(int, xy)
|
||||||
bbox = (x1, y1, x2, y2)
|
dets.append({"class": name, "confidence": conf,
|
||||||
|
"id": box.id.item() if hasattr(box, "id") else None})
|
||||||
|
boxes.append((x1, y1, x2, y2))
|
||||||
|
|
||||||
detection = {
|
if not dets:
|
||||||
"class": name,
|
|
||||||
"confidence": conf,
|
|
||||||
"id": box.id.item() if hasattr(box, "id") else None,
|
|
||||||
"bbox": bbox
|
|
||||||
}
|
|
||||||
|
|
||||||
all_detections.append(detection)
|
|
||||||
all_boxes.append(bbox)
|
|
||||||
|
|
||||||
logger.debug(f" -> ACCEPTED: {name} with confidence {conf:.3f}, bbox={bbox}")
|
|
||||||
|
|
||||||
# Store highest confidence detection for each class
|
|
||||||
if name not in regions_dict or conf > regions_dict[name]["confidence"]:
|
|
||||||
regions_dict[name] = {
|
|
||||||
"bbox": bbox,
|
|
||||||
"confidence": conf,
|
|
||||||
"detection": detection
|
|
||||||
}
|
|
||||||
logger.debug(f" -> Updated regions_dict['{name}'] with confidence {conf:.3f}")
|
|
||||||
|
|
||||||
logger.info(f"Detection summary: {len(all_detections)} accepted detections from {len(res.boxes) if res.boxes is not None else 0} total")
|
|
||||||
logger.info(f"Detected classes: {list(regions_dict.keys())}")
|
|
||||||
|
|
||||||
if not all_detections:
|
|
||||||
logger.warning("No detections above confidence threshold - returning null")
|
|
||||||
return (None, None) if return_bbox else None
|
return (None, None) if return_bbox else None
|
||||||
|
|
||||||
# ─── Multi-class validation ─────────────────────────────────
|
# take highest‐confidence
|
||||||
if node.get("multiClass", False) and node.get("expectedClasses"):
|
best_idx = max(range(len(dets)), key=lambda i: dets[i]["confidence"])
|
||||||
expected_classes = node["expectedClasses"]
|
best_det = dets[best_idx]
|
||||||
detected_classes = list(regions_dict.keys())
|
best_box = boxes[best_idx]
|
||||||
|
|
||||||
logger.info(f"Multi-class validation: expected={expected_classes}, detected={detected_classes}")
|
# ─── Branch (classification) ───────────────────────────────
|
||||||
|
for br in node["branches"]:
|
||||||
|
if (best_det["class"] in br["triggerClasses"]
|
||||||
|
and best_det["confidence"] >= br["minConfidence"]):
|
||||||
|
# crop if requested
|
||||||
|
sub = frame
|
||||||
|
if br["crop"]:
|
||||||
|
x1,y1,x2,y2 = best_box
|
||||||
|
sub = frame[y1:y2, x1:x2]
|
||||||
|
sub = cv2.resize(sub, (224, 224))
|
||||||
|
|
||||||
# Check if at least one expected class is detected (flexible mode)
|
det2, _ = run_pipeline(sub, br, return_bbox=True)
|
||||||
matching_classes = [cls for cls in expected_classes if cls in detected_classes]
|
if det2:
|
||||||
missing_classes = [cls for cls in expected_classes if cls not in detected_classes]
|
# return classification result + original bbox
|
||||||
|
return (det2, best_box) if return_bbox else det2
|
||||||
|
|
||||||
logger.debug(f"Matching classes: {matching_classes}, Missing classes: {missing_classes}")
|
# ─── No branch matched → return this detection ─────────────
|
||||||
|
return (best_det, best_box) if return_bbox else best_det
|
||||||
if not matching_classes:
|
|
||||||
# No expected classes found at all
|
|
||||||
logger.warning(f"PIPELINE REJECTED: No expected classes detected. Expected: {expected_classes}, Detected: {detected_classes}")
|
|
||||||
return (None, None) if return_bbox else None
|
|
||||||
|
|
||||||
if missing_classes:
|
|
||||||
logger.info(f"Partial multi-class detection: {matching_classes} found, {missing_classes} missing")
|
|
||||||
else:
|
|
||||||
logger.info(f"Complete multi-class detection success: {detected_classes}")
|
|
||||||
else:
|
|
||||||
logger.debug("No multi-class validation - proceeding with all detections")
|
|
||||||
|
|
||||||
# ─── Pre-validate pipeline execution ────────────────────────
|
|
||||||
pipeline_valid, missing_branches = validate_pipeline_execution(node, regions_dict)
|
|
||||||
|
|
||||||
if not pipeline_valid:
|
|
||||||
logger.error(f"Pipeline execution validation FAILED - required branches {missing_branches} cannot execute")
|
|
||||||
logger.error("Aborting pipeline: no Redis actions or database records will be created")
|
|
||||||
return (None, None) if return_bbox else None
|
|
||||||
|
|
||||||
# ─── Execute actions with region information ────────────────
|
|
||||||
detection_result = {
|
|
||||||
"detections": all_detections,
|
|
||||||
"regions": regions_dict,
|
|
||||||
**(context or {})
|
|
||||||
}
|
|
||||||
|
|
||||||
# ─── Create initial database record when Car+Frontal detected ────
|
|
||||||
if node.get("db_manager") and node.get("multiClass", False):
|
|
||||||
# Only create database record if we have both Car and Frontal
|
|
||||||
has_car = "Car" in regions_dict
|
|
||||||
has_frontal = "Frontal" in regions_dict
|
|
||||||
|
|
||||||
if has_car and has_frontal:
|
|
||||||
# Generate UUID session_id since client session is None for now
|
|
||||||
import uuid as uuid_lib
|
|
||||||
from datetime import datetime
|
|
||||||
generated_session_id = str(uuid_lib.uuid4())
|
|
||||||
|
|
||||||
# Insert initial detection record
|
|
||||||
display_id = detection_result.get("display_id", "unknown")
|
|
||||||
timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
|
||||||
|
|
||||||
inserted_session_id = node["db_manager"].insert_initial_detection(
|
|
||||||
display_id=display_id,
|
|
||||||
captured_timestamp=timestamp,
|
|
||||||
session_id=generated_session_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if inserted_session_id:
|
|
||||||
# Update detection_result with the generated session_id for actions and branches
|
|
||||||
detection_result["session_id"] = inserted_session_id
|
|
||||||
detection_result["timestamp"] = timestamp # Update with proper timestamp
|
|
||||||
logger.info(f"Created initial database record with session_id: {inserted_session_id}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"Database record not created - missing required classes. Has Car: {has_car}, Has Frontal: {has_frontal}")
|
|
||||||
|
|
||||||
execute_actions(node, frame, detection_result, regions_dict)
|
|
||||||
|
|
||||||
# ─── Parallel branch processing ─────────────────────────────
|
|
||||||
if node["branches"]:
|
|
||||||
branch_results = {}
|
|
||||||
|
|
||||||
# Filter branches that should be triggered
|
|
||||||
active_branches = []
|
|
||||||
for br in node["branches"]:
|
|
||||||
trigger_classes = br.get("triggerClasses", [])
|
|
||||||
min_conf = br.get("minConfidence", 0)
|
|
||||||
|
|
||||||
logger.debug(f"Evaluating branch {br['modelId']}: trigger_classes={trigger_classes}, min_conf={min_conf}")
|
|
||||||
|
|
||||||
# Check if any detected class matches branch trigger
|
|
||||||
branch_triggered = False
|
|
||||||
for det_class in regions_dict:
|
|
||||||
det_confidence = regions_dict[det_class]["confidence"]
|
|
||||||
logger.debug(f" Checking detected class '{det_class}' (confidence={det_confidence:.3f}) against triggers {trigger_classes}")
|
|
||||||
|
|
||||||
if (det_class in trigger_classes and det_confidence >= min_conf):
|
|
||||||
active_branches.append(br)
|
|
||||||
branch_triggered = True
|
|
||||||
logger.info(f"Branch {br['modelId']} activated by class '{det_class}' (conf={det_confidence:.3f} >= {min_conf})")
|
|
||||||
break
|
|
||||||
|
|
||||||
if not branch_triggered:
|
|
||||||
logger.debug(f"Branch {br['modelId']} not triggered - no matching classes or insufficient confidence")
|
|
||||||
|
|
||||||
if active_branches:
|
|
||||||
if node.get("parallel", False) or any(br.get("parallel", False) for br in active_branches):
|
|
||||||
# Run branches in parallel
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=len(active_branches)) as executor:
|
|
||||||
futures = {}
|
|
||||||
|
|
||||||
for br in active_branches:
|
|
||||||
crop_class = br.get("cropClass", br.get("triggerClasses", [])[0] if br.get("triggerClasses") else None)
|
|
||||||
sub_frame = frame
|
|
||||||
|
|
||||||
logger.info(f"Starting parallel branch: {br['modelId']}, crop_class: {crop_class}")
|
|
||||||
|
|
||||||
if br.get("crop", False) and crop_class:
|
|
||||||
cropped = crop_region_by_class(frame, regions_dict, crop_class)
|
|
||||||
if cropped is not None:
|
|
||||||
sub_frame = cv2.resize(cropped, (224, 224))
|
|
||||||
logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch")
|
|
||||||
continue
|
|
||||||
|
|
||||||
future = executor.submit(run_pipeline, sub_frame, br, True, context)
|
|
||||||
futures[future] = br
|
|
||||||
|
|
||||||
# Collect results
|
|
||||||
for future in concurrent.futures.as_completed(futures):
|
|
||||||
br = futures[future]
|
|
||||||
try:
|
|
||||||
result, _ = future.result()
|
|
||||||
if result:
|
|
||||||
branch_results[br["modelId"]] = result
|
|
||||||
logger.info(f"Branch {br['modelId']} completed: {result}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Branch {br['modelId']} failed: {e}")
|
|
||||||
else:
|
|
||||||
# Run branches sequentially
|
|
||||||
for br in active_branches:
|
|
||||||
crop_class = br.get("cropClass", br.get("triggerClasses", [])[0] if br.get("triggerClasses") else None)
|
|
||||||
sub_frame = frame
|
|
||||||
|
|
||||||
logger.info(f"Starting sequential branch: {br['modelId']}, crop_class: {crop_class}")
|
|
||||||
|
|
||||||
if br.get("crop", False) and crop_class:
|
|
||||||
cropped = crop_region_by_class(frame, regions_dict, crop_class)
|
|
||||||
if cropped is not None:
|
|
||||||
sub_frame = cv2.resize(cropped, (224, 224))
|
|
||||||
logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch")
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
result, _ = run_pipeline(sub_frame, br, True, context)
|
|
||||||
if result:
|
|
||||||
branch_results[br["modelId"]] = result
|
|
||||||
logger.info(f"Branch {br['modelId']} completed: {result}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"Branch {br['modelId']} returned no result")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in sequential branch {br['modelId']}: {e}")
|
|
||||||
import traceback
|
|
||||||
logger.debug(f"Branch error traceback: {traceback.format_exc()}")
|
|
||||||
|
|
||||||
# Store branch results in detection_result for parallel actions
|
|
||||||
detection_result["branch_results"] = branch_results
|
|
||||||
|
|
||||||
# ─── Execute Parallel Actions ───────────────────────────────
|
|
||||||
if node.get("parallelActions") and "branch_results" in detection_result:
|
|
||||||
execute_parallel_actions(node, frame, detection_result, regions_dict)
|
|
||||||
|
|
||||||
# ─── Return detection result ────────────────────────────────
|
|
||||||
primary_detection = max(all_detections, key=lambda x: x["confidence"])
|
|
||||||
primary_bbox = primary_detection["bbox"]
|
|
||||||
|
|
||||||
# Add branch results and session_id to primary detection for compatibility
|
|
||||||
if "branch_results" in detection_result:
|
|
||||||
primary_detection["branch_results"] = detection_result["branch_results"]
|
|
||||||
if "session_id" in detection_result:
|
|
||||||
primary_detection["session_id"] = detection_result["session_id"]
|
|
||||||
|
|
||||||
return (primary_detection, primary_bbox) if return_bbox else primary_detection
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in node {node.get('modelId')}: {e}")
|
logging.error(f"Error in node {node.get('modelId')}: {e}")
|
||||||
traceback.print_exc()
|
|
||||||
return (None, None) if return_bbox else None
|
return (None, None) if return_bbox else None
|
||||||
|
|
125
test_protocol.py
125
test_protocol.py
|
@ -1,125 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test script to verify the worker implementation follows the protocol
|
|
||||||
"""
|
|
||||||
import json
|
|
||||||
import asyncio
|
|
||||||
import websockets
|
|
||||||
import time
|
|
||||||
|
|
||||||
async def test_protocol():
|
|
||||||
"""Test the worker protocol implementation"""
|
|
||||||
uri = "ws://localhost:8000"
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with websockets.connect(uri) as websocket:
|
|
||||||
print("✓ Connected to worker")
|
|
||||||
|
|
||||||
# Test 1: Check if we receive heartbeat (stateReport)
|
|
||||||
print("\n1. Testing heartbeat...")
|
|
||||||
try:
|
|
||||||
message = await asyncio.wait_for(websocket.recv(), timeout=5)
|
|
||||||
data = json.loads(message)
|
|
||||||
if data.get("type") == "stateReport":
|
|
||||||
print("✓ Received stateReport heartbeat")
|
|
||||||
print(f" - CPU Usage: {data.get('cpuUsage', 'N/A')}%")
|
|
||||||
print(f" - Memory Usage: {data.get('memoryUsage', 'N/A')}%")
|
|
||||||
print(f" - Camera Connections: {len(data.get('cameraConnections', []))}")
|
|
||||||
else:
|
|
||||||
print(f"✗ Expected stateReport, got {data.get('type')}")
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
print("✗ No heartbeat received within 5 seconds")
|
|
||||||
|
|
||||||
# Test 2: Request state
|
|
||||||
print("\n2. Testing requestState...")
|
|
||||||
await websocket.send(json.dumps({"type": "requestState"}))
|
|
||||||
try:
|
|
||||||
message = await asyncio.wait_for(websocket.recv(), timeout=5)
|
|
||||||
data = json.loads(message)
|
|
||||||
if data.get("type") == "stateReport":
|
|
||||||
print("✓ Received stateReport response")
|
|
||||||
else:
|
|
||||||
print(f"✗ Expected stateReport, got {data.get('type')}")
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
print("✗ No response to requestState within 5 seconds")
|
|
||||||
|
|
||||||
# Test 3: Set session ID
|
|
||||||
print("\n3. Testing setSessionId...")
|
|
||||||
session_message = {
|
|
||||||
"type": "setSessionId",
|
|
||||||
"payload": {
|
|
||||||
"displayIdentifier": "display-001",
|
|
||||||
"sessionId": 12345
|
|
||||||
}
|
|
||||||
}
|
|
||||||
await websocket.send(json.dumps(session_message))
|
|
||||||
print("✓ Sent setSessionId message")
|
|
||||||
|
|
||||||
# Test 4: Test patchSession
|
|
||||||
print("\n4. Testing patchSession...")
|
|
||||||
patch_message = {
|
|
||||||
"type": "patchSession",
|
|
||||||
"sessionId": 12345,
|
|
||||||
"data": {
|
|
||||||
"currentCar": {
|
|
||||||
"carModel": "Civic",
|
|
||||||
"carBrand": "Honda"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
await websocket.send(json.dumps(patch_message))
|
|
||||||
|
|
||||||
# Wait for patchSessionResult
|
|
||||||
try:
|
|
||||||
message = await asyncio.wait_for(websocket.recv(), timeout=5)
|
|
||||||
data = json.loads(message)
|
|
||||||
if data.get("type") == "patchSessionResult":
|
|
||||||
print("✓ Received patchSessionResult")
|
|
||||||
print(f" - Success: {data.get('payload', {}).get('success')}")
|
|
||||||
print(f" - Message: {data.get('payload', {}).get('message')}")
|
|
||||||
else:
|
|
||||||
print(f"✗ Expected patchSessionResult, got {data.get('type')}")
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
print("✗ No patchSessionResult received within 5 seconds")
|
|
||||||
|
|
||||||
# Test 5: Test subscribe message format (without actual camera)
|
|
||||||
print("\n5. Testing subscribe message format...")
|
|
||||||
subscribe_message = {
|
|
||||||
"type": "subscribe",
|
|
||||||
"payload": {
|
|
||||||
"subscriptionIdentifier": "display-001;cam-001",
|
|
||||||
"snapshotUrl": "http://example.com/snapshot.jpg",
|
|
||||||
"snapshotInterval": 5000,
|
|
||||||
"modelUrl": "http://example.com/model.mpta",
|
|
||||||
"modelName": "Test Model",
|
|
||||||
"modelId": 101,
|
|
||||||
"cropX1": 100,
|
|
||||||
"cropY1": 200,
|
|
||||||
"cropX2": 300,
|
|
||||||
"cropY2": 400
|
|
||||||
}
|
|
||||||
}
|
|
||||||
await websocket.send(json.dumps(subscribe_message))
|
|
||||||
print("✓ Sent subscribe message (will fail without actual camera/model)")
|
|
||||||
|
|
||||||
# Listen for a few more messages to catch any errors
|
|
||||||
print("\n6. Listening for additional messages...")
|
|
||||||
for i in range(3):
|
|
||||||
try:
|
|
||||||
message = await asyncio.wait_for(websocket.recv(), timeout=2)
|
|
||||||
data = json.loads(message)
|
|
||||||
msg_type = data.get("type")
|
|
||||||
print(f" - Received {msg_type}")
|
|
||||||
if msg_type == "error":
|
|
||||||
print(f" Error: {data.get('error')}")
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
break
|
|
||||||
|
|
||||||
print("\n✓ Protocol test completed successfully!")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Connection failed: {e}")
|
|
||||||
print("Make sure the worker is running on localhost:8000")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(test_protocol())
|
|
495
worker.md
495
worker.md
|
@ -1,495 +0,0 @@
|
||||||
# Worker Communication Protocol
|
|
||||||
|
|
||||||
This document outlines the WebSocket-based communication protocol between the CMS backend and a detector worker. As a worker developer, your primary responsibility is to implement a WebSocket server that adheres to this protocol.
|
|
||||||
|
|
||||||
## 1. Connection
|
|
||||||
|
|
||||||
The worker must run a WebSocket server, preferably on port `8000`. The backend system, which is managed by a container orchestration service, will automatically discover and establish a WebSocket connection to your worker.
|
|
||||||
|
|
||||||
Upon a successful connection from the backend, you should begin sending `stateReport` messages as heartbeats.
|
|
||||||
|
|
||||||
## 2. Communication Overview
|
|
||||||
|
|
||||||
Communication is bidirectional and asynchronous. All messages are JSON objects with a `type` field that indicates the message's purpose, and an optional `payload` field containing the data.
|
|
||||||
|
|
||||||
- **Worker -> Backend:** You will send messages to the backend to report status, forward detection events, or request changes to session data.
|
|
||||||
- **Backend -> Worker:** The backend will send commands to you to manage camera subscriptions.
|
|
||||||
|
|
||||||
## 3. Dynamic Configuration via MPTA File
|
|
||||||
|
|
||||||
To enable modularity and dynamic configuration, the backend will send you a URL to a `.mpta` file when it issues a `subscribe` command. This file is a renamed `.zip` archive that contains everything your worker needs to perform its task.
|
|
||||||
|
|
||||||
**Your worker is responsible for:**
|
|
||||||
|
|
||||||
1. Fetching this file from the provided URL.
|
|
||||||
2. Extracting its contents.
|
|
||||||
3. Interpreting the contents to configure its internal pipeline.
|
|
||||||
|
|
||||||
**The contents of the `.mpta` file are entirely up to the user who configures the model in the CMS.** This allows for maximum flexibility. For example, the archive could contain:
|
|
||||||
|
|
||||||
- AI/ML Models: Pre-trained models for libraries like TensorFlow, PyTorch, or ONNX.
|
|
||||||
- Configuration Files: A `config.json` or `pipeline.yaml` that defines a sequence of operations, specifies model paths, or sets detection thresholds.
|
|
||||||
- Scripts: Custom Python scripts for pre-processing or post-processing.
|
|
||||||
- API Integration Details: A JSON file with endpoint information and credentials for interacting with third-party detection services.
|
|
||||||
|
|
||||||
Essentially, the `.mpta` file is a self-contained package that tells your worker _how_ to process the video stream for a given subscription.
|
|
||||||
|
|
||||||
## 4. Messages from Worker to Backend
|
|
||||||
|
|
||||||
These are the messages your worker is expected to send to the backend.
|
|
||||||
|
|
||||||
### 4.1. State Report (Heartbeat)
|
|
||||||
|
|
||||||
This message is crucial for the backend to monitor your worker's health and status, including GPU usage.
|
|
||||||
|
|
||||||
- **Type:** `stateReport`
|
|
||||||
- **When to Send:** Periodically (e.g., every 2 seconds) after a connection is established.
|
|
||||||
|
|
||||||
**Payload:**
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "stateReport",
|
|
||||||
"cpuUsage": 75.5,
|
|
||||||
"memoryUsage": 40.2,
|
|
||||||
"gpuUsage": 60.0,
|
|
||||||
"gpuMemoryUsage": 25.1,
|
|
||||||
"cameraConnections": [
|
|
||||||
{
|
|
||||||
"subscriptionIdentifier": "display-001;cam-001",
|
|
||||||
"modelId": 101,
|
|
||||||
"modelName": "General Object Detection",
|
|
||||||
"online": true,
|
|
||||||
"cropX1": 100,
|
|
||||||
"cropY1": 200,
|
|
||||||
"cropX2": 300,
|
|
||||||
"cropY2": 400
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
> **Note:**
|
|
||||||
>
|
|
||||||
> - `cropX1`, `cropY1`, `cropX2`, `cropY2` (optional, integer) should be included in each camera connection to indicate the crop coordinates for that subscription.
|
|
||||||
|
|
||||||
### 4.2. Image Detection
|
|
||||||
|
|
||||||
Sent when the worker detects a relevant object. The `detection` object should be flat and contain key-value pairs corresponding to the detected attributes.
|
|
||||||
|
|
||||||
- **Type:** `imageDetection`
|
|
||||||
|
|
||||||
**Payload Example:**
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "imageDetection",
|
|
||||||
"subscriptionIdentifier": "display-001;cam-001",
|
|
||||||
"timestamp": "2025-07-14T12:34:56.789Z",
|
|
||||||
"data": {
|
|
||||||
"detection": {
|
|
||||||
"carModel": "Civic",
|
|
||||||
"carBrand": "Honda",
|
|
||||||
"carYear": 2023,
|
|
||||||
"bodyType": "Sedan",
|
|
||||||
"licensePlateText": "ABCD1234",
|
|
||||||
"licensePlateConfidence": 0.95
|
|
||||||
},
|
|
||||||
"modelId": 101,
|
|
||||||
"modelName": "US-LPR-and-Vehicle-ID"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4.3. Patch Session
|
|
||||||
|
|
||||||
> **Note:** Patch messages are only used when the worker can't keep up and needs to retroactively send detections. Normally, detections should be sent in real-time using `imageDetection` messages. Use `patchSession` only to update session data after the fact.
|
|
||||||
|
|
||||||
Allows the worker to request a modification to an active session's data. The `data` payload must be a partial object of the `DisplayPersistentData` structure.
|
|
||||||
|
|
||||||
- **Type:** `patchSession`
|
|
||||||
|
|
||||||
**Payload Example:**
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "patchSession",
|
|
||||||
"sessionId": 12345,
|
|
||||||
"data": {
|
|
||||||
"currentCar": {
|
|
||||||
"carModel": "Civic",
|
|
||||||
"carBrand": "Honda",
|
|
||||||
"licensePlateText": "ABCD1234"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
The backend will respond with a `patchSessionResult` command.
|
|
||||||
|
|
||||||
#### `DisplayPersistentData` Structure
|
|
||||||
|
|
||||||
The `data` object in the `patchSession` message is merged with the existing `DisplayPersistentData` on the backend. Here is its structure:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
interface DisplayPersistentData {
|
|
||||||
progressionStage:
|
|
||||||
| 'welcome'
|
|
||||||
| 'car_fueling'
|
|
||||||
| 'car_waitpayment'
|
|
||||||
| 'car_postpayment'
|
|
||||||
| null;
|
|
||||||
qrCode: string | null;
|
|
||||||
adsPlayback: {
|
|
||||||
playlistSlotOrder: number; // The 'order' of the current slot
|
|
||||||
adsId: number | null;
|
|
||||||
adsUrl: string | null;
|
|
||||||
} | null;
|
|
||||||
currentCar: {
|
|
||||||
carModel?: string;
|
|
||||||
carBrand?: string;
|
|
||||||
carYear?: number;
|
|
||||||
bodyType?: string;
|
|
||||||
licensePlateText?: string;
|
|
||||||
licensePlateType?: string;
|
|
||||||
} | null;
|
|
||||||
fuelPump: {
|
|
||||||
/* FuelPumpData structure */
|
|
||||||
} | null;
|
|
||||||
weatherData: {
|
|
||||||
/* WeatherResponse structure */
|
|
||||||
} | null;
|
|
||||||
sessionId: number | null;
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Patching Behavior
|
|
||||||
|
|
||||||
- The patch is a **deep merge**.
|
|
||||||
- **`undefined`** values are ignored.
|
|
||||||
- **`null`** values will set the corresponding field to `null`.
|
|
||||||
- Nested objects are merged recursively.
|
|
||||||
|
|
||||||
## 5. Commands from Backend to Worker
|
|
||||||
|
|
||||||
These are the commands your worker will receive from the backend.
|
|
||||||
|
|
||||||
### 5.1. Subscribe to Camera
|
|
||||||
|
|
||||||
Instructs the worker to process a camera's RTSP stream using the configuration from the specified `.mpta` file.
|
|
||||||
|
|
||||||
- **Type:** `subscribe`
|
|
||||||
|
|
||||||
**Payload:**
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "subscribe",
|
|
||||||
"payload": {
|
|
||||||
"subscriptionIdentifier": "display-001;cam-002",
|
|
||||||
"rtspUrl": "rtsp://user:pass@host:port/stream",
|
|
||||||
"snapshotUrl": "http://go2rtc/snapshot/1",
|
|
||||||
"snapshotInterval": 5000,
|
|
||||||
"modelUrl": "http://storage/models/us-lpr.mpta",
|
|
||||||
"modelName": "US-LPR-and-Vehicle-ID",
|
|
||||||
"modelId": 102,
|
|
||||||
"cropX1": 100,
|
|
||||||
"cropY1": 200,
|
|
||||||
"cropX2": 300,
|
|
||||||
"cropY2": 400
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
> **Note:**
|
|
||||||
>
|
|
||||||
> - `cropX1`, `cropY1`, `cropX2`, `cropY2` (optional, integer) specify the crop coordinates for the camera stream. These values are configured per display and passed in the subscription payload. If not provided, the worker should process the full frame.
|
|
||||||
>
|
|
||||||
> **Important:**
|
|
||||||
> If multiple displays are bound to the same camera, your worker must ensure that only **one stream** is opened per camera. When you receive multiple subscriptions for the same camera (with different `subscriptionIdentifier` values), you should:
|
|
||||||
>
|
|
||||||
> - Open the RTSP stream **once** for that camera if using RTSP.
|
|
||||||
> - Capture each snapshot only once per cycle, and reuse it for all display subscriptions sharing that camera.
|
|
||||||
> - Capture each frame/image only once per cycle.
|
|
||||||
> - Reuse the same captured image and snapshot for all display subscriptions that share the camera, processing and routing detection results separately for each display as needed.
|
|
||||||
> This avoids unnecessary load and bandwidth usage, and ensures consistent detection results and snapshots across all displays sharing the same camera.
|
|
||||||
|
|
||||||
### 5.2. Unsubscribe from Camera
|
|
||||||
|
|
||||||
Instructs the worker to stop processing a camera's stream.
|
|
||||||
|
|
||||||
- **Type:** `unsubscribe`
|
|
||||||
|
|
||||||
**Payload:**
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "unsubscribe",
|
|
||||||
"payload": {
|
|
||||||
"subscriptionIdentifier": "display-001;cam-002"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5.3. Request State
|
|
||||||
|
|
||||||
Direct request for the worker's current state. Respond with a `stateReport` message.
|
|
||||||
|
|
||||||
- **Type:** `requestState`
|
|
||||||
|
|
||||||
**Payload:**
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "requestState"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5.4. Patch Session Result
|
|
||||||
|
|
||||||
Backend's response to a `patchSession` message.
|
|
||||||
|
|
||||||
- **Type:** `patchSessionResult`
|
|
||||||
|
|
||||||
**Payload:**
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "patchSessionResult",
|
|
||||||
"payload": {
|
|
||||||
"sessionId": 12345,
|
|
||||||
"success": true,
|
|
||||||
"message": "Session updated successfully."
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5.5. Set Session ID
|
|
||||||
|
|
||||||
Allows the backend to instruct the worker to associate a session ID with a subscription. This is useful for linking detection events to a specific session. The session ID can be `null` to indicate no active session.
|
|
||||||
|
|
||||||
- **Type:** `setSessionId`
|
|
||||||
|
|
||||||
**Payload:**
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "setSessionId",
|
|
||||||
"payload": {
|
|
||||||
"displayIdentifier": "display-001",
|
|
||||||
"sessionId": 12345
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Or to clear the session:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "setSessionId",
|
|
||||||
"payload": {
|
|
||||||
"displayIdentifier": "display-001",
|
|
||||||
"sessionId": null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
> **Note:**
|
|
||||||
>
|
|
||||||
> - The worker should store the session ID for the given subscription and use it in subsequent detection or patch messages as appropriate. If `sessionId` is `null`, the worker should treat the subscription as having no active session.
|
|
||||||
|
|
||||||
## Subscription Identifier Format
|
|
||||||
|
|
||||||
The `subscriptionIdentifier` used in all messages is constructed as:
|
|
||||||
|
|
||||||
```
|
|
||||||
displayIdentifier;cameraIdentifier
|
|
||||||
```
|
|
||||||
|
|
||||||
This uniquely identifies a camera subscription for a specific display.
|
|
||||||
|
|
||||||
### Session ID Association
|
|
||||||
|
|
||||||
When the backend sends a `setSessionId` command, it will only provide the `displayIdentifier` (not the full `subscriptionIdentifier`).
|
|
||||||
|
|
||||||
**Worker Responsibility:**
|
|
||||||
|
|
||||||
- The worker must match the `displayIdentifier` to all active subscriptions for that display (i.e., all `subscriptionIdentifier` values that start with `displayIdentifier;`).
|
|
||||||
- The worker should set or clear the session ID for all matching subscriptions.
|
|
||||||
|
|
||||||
## 6. Example Communication Log
|
|
||||||
|
|
||||||
This section shows a typical sequence of messages between the backend and the worker. Patch messages are not included, as they are only used when the worker cannot keep up.
|
|
||||||
|
|
||||||
> **Note:** Unsubscribe is triggered when a user removes a camera or when the node is too heavily loaded and needs rebalancing.
|
|
||||||
|
|
||||||
1. **Connection Established** & **Heartbeat**
|
|
||||||
- **Worker -> Backend**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "stateReport",
|
|
||||||
"cpuUsage": 70.2,
|
|
||||||
"memoryUsage": 38.1,
|
|
||||||
"gpuUsage": 55.0,
|
|
||||||
"gpuMemoryUsage": 20.0,
|
|
||||||
"cameraConnections": []
|
|
||||||
}
|
|
||||||
```
|
|
||||||
2. **Backend Subscribes Camera**
|
|
||||||
- **Backend -> Worker**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "subscribe",
|
|
||||||
"payload": {
|
|
||||||
"subscriptionIdentifier": "display-001;entry-cam-01",
|
|
||||||
"rtspUrl": "rtsp://192.168.1.100/stream1",
|
|
||||||
"modelUrl": "http://storage/models/vehicle-id.mpta",
|
|
||||||
"modelName": "Vehicle Identification",
|
|
||||||
"modelId": 201
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
3. **Worker Acknowledges in Heartbeat**
|
|
||||||
- **Worker -> Backend**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "stateReport",
|
|
||||||
"cpuUsage": 72.5,
|
|
||||||
"memoryUsage": 39.0,
|
|
||||||
"gpuUsage": 57.0,
|
|
||||||
"gpuMemoryUsage": 21.0,
|
|
||||||
"cameraConnections": [
|
|
||||||
{
|
|
||||||
"subscriptionIdentifier": "display-001;entry-cam-01",
|
|
||||||
"modelId": 201,
|
|
||||||
"modelName": "Vehicle Identification",
|
|
||||||
"online": true
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
4. **Worker Detects a Car**
|
|
||||||
- **Worker -> Backend**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "imageDetection",
|
|
||||||
"subscriptionIdentifier": "display-001;entry-cam-01",
|
|
||||||
"timestamp": "2025-07-15T10:00:00.000Z",
|
|
||||||
"data": {
|
|
||||||
"detection": {
|
|
||||||
"carBrand": "Honda",
|
|
||||||
"carModel": "CR-V",
|
|
||||||
"bodyType": "SUV",
|
|
||||||
"licensePlateText": "GEMINI-AI",
|
|
||||||
"licensePlateConfidence": 0.98
|
|
||||||
},
|
|
||||||
"modelId": 201,
|
|
||||||
"modelName": "Vehicle Identification"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
- **Worker -> Backend**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "imageDetection",
|
|
||||||
"subscriptionIdentifier": "display-001;entry-cam-01",
|
|
||||||
"timestamp": "2025-07-15T10:00:01.000Z",
|
|
||||||
"data": {
|
|
||||||
"detection": {
|
|
||||||
"carBrand": "Toyota",
|
|
||||||
"carModel": "Corolla",
|
|
||||||
"bodyType": "Sedan",
|
|
||||||
"licensePlateText": "CMS-1234",
|
|
||||||
"licensePlateConfidence": 0.97
|
|
||||||
},
|
|
||||||
"modelId": 201,
|
|
||||||
"modelName": "Vehicle Identification"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
- **Worker -> Backend**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "imageDetection",
|
|
||||||
"subscriptionIdentifier": "display-001;entry-cam-01",
|
|
||||||
"timestamp": "2025-07-15T10:00:02.000Z",
|
|
||||||
"data": {
|
|
||||||
"detection": {
|
|
||||||
"carBrand": "Ford",
|
|
||||||
"carModel": "Focus",
|
|
||||||
"bodyType": "Hatchback",
|
|
||||||
"licensePlateText": "CMS-5678",
|
|
||||||
"licensePlateConfidence": 0.96
|
|
||||||
},
|
|
||||||
"modelId": 201,
|
|
||||||
"modelName": "Vehicle Identification"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
5. **Backend Unsubscribes Camera**
|
|
||||||
- **Backend -> Worker**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "unsubscribe",
|
|
||||||
"payload": {
|
|
||||||
"subscriptionIdentifier": "display-001;entry-cam-01"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
6. **Worker Acknowledges Unsubscription**
|
|
||||||
- **Worker -> Backend**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "stateReport",
|
|
||||||
"cpuUsage": 68.0,
|
|
||||||
"memoryUsage": 37.0,
|
|
||||||
"gpuUsage": 50.0,
|
|
||||||
"gpuMemoryUsage": 18.0,
|
|
||||||
"cameraConnections": []
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 7. HTTP API: Image Retrieval
|
|
||||||
|
|
||||||
In addition to the WebSocket protocol, the worker exposes an HTTP endpoint for retrieving the latest image frame from a camera.
|
|
||||||
|
|
||||||
### Endpoint
|
|
||||||
|
|
||||||
```
|
|
||||||
GET /camera/{camera_id}/image
|
|
||||||
```
|
|
||||||
|
|
||||||
- **`camera_id`**: The full `subscriptionIdentifier` (e.g., `display-001;cam-001`).
|
|
||||||
|
|
||||||
### Response
|
|
||||||
|
|
||||||
- **Success (200):** Returns the latest JPEG image from the camera stream.
|
|
||||||
|
|
||||||
- `Content-Type: image/jpeg`
|
|
||||||
- Binary JPEG data.
|
|
||||||
|
|
||||||
- **Error (404):** If the camera is not found or no frame is available.
|
|
||||||
|
|
||||||
- JSON error response.
|
|
||||||
|
|
||||||
- **Error (500):** Internal server error.
|
|
||||||
|
|
||||||
### Example Request
|
|
||||||
|
|
||||||
```
|
|
||||||
GET /camera/display-001;cam-001/image
|
|
||||||
```
|
|
||||||
|
|
||||||
### Example Response
|
|
||||||
|
|
||||||
- **Headers:**
|
|
||||||
```
|
|
||||||
Content-Type: image/jpeg
|
|
||||||
```
|
|
||||||
- **Body:** Binary JPEG image.
|
|
||||||
|
|
||||||
### Notes
|
|
||||||
|
|
||||||
- The endpoint returns the most recent frame available for the specified camera subscription.
|
|
||||||
- If multiple displays share the same camera, each subscription has its own buffer; the endpoint uses the buffer for the given `camera_id`.
|
|
||||||
- This API is useful for debugging, monitoring, or integrating with external systems that require direct image access.
|
|
Loading…
Add table
Add a link
Reference in a new issue