Compare commits

...
Sign in to create a new pull request.

47 commits

Author SHA1 Message Date
ziesorx
007a3d48b9 Refactor: bring back original download_mpta
All checks were successful
Build Worker Base and Application Images / check-base-changes (push) Successful in 7s
Build Worker Base and Application Images / build-base (push) Has been skipped
Build Worker Base and Application Images / build-docker (push) Successful in 2m15s
Build Worker Base and Application Images / deploy-stack (push) Successful in 9s
2025-08-13 01:02:39 +07:00
ziesorx
4342eb219b Merge branch 'dev' into dev-pond
All checks were successful
Build Worker Base and Application Images / check-base-changes (push) Successful in 7s
Build Worker Base and Application Images / build-base (push) Has been skipped
Build Worker Base and Application Images / build-docker (push) Successful in 2m3s
Build Worker Base and Application Images / deploy-stack (push) Successful in 11s
2025-08-13 00:20:38 +07:00
ziesorx
c281ca6c6d Feat: pre-evaluation confidence level 2025-08-13 00:07:01 +07:00
0bcf572242 Implement subscription reconciliation logic for improved management of camera streams
All checks were successful
Build Worker Base and Application Images / check-base-changes (push) Successful in 7s
Build Worker Base and Application Images / build-base (push) Has been skipped
Build Worker Base and Application Images / build-docker (push) Successful in 2m10s
Build Worker Base and Application Images / deploy-stack (push) Successful in 14s
2025-08-13 00:06:27 +07:00
ziesorx
c4ab4d6cde Merge branch 'dev' into dev-pond 2025-08-12 23:33:43 +07:00
838028fcb0 Add subscription parameters to stream detection for improved change tracking
All checks were successful
Build Worker Base and Application Images / check-base-changes (push) Successful in 7s
Build Worker Base and Application Images / build-base (push) Has been skipped
Build Worker Base and Application Images / build-docker (push) Successful in 2m16s
Build Worker Base and Application Images / deploy-stack (push) Successful in 12s
2025-08-12 23:29:53 +07:00
ziesorx
0f8b575c90 Feat: connect with cms 2025-08-12 23:18:54 +07:00
ziesorx
9a1496f224 Refactor: change output payload to socket 2025-08-12 20:40:23 +07:00
5f9050e04e Implement code changes to enhance functionality and improve performance
All checks were successful
Build Worker Base and Application Images / check-base-changes (push) Successful in 8s
Build Worker Base and Application Images / build-base (push) Has been skipped
Build Worker Base and Application Images / build-docker (push) Successful in 2m10s
Build Worker Base and Application Images / deploy-stack (push) Successful in 17s
2025-08-12 16:06:37 +07:00
Pongsatorn Kanjanasantisak
aaa90faef9 Add .venv to .gitignore
All checks were successful
Build Worker Base and Application Images / check-base-changes (push) Successful in 8s
Build Worker Base and Application Images / build-base (push) Has been skipped
Build Worker Base and Application Images / build-docker (push) Successful in 2m10s
Build Worker Base and Application Images / deploy-stack (push) Successful in 12s
2025-08-11 23:57:02 +07:00
Pongsatorn Kanjanasantisak
e0a786a46c Merge branch 'dev' of https://git.siwatsystem.com/adsist-cms/python-detector-worker into dev 2025-08-11 23:50:14 +07:00
Pongsatorn Kanjanasantisak
975e6d03dc add feeder/ in gitignore
All checks were successful
Build Worker Base and Application Images / check-base-changes (push) Successful in 7s
Build Worker Base and Application Images / build-base (push) Has been skipped
Build Worker Base and Application Images / build-docker (push) Successful in 2m18s
Build Worker Base and Application Images / deploy-stack (push) Successful in 12s
2025-08-11 01:02:21 +07:00
ziesorx
416db7a33a Revert worker.md
All checks were successful
Build Worker Base and Application Images / check-base-changes (push) Successful in 7s
Build Worker Base and Application Images / build-base (push) Has been skipped
Build Worker Base and Application Images / build-docker (push) Successful in 2m12s
Build Worker Base and Application Images / deploy-stack (push) Successful in 8s
2025-08-10 22:47:16 +07:00
ziesorx
cfc7503a14 Update markdown
All checks were successful
Build Worker Base and Application Images / check-base-changes (push) Successful in 8s
Build Worker Base and Application Images / build-base (push) Has been skipped
Build Worker Base and Application Images / build-docker (push) Successful in 2m15s
Build Worker Base and Application Images / deploy-stack (push) Successful in 8s
2025-08-10 20:51:16 +07:00
1c21f417ce Merge pull request 'dev' (#3) from dev into main
Some checks failed
Build Worker Base and Application Images / check-base-changes (push) Successful in 9s
Build Worker Base and Application Images / build-base (push) Has been skipped
Build Worker Base and Application Images / deploy-stack (push) Has been cancelled
Build Worker Base and Application Images / build-docker (push) Has been cancelled
Reviewed-on: #3
2025-08-10 12:54:06 +00:00
10b2048e94 Merge branch 'main' into dev
Some checks failed
Build Worker Base and Application Images / check-base-changes (push) Successful in 8s
Build Worker Base and Application Images / build-base (push) Has been skipped
Build Worker Base and Application Images / deploy-stack (push) Has been cancelled
Build Worker Base and Application Images / build-docker (push) Has been cancelled
2025-08-10 12:53:59 +00:00
7b9eee1ad9 feat: enhance build workflow to include optional base image rebuild trigger
Some checks failed
Build Worker Base and Application Images / deploy-stack (push) Blocked by required conditions
Build Worker Base and Application Images / check-base-changes (push) Successful in 12s
Build Worker Base and Application Images / build-base (push) Has been skipped
Build Worker Base and Application Images / build-docker (push) Has been cancelled
2025-08-10 19:53:33 +07:00
244ec65c09 feat: update build workflow to include base image checks and build steps
Some checks failed
Build Worker Base and Application Images / check-base-changes (push) Successful in 9s
Build Worker Base and Application Images / build-base (push) Has been skipped
Build Worker Base and Application Images / deploy-stack (push) Has been cancelled
Build Worker Base and Application Images / build-docker (push) Has been cancelled
2025-08-10 19:52:32 +07:00
7afd7f0832 Merge pull request 'feat: update Dockerfile and requirements for ML dependencies; add base image build workflow' (#2) from dev into main
Some checks failed
Build Backend Application and Docker Image / deploy-stack (push) Has been cancelled
Build Backend Application and Docker Image / build-docker (push) Has been cancelled
Build Worker Base Image / build-base (push) Has been cancelled
Reviewed-on: #2
2025-08-10 12:50:54 +00:00
252ef468c9 feat: update Dockerfile and requirements for ML dependencies; add base image build workflow
Some checks failed
Build Backend Application and Docker Image / build-docker (push) Failing after 31s
Build Backend Application and Docker Image / deploy-stack (push) Has been skipped
2025-08-10 19:49:24 +07:00
57a51f3ba3 feat: add staging environment
All checks were successful
Build Backend Application and Docker Image / build-docker (push) Successful in 10m26s
Build Backend Application and Docker Image / deploy-stack (push) Successful in 7s
2025-08-10 18:48:10 +07:00
d35a9ae532 Add deployment steps to build workflow for Docker containers
All checks were successful
Build Backend Application and Docker Image / build-docker (push) Successful in 10m45s
Build Backend Application and Docker Image / deploy-stack (push) Successful in 3m41s
2025-08-10 18:27:39 +07:00
29b97ded2a Merge pull request 'dev' (#1) from dev into main
All checks were successful
Build Backend Application and Docker Image / build-docker (push) Successful in 10m53s
Reviewed-on: #1
2025-08-10 11:07:36 +00:00
ziesorx
c4179b3b08 Done feature 3 fully with postgresql integration 2025-08-10 17:55:02 +07:00
ziesorx
81547311d8 Add confidence check on model 2025-08-10 16:50:39 +07:00
ziesorx
8c429cc8f6 Done brand and body type detection with postgresql integration 2025-08-10 16:23:33 +07:00
ziesorx
18c62a2370 Done features 2 vehicle detect and store image to redis 2025-08-10 15:01:18 +07:00
ziesorx
a1d358aead Done setup and integration redis and postgresql 2025-08-10 13:11:38 +07:00
ziesorx
37c2e2a4d4 update requirements 2025-08-09 15:43:36 +07:00
ziesorx
7f9cc3de8d Fix: 401 and buffer 404 2025-08-06 15:16:16 +07:00
e6716bbe73 feat: add comprehensive documentation for Python Detector Worker; include project overview, architecture, core components, and configuration details
All checks were successful
Build Backend Application and Docker Image / build-docker (push) Successful in 8m47s
2025-07-16 03:24:40 +07:00
f50585f26d feat: enhance Redis action handling; add dynamic context for actions and support for expiration time
All checks were successful
Build Backend Application and Docker Image / build-docker (push) Successful in 9m3s
2025-07-15 00:35:22 +07:00
769371a1a3 feat: integrate Redis support in pipeline execution; add actions for saving images and publishing messages 2025-07-15 00:30:09 +07:00
a1f797f564 feat: add HTTP API for image retrieval from camera; implement endpoint for accessing latest image frames 2025-07-15 00:18:28 +07:00
pixchy-commits
428f7a9671 feat: enhance session management in worker communication protocol; implement session ID handling and crop frame processing 2025-07-14 23:40:19 +07:00
c7bb46e1e3 refactor documentation for worker communication protocol; improve formatting and clarify crop coordinates and session ID handling
All checks were successful
Build Backend Application and Docker Image / build-docker (push) Successful in 8m50s
2025-07-14 11:19:11 +07:00
112ca9325d refactor session ID handling in worker communication protocol; replace subscriptionIdentifier with displayIdentifier
Some checks failed
Build Backend Application and Docker Image / build-docker (push) Failing after 8s
2025-07-14 11:05:17 +07:00
700d3b3efe add subscription identifier format and session ID association details to worker communication protocol
Some checks failed
Build Backend Application and Docker Image / build-docker (push) Failing after 8s
2025-07-14 11:02:05 +07:00
3edcd286fd update session ID handling in worker communication protocol; allow null session ID to indicate no active session
Some checks failed
Build Backend Application and Docker Image / build-docker (push) Has been cancelled
2025-07-14 10:57:06 +07:00
8f32de1510 add session ID handling to worker communication protocol; allow backend to associate session IDs with subscriptions
Some checks failed
Build Backend Application and Docker Image / build-docker (push) Has been cancelled
2025-07-14 10:49:59 +07:00
3c67fa933c add crop coordinates handling in camera stream management; update logging and refactor subscription identifiers
All checks were successful
Build Backend Application and Docker Image / build-docker (push) Successful in 8m38s
2025-07-14 01:46:22 +07:00
39d49ba617 update crop coordinate fields in worker communication protocol to support rectangular cropping
All checks were successful
Build Backend Application and Docker Image / build-docker (push) Successful in 11m5s
2025-07-14 01:01:01 +07:00
8e14897a69 add crop coordinates to state report messages for camera connections
All checks were successful
Build Backend Application and Docker Image / build-docker (push) Successful in 8m43s
2025-07-13 23:59:51 +07:00
1ff6108d08 update worker communication protocol to use subscription identifiers; add crop coordinates for camera streams and clarify handling of multiple subscriptions
Some checks failed
Build Backend Application and Docker Image / build-docker (push) Has been cancelled
2025-07-13 23:58:01 +07:00
162f29ec21 remove license plate confidence from detection messages for simplified reporting
All checks were successful
Build Backend Application and Docker Image / build-docker (push) Successful in 8m49s
2025-07-13 19:43:46 +07:00
5cf1bf08cc add WebSocket communication protocol documentation for detector worker; outline connection, message types, and dynamic configuration
Some checks are pending
Build Backend Application and Docker Image / build-docker (push) Waiting to run
2025-07-13 19:39:17 +07:00
22370e2040 add REST API endpoint for image retrieval; implement error handling and response formatting
All checks were successful
Build Backend Application and Docker Image / build-docker (push) Successful in 8m48s
2025-07-06 21:27:17 +07:00
16 changed files with 5777 additions and 602 deletions

View file

@ -1,13 +1,68 @@
name: Build Backend Application and Docker Image name: Build Worker Base and Application Images
on: on:
push: push:
branches: branches:
- main - main
- dev
workflow_dispatch: workflow_dispatch:
inputs:
force_base_build:
description: 'Force base image build regardless of changes'
required: false
default: 'false'
type: boolean
jobs: jobs:
check-base-changes:
runs-on: ubuntu-latest
outputs:
base-changed: ${{ steps.changes.outputs.base-changed }}
steps:
- name: Checkout code
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Check for base changes
id: changes
run: |
if git diff HEAD^ HEAD --name-only | grep -E "(Dockerfile\.base|requirements\.base\.txt)" > /dev/null; then
echo "base-changed=true" >> $GITHUB_OUTPUT
else
echo "base-changed=false" >> $GITHUB_OUTPUT
fi
build-base:
needs: check-base-changes
if: needs.check-base-changes.outputs.base-changed == 'true' || (github.event_name == 'workflow_dispatch' && github.event.inputs.force_base_build == 'true')
runs-on: ubuntu-latest
permissions:
packages: write
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: git.siwatsystem.com
username: ${{ github.actor }}
password: ${{ secrets.RUNNER_TOKEN }}
- name: Build and push base Docker image
uses: docker/build-push-action@v4
with:
context: .
file: ./Dockerfile.base
push: true
tags: git.siwatsystem.com/adsist-cms/worker-base:latest
build-docker: build-docker:
needs: [check-base-changes, build-base]
if: always() && (needs.build-base.result == 'success' || needs.build-base.result == 'skipped')
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:
packages: write packages: write
@ -31,4 +86,27 @@ jobs:
context: . context: .
file: ./Dockerfile file: ./Dockerfile
push: true push: true
tags: git.siwatsystem.com/adsist-cms/worker:latest tags: git.siwatsystem.com/adsist-cms/worker:${{ github.ref_name == 'main' && 'latest' || 'dev' }}
deploy-stack:
needs: build-docker
runs-on: adsist
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up SSH connection
run: |
mkdir -p ~/.ssh
echo "${{ secrets.DEPLOY_KEY_CMS }}" > ~/.ssh/id_rsa
chmod 600 ~/.ssh/id_rsa
ssh-keyscan -H ${{ vars.DEPLOY_HOST_CMS }} >> ~/.ssh/known_hosts
- name: Deploy stack
run: |
echo "Pulling and starting containers on server..."
if [ "${{ github.ref_name }}" = "main" ]; then
echo "Deploying production stack..."
ssh -i ~/.ssh/id_rsa ${{ vars.DEPLOY_USER_CMS }}@${{ vars.DEPLOY_HOST_CMS }} "cd ~/cms-system-k8s && docker compose -f docker-compose.production.yml pull && docker compose -f docker-compose.production.yml up -d"
else
echo "Deploying staging stack..."
ssh -i ~/.ssh/id_rsa ${{ vars.DEPLOY_USER_CMS }}@${{ vars.DEPLOY_HOST_CMS }} "cd ~/cms-system-k8s && docker compose -f docker-compose.staging.yml pull && docker compose -f docker-compose.staging.yml up -d"
fi

3
.gitignore vendored
View file

@ -10,3 +10,6 @@ mptas
detector_worker.log detector_worker.log
.gitignore .gitignore
no_frame_debug.log no_frame_debug.log
feeder/
.venv/

277
CLAUDE.md Normal file
View file

@ -0,0 +1,277 @@
# Python Detector Worker - CLAUDE.md
## Project Overview
This is a FastAPI-based computer vision detection worker that processes video streams from RTSP/HTTP sources and runs advanced YOLO-based machine learning pipelines for multi-class object detection and parallel classification. The system features comprehensive database integration, Redis support, and hierarchical pipeline execution designed to work within a larger CMS (Content Management System) architecture.
### Key Features
- **Multi-Class Detection**: Simultaneous detection of multiple object classes (e.g., Car + Frontal)
- **Parallel Processing**: Concurrent execution of classification branches using ThreadPoolExecutor
- **Database Integration**: Automatic PostgreSQL schema management and record updates
- **Redis Actions**: Image storage with region cropping and pub/sub messaging
- **Pipeline Synchronization**: Branch coordination with `waitForBranches` functionality
- **Dynamic Field Mapping**: Template-based field resolution for database operations
## Architecture & Technology Stack
- **Framework**: FastAPI with WebSocket support
- **ML/CV**: PyTorch, Ultralytics YOLO, OpenCV
- **Containerization**: Docker (Python 3.13-bookworm base)
- **Data Storage**: Redis integration for action handling + PostgreSQL for persistent storage
- **Database**: Automatic schema management with gas_station_1 database
- **Parallel Processing**: ThreadPoolExecutor for concurrent classification
- **Communication**: WebSocket-based real-time protocol
## Core Components
### Main Application (`app.py`)
- **FastAPI WebSocket server** for real-time communication
- **Multi-camera stream management** with shared stream optimization
- **HTTP REST endpoint** for image retrieval (`/camera/{camera_id}/image`)
- **Threading-based frame readers** for RTSP streams and HTTP snapshots
- **Model loading and inference** using MPTA (Machine Learning Pipeline Archive) format
- **Session management** with display identifier mapping
- **Resource monitoring** (CPU, memory, GPU usage via psutil)
### Pipeline System (`siwatsystem/pympta.py`)
- **MPTA file handling** - ZIP archives containing model configurations
- **Hierarchical pipeline execution** with detection → classification branching
- **Multi-class detection** - Simultaneous detection of multiple classes (Car + Frontal)
- **Parallel processing** - Concurrent classification branches with ThreadPoolExecutor
- **Redis action system** - Image saving with region cropping and message publishing
- **PostgreSQL integration** - Automatic table creation and combined updates
- **Dynamic model loading** with GPU optimization
- **Configurable trigger classes and confidence thresholds**
- **Branch synchronization** - waitForBranches coordination for database updates
### Database System (`siwatsystem/database.py`)
- **DatabaseManager class** for PostgreSQL operations
- **Automatic table creation** with gas_station_1.car_frontal_info schema
- **Combined update operations** with field mapping from branch results
- **Session management** with UUID generation
- **Error handling** and connection management
### Testing & Debugging
- **Protocol test script** (`test_protocol.py`) for WebSocket communication validation
- **Pipeline webcam utility** (`pipeline_webcam.py`) for local testing with visual output
- **RTSP streaming debug tool** (`debug/rtsp_webcam.py`) using GStreamer
## Code Conventions & Patterns
### Logging
- **Structured logging** using Python's logging module
- **File + console output** to `detector_worker.log`
- **Debug level separation** for detailed troubleshooting
- **Context-aware messages** with camera IDs and model information
### Error Handling
- **Graceful failure handling** with retry mechanisms (configurable max_retries)
- **Thread-safe operations** using locks for streams and models
- **WebSocket disconnect handling** with proper cleanup
- **Model loading validation** with detailed error reporting
### Configuration
- **JSON configuration** (`config.json`) for runtime parameters:
- `poll_interval_ms`: Frame processing interval
- `max_streams`: Concurrent stream limit
- `target_fps`: Target frame rate
- `reconnect_interval_sec`: Stream reconnection delay
- `max_retries`: Maximum retry attempts (-1 for unlimited)
### Threading Model
- **Frame reader threads** for each camera stream (RTSP/HTTP)
- **Shared stream optimization** - multiple subscriptions can reuse the same camera stream
- **Async WebSocket handling** with concurrent task management
- **Thread-safe data structures** with proper locking mechanisms
## WebSocket Protocol
### Message Types
- **subscribe**: Start camera stream with model pipeline
- **unsubscribe**: Stop camera stream processing
- **requestState**: Request current worker status
- **setSessionId**: Associate display with session identifier
- **patchSession**: Update session data
- **stateReport**: Periodic heartbeat with system metrics
- **imageDetection**: Detection results with timestamp and model info
### Subscription Format
```json
{
"type": "subscribe",
"payload": {
"subscriptionIdentifier": "display-001;cam-001",
"rtspUrl": "rtsp://...", // OR snapshotUrl
"snapshotUrl": "http://...",
"snapshotInterval": 5000,
"modelUrl": "http://...model.mpta",
"modelId": 101,
"modelName": "Vehicle Detection",
"cropX1": 100, "cropY1": 200,
"cropX2": 300, "cropY2": 400
}
}
```
## Model Pipeline (MPTA) Format
### Enhanced Structure
- **ZIP archive** containing models and configuration
- **pipeline.json** - Main configuration file with Redis + PostgreSQL settings
- **Model files** - YOLO .pt files for detection/classification
- **Multi-model support** - Detection + multiple classification models
### Advanced Pipeline Flow
1. **Multi-class detection stage** - YOLO detection of Car + Frontal simultaneously
2. **Validation stage** - Check for expected classes (flexible matching)
3. **Database initialization** - Create initial record with session_id
4. **Redis actions** - Save cropped frontal images with expiration
5. **Parallel classification** - Concurrent brand and body type classification
6. **Branch synchronization** - Wait for all classification branches to complete
7. **Database update** - Combined update with all classification results
### Enhanced Branch Configuration
```json
{
"modelId": "car_frontal_detection_v1",
"modelFile": "car_frontal_detection_v1.pt",
"multiClass": true,
"expectedClasses": ["Car", "Frontal"],
"triggerClasses": ["Car", "Frontal"],
"minConfidence": 0.8,
"actions": [
{
"type": "redis_save_image",
"region": "Frontal",
"key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
"expire_seconds": 600
}
],
"branches": [
{
"modelId": "car_brand_cls_v1",
"modelFile": "car_brand_cls_v1.pt",
"parallel": true,
"crop": true,
"cropClass": "Frontal",
"triggerClasses": ["Frontal"],
"minConfidence": 0.85
}
],
"parallelActions": [
{
"type": "postgresql_update_combined",
"table": "car_frontal_info",
"key_field": "session_id",
"waitForBranches": ["car_brand_cls_v1", "car_bodytype_cls_v1"],
"fields": {
"car_brand": "{car_brand_cls_v1.brand}",
"car_body_type": "{car_bodytype_cls_v1.body_type}"
}
}
]
}
```
## Stream Management
### Shared Streams
- Multiple subscriptions can share the same camera URL
- Reference counting prevents premature stream termination
- Automatic cleanup when last subscription ends
### Frame Processing
- **Queue-based buffering** with single frame capacity (latest frame only)
- **Configurable polling interval** based on target FPS
- **Automatic reconnection** with exponential backoff
## Development & Testing
### Local Development
```bash
# Install dependencies
pip install -r requirements.txt
# Run the worker
python app.py
# Test protocol compliance
python test_protocol.py
# Test pipeline with webcam
python pipeline_webcam.py --mpta-file path/to/model.mpta --video 0
```
### Docker Deployment
```bash
# Build container
docker build -t detector-worker .
# Run with volume mounts for models
docker run -p 8000:8000 -v ./models:/app/models detector-worker
```
### Testing Commands
- **Protocol testing**: `python test_protocol.py`
- **Pipeline validation**: `python pipeline_webcam.py --mpta-file <path> --video 0`
- **RTSP debugging**: `python debug/rtsp_webcam.py`
## Dependencies
- **fastapi[standard]**: Web framework with WebSocket support
- **uvicorn**: ASGI server
- **torch, torchvision**: PyTorch for ML inference
- **ultralytics**: YOLO implementation
- **opencv-python**: Computer vision operations
- **websockets**: WebSocket client/server
- **redis**: Redis client for action execution
- **psycopg2-binary**: PostgreSQL database adapter
- **scipy**: Scientific computing for advanced algorithms
- **filterpy**: Kalman filtering and state estimation
## Security Considerations
- Model files are loaded from trusted sources only
- Redis connections use authentication when configured
- WebSocket connections handle disconnects gracefully
- Resource usage is monitored to prevent DoS
## Database Integration
### Schema Management
The system automatically creates and manages PostgreSQL tables:
```sql
CREATE TABLE IF NOT EXISTS gas_station_1.car_frontal_info (
display_id VARCHAR(255),
captured_timestamp VARCHAR(255),
session_id VARCHAR(255) PRIMARY KEY,
license_character VARCHAR(255) DEFAULT NULL,
license_type VARCHAR(255) DEFAULT 'No model available',
car_brand VARCHAR(255) DEFAULT NULL,
car_model VARCHAR(255) DEFAULT NULL,
car_body_type VARCHAR(255) DEFAULT NULL,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
```
### Workflow
1. **Detection**: When both "Car" and "Frontal" are detected, create initial database record with UUID session_id
2. **Redis Storage**: Save cropped frontal image to Redis with session_id in key
3. **Parallel Processing**: Run brand and body type classification concurrently
4. **Synchronization**: Wait for all branches to complete using `waitForBranches`
5. **Database Update**: Update record with combined classification results using field mapping
### Field Mapping
Templates like `{car_brand_cls_v1.brand}` are resolved to actual classification results:
- `car_brand_cls_v1.brand` → "Honda"
- `car_bodytype_cls_v1.body_type` → "Sedan"
## Performance Optimizations
- GPU acceleration when CUDA is available
- Shared camera streams reduce resource usage
- Frame queue optimization (single latest frame)
- Model caching across subscriptions
- Trigger class filtering for faster inference
- Parallel processing with ThreadPoolExecutor for classification branches
- Multi-class detection reduces inference passes
- Region-based cropping minimizes processing overhead
- Database connection pooling and prepared statements
- Redis image storage with automatic expiration

View file

@ -1,19 +1,11 @@
# Use the official Python image from the Docker Hub # Use our pre-built base image with ML dependencies
FROM python:3.13-bookworm FROM git.siwatsystem.com/adsist-cms/worker-base:latest
# Set the working directory in the container # Copy and install application requirements (frequently changing dependencies)
WORKDIR /app
# Copy the requirements file into the container at /app
COPY requirements.txt . COPY requirements.txt .
# Update apt, install libgl1, and clear apt cache
RUN apt update && apt install -y libgl1 && rm -rf /var/lib/apt/lists/*
# Install any dependencies specified in requirements.txt
RUN pip install --no-cache-dir -r requirements.txt RUN pip install --no-cache-dir -r requirements.txt
# Copy the rest of the application code into the container at /app # Copy the application code
COPY . . COPY . .
# Run the application # Run the application

15
Dockerfile.base Normal file
View file

@ -0,0 +1,15 @@
# Base image with all ML dependencies
FROM python:3.13-bookworm
# Install system dependencies
RUN apt update && apt install -y libgl1 && rm -rf /var/lib/apt/lists/*
# Copy and install base requirements (ML dependencies that rarely change)
COPY requirements.base.txt .
RUN pip install --no-cache-dir -r requirements.base.txt
# Set working directory
WORKDIR /app
# This base image will be reused for all worker builds
CMD ["python3", "-m", "fastapi", "run", "--host", "0.0.0.0", "--port", "8000"]

787
app.py
View file

@ -13,9 +13,16 @@ import requests
import asyncio import asyncio
import psutil import psutil
import zipfile import zipfile
import ssl
import urllib3
import subprocess
import tempfile
from urllib.parse import urlparse from urllib.parse import urlparse
from fastapi import FastAPI, WebSocket from requests.adapters import HTTPAdapter
from urllib3.util.ssl_ import create_urllib3_context
from fastapi import FastAPI, WebSocket, HTTPException
from fastapi.websockets import WebSocketDisconnect from fastapi.websockets import WebSocketDisconnect
from fastapi.responses import Response
from websockets.exceptions import ConnectionClosedError from websockets.exceptions import ConnectionClosedError
from ultralytics import YOLO from ultralytics import YOLO
@ -28,6 +35,14 @@ app = FastAPI()
# "models" now holds a nested dict: { camera_id: { modelId: model_tree } } # "models" now holds a nested dict: { camera_id: { modelId: model_tree } }
models: Dict[str, Dict[str, Any]] = {} models: Dict[str, Dict[str, Any]] = {}
streams: Dict[str, Dict[str, Any]] = {} streams: Dict[str, Dict[str, Any]] = {}
# Store session IDs per display
session_ids: Dict[str, int] = {}
# Track shared camera streams by camera URL
camera_streams: Dict[str, Dict[str, Any]] = {}
# Map subscriptions to their camera URL
subscription_to_camera: Dict[str, str] = {}
# Store latest frames for REST API access (separate from processing buffer)
latest_frames: Dict[str, Any] = {}
with open("config.json", "r") as f: with open("config.json", "r") as f:
config = json.load(f) config = json.load(f)
@ -102,25 +117,115 @@ def download_mpta(url: str, dest_path: str) -> str:
# Add helper to fetch snapshot image from HTTP/HTTPS URL # Add helper to fetch snapshot image from HTTP/HTTPS URL
def fetch_snapshot(url: str): def fetch_snapshot(url: str):
try: try:
response = requests.get(url, timeout=10) from requests.auth import HTTPBasicAuth, HTTPDigestAuth
# Parse URL to extract credentials
parsed = urlparse(url)
# Prepare headers - some cameras require User-Agent
headers = {
'User-Agent': 'Mozilla/5.0 (compatible; DetectorWorker/1.0)'
}
# Reconstruct URL without credentials
clean_url = f"{parsed.scheme}://{parsed.hostname}"
if parsed.port:
clean_url += f":{parsed.port}"
clean_url += parsed.path
if parsed.query:
clean_url += f"?{parsed.query}"
auth = None
if parsed.username and parsed.password:
# Try HTTP Digest authentication first (common for IP cameras)
try:
auth = HTTPDigestAuth(parsed.username, parsed.password)
response = requests.get(clean_url, auth=auth, headers=headers, timeout=10)
if response.status_code == 200:
logger.debug(f"Successfully authenticated using HTTP Digest for {clean_url}")
elif response.status_code == 401:
# If Digest fails, try Basic auth
logger.debug(f"HTTP Digest failed, trying Basic auth for {clean_url}")
auth = HTTPBasicAuth(parsed.username, parsed.password)
response = requests.get(clean_url, auth=auth, headers=headers, timeout=10)
if response.status_code == 200:
logger.debug(f"Successfully authenticated using HTTP Basic for {clean_url}")
except Exception as auth_error:
logger.debug(f"Authentication setup error: {auth_error}")
# Fallback to original URL with embedded credentials
response = requests.get(url, headers=headers, timeout=10)
else:
# No credentials in URL, make request as-is
response = requests.get(url, headers=headers, timeout=10)
if response.status_code == 200: if response.status_code == 200:
# Convert response content to numpy array # Convert response content to numpy array
nparr = np.frombuffer(response.content, np.uint8) nparr = np.frombuffer(response.content, np.uint8)
# Decode image # Decode image
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if frame is not None: if frame is not None:
logger.debug(f"Successfully fetched snapshot from {url}, shape: {frame.shape}") logger.debug(f"Successfully fetched snapshot from {clean_url}, shape: {frame.shape}")
return frame return frame
else: else:
logger.error(f"Failed to decode image from snapshot URL: {url}") logger.error(f"Failed to decode image from snapshot URL: {clean_url}")
return None return None
else: else:
logger.error(f"Failed to fetch snapshot (status code {response.status_code}): {url}") logger.error(f"Failed to fetch snapshot (status code {response.status_code}): {clean_url}")
return None return None
except Exception as e: except Exception as e:
logger.error(f"Exception fetching snapshot from {url}: {str(e)}") logger.error(f"Exception fetching snapshot from {url}: {str(e)}")
return None return None
# Helper to get crop coordinates from stream
def get_crop_coords(stream):
return {
"cropX1": stream.get("cropX1"),
"cropY1": stream.get("cropY1"),
"cropX2": stream.get("cropX2"),
"cropY2": stream.get("cropY2")
}
####################################################
# REST API endpoint for image retrieval
####################################################
@app.get("/camera/{camera_id}/image")
async def get_camera_image(camera_id: str):
"""
Get the current frame from a camera as JPEG image
"""
try:
# URL decode the camera_id to handle encoded characters like %3B for semicolon
from urllib.parse import unquote
original_camera_id = camera_id
camera_id = unquote(camera_id)
logger.debug(f"REST API request: original='{original_camera_id}', decoded='{camera_id}'")
with streams_lock:
if camera_id not in streams:
logger.warning(f"Camera ID '{camera_id}' not found in streams. Current streams: {list(streams.keys())}")
raise HTTPException(status_code=404, detail=f"Camera {camera_id} not found or not active")
# Check if we have a cached frame for this camera
if camera_id not in latest_frames:
logger.warning(f"No cached frame available for camera '{camera_id}'.")
raise HTTPException(status_code=404, detail=f"No frame available for camera {camera_id}")
frame = latest_frames[camera_id]
logger.debug(f"Retrieved cached frame for camera '{camera_id}', frame shape: {frame.shape}")
# Encode frame as JPEG
success, buffer_img = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
if not success:
raise HTTPException(status_code=500, detail="Failed to encode image as JPEG")
# Return image as binary response
return Response(content=buffer_img.tobytes(), media_type="image/jpeg")
except HTTPException:
raise
except Exception as e:
logger.error(f"Error retrieving image for camera {camera_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
#################################################### ####################################################
# Detection and frame processing functions # Detection and frame processing functions
#################################################### ####################################################
@ -131,73 +236,118 @@ async def detect(websocket: WebSocket):
async def handle_detection(camera_id, stream, frame, websocket, model_tree, persistent_data): async def handle_detection(camera_id, stream, frame, websocket, model_tree, persistent_data):
try: try:
# Apply crop if specified
cropped_frame = frame
if all(coord is not None for coord in [stream.get("cropX1"), stream.get("cropY1"), stream.get("cropX2"), stream.get("cropY2")]):
cropX1, cropY1, cropX2, cropY2 = stream["cropX1"], stream["cropY1"], stream["cropX2"], stream["cropY2"]
cropped_frame = frame[cropY1:cropY2, cropX1:cropX2]
logger.debug(f"Applied crop coordinates ({cropX1}, {cropY1}, {cropX2}, {cropY2}) to frame for camera {camera_id}")
logger.debug(f"Processing frame for camera {camera_id} with model {stream['modelId']}") logger.debug(f"Processing frame for camera {camera_id} with model {stream['modelId']}")
start_time = time.time() start_time = time.time()
detection_result = run_pipeline(frame, model_tree)
# Extract display identifier for pipeline context
subscription_parts = stream["subscriptionIdentifier"].split(';')
display_identifier = subscription_parts[0] if subscription_parts else None
# Create context for pipeline execution (session_id will be generated by pipeline)
pipeline_context = {
"camera_id": camera_id,
"display_id": display_identifier
}
detection_result = run_pipeline(cropped_frame, model_tree, context=pipeline_context)
process_time = (time.time() - start_time) * 1000 process_time = (time.time() - start_time) * 1000
logger.debug(f"Detection for camera {camera_id} completed in {process_time:.2f}ms") logger.debug(f"Detection for camera {camera_id} completed in {process_time:.2f}ms")
# Log the raw detection result for debugging # Log the raw detection result for debugging
logger.debug(f"Raw detection result for camera {camera_id}:\n{json.dumps(detection_result, indent=2, default=str)}") logger.debug(f"Raw detection result for camera {camera_id}:\n{json.dumps(detection_result, indent=2, default=str)}")
# Direct class result (no detections/classifications structure) # Extract session_id from pipeline result (generated during database record creation)
if detection_result and isinstance(detection_result, dict) and "class" in detection_result and "confidence" in detection_result: session_id = None
highest_confidence_detection = { if detection_result and isinstance(detection_result, dict):
"class": detection_result.get("class", "none"), # Check if pipeline generated a session_id (happens when Car+Frontal detected together)
"confidence": detection_result.get("confidence", 1.0), if "session_id" in detection_result:
"box": [0, 0, 0, 0] # Empty bounding box for classifications session_id = detection_result["session_id"]
} logger.debug(f"Extracted session_id from pipeline result: {session_id}")
# Handle case when no detections found or result is empty
elif not detection_result or not detection_result.get("detections"):
# Check if we have classification results
if detection_result and detection_result.get("classifications"):
# Get the highest confidence classification
classifications = detection_result.get("classifications", [])
highest_confidence_class = max(classifications, key=lambda x: x.get("confidence", 0)) if classifications else None
if highest_confidence_class: # Process detection result - run_pipeline returns the primary detection directly
highest_confidence_detection = { if detection_result and isinstance(detection_result, dict) and "class" in detection_result:
"class": highest_confidence_class.get("class", "none"), highest_confidence_detection = detection_result
"confidence": highest_confidence_class.get("confidence", 1.0),
"box": [0, 0, 0, 0] # Empty bounding box for classifications
}
else:
highest_confidence_detection = {
"class": "none",
"confidence": 1.0,
"box": [0, 0, 0, 0]
}
else:
highest_confidence_detection = {
"class": "none",
"confidence": 1.0,
"box": [0, 0, 0, 0]
}
else: else:
# Find detection with highest confidence # No detection found
detections = detection_result.get("detections", []) highest_confidence_detection = {
highest_confidence_detection = max(detections, key=lambda x: x.get("confidence", 0)) if detections else {
"class": "none", "class": "none",
"confidence": 1.0, "confidence": 1.0,
"box": [0, 0, 0, 0] "bbox": [0, 0, 0, 0],
"branch_results": {}
} }
# Convert detection format to match backend expectations exactly as in worker.md section 4.2
detection_dict = {
"carModel": None,
"carBrand": None,
"carYear": None,
"bodyType": None,
"licensePlateText": None,
"licensePlateConfidence": None
}
# Extract and process branch results from parallel classification
branch_results = highest_confidence_detection.get("branch_results", {})
if branch_results:
logger.debug(f"Processing branch results: {branch_results}")
# Transform branch results into backend-expected detection attributes
for branch_id, branch_data in branch_results.items():
if isinstance(branch_data, dict):
logger.debug(f"Processing branch {branch_id}: {branch_data}")
# Map common classification fields to backend-expected names
if "brand" in branch_data:
detection_dict["carBrand"] = branch_data["brand"]
if "body_type" in branch_data:
detection_dict["bodyType"] = branch_data["body_type"]
if "class" in branch_data:
class_name = branch_data["class"]
# Map based on branch/model type
if "brand" in branch_id.lower():
detection_dict["carBrand"] = class_name
elif "bodytype" in branch_id.lower() or "body" in branch_id.lower():
detection_dict["bodyType"] = class_name
logger.info(f"Detection payload after branch processing: {detection_dict}")
else:
logger.debug("No branch results found in detection result")
detection_data = { detection_data = {
"type": "imageDetection", "type": "imageDetection",
"cameraIdentifier": camera_id, "subscriptionIdentifier": stream["subscriptionIdentifier"],
"timestamp": time.time(), "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S.%fZ", time.gmtime()),
"data": { "data": {
"detection": highest_confidence_detection, # Send only the highest confidence detection "detection": detection_dict,
"modelId": stream["modelId"], "modelId": stream["modelId"],
"modelName": stream["modelName"] "modelName": stream["modelName"]
} }
} }
if highest_confidence_detection["class"] != "none": # Add session ID if available (generated by pipeline when Car+Frontal detected)
logger.info(f"Camera {camera_id}: Detected {highest_confidence_detection['class']} with confidence {highest_confidence_detection['confidence']:.2f} using model {stream['modelName']}") if session_id is not None:
detection_data["sessionId"] = session_id
logger.debug(f"Added session_id to WebSocket response: {session_id}")
if highest_confidence_detection.get("class") != "none":
confidence = highest_confidence_detection.get("confidence", 0.0)
logger.info(f"Camera {camera_id}: Detected {highest_confidence_detection['class']} with confidence {confidence:.2f} using model {stream['modelName']}")
# Log session ID if available
if session_id:
logger.debug(f"Detection associated with session ID: {session_id}")
await websocket.send_json(detection_data) await websocket.send_json(detection_data)
logger.debug(f"Sent detection data to client for camera {camera_id}:\n{json.dumps(detection_data, indent=2)}") logger.debug(f"Sent detection data to client for camera {camera_id}")
logger.debug(f"Sent this detection data: {detection_data}")
return persistent_data return persistent_data
except Exception as e: except Exception as e:
logger.error(f"Error in handle_detection for camera {camera_id}: {str(e)}", exc_info=True) logger.error(f"Error in handle_detection for camera {camera_id}: {str(e)}", exc_info=True)
@ -264,12 +414,11 @@ async def detect(websocket: WebSocket):
if not buffer.empty(): if not buffer.empty():
try: try:
buffer.get_nowait() buffer.get_nowait()
logger.debug(f"Removed old frame from buffer for camera {camera_id}") logger.debug(f"[frame_reader] Removed old frame from buffer for camera {camera_id}")
except queue.Empty: except queue.Empty:
pass pass
buffer.put(frame) buffer.put(frame)
logger.debug(f"Added new frame to buffer for camera {camera_id}") logger.debug(f"[frame_reader] Added new frame to buffer for camera {camera_id}. Buffer size: {buffer.qsize()}")
# Short sleep to avoid CPU overuse # Short sleep to avoid CPU overuse
time.sleep(0.01) time.sleep(0.01)
@ -340,12 +489,11 @@ async def detect(websocket: WebSocket):
if not buffer.empty(): if not buffer.empty():
try: try:
buffer.get_nowait() buffer.get_nowait()
logger.debug(f"Removed old snapshot from buffer for camera {camera_id}") logger.debug(f"[snapshot_reader] Removed old snapshot from buffer for camera {camera_id}")
except queue.Empty: except queue.Empty:
pass pass
buffer.put(frame) buffer.put(frame)
logger.debug(f"Added new snapshot to buffer for camera {camera_id}") logger.debug(f"[snapshot_reader] Added new snapshot to buffer for camera {camera_id}. Buffer size: {buffer.qsize()}")
# Wait for the specified interval # Wait for the specified interval
elapsed = time.time() - start_time elapsed = time.time() - start_time
@ -365,6 +513,199 @@ async def detect(websocket: WebSocket):
finally: finally:
logger.info(f"Snapshot reader thread for camera {camera_id} is exiting") logger.info(f"Snapshot reader thread for camera {camera_id} is exiting")
async def reconcile_subscriptions(desired_subscriptions, websocket):
"""
Declarative reconciliation: Compare desired vs current subscriptions and make changes
"""
logger.info(f"Reconciling subscriptions: {len(desired_subscriptions)} desired")
with streams_lock:
# Get current subscriptions
current_subscription_ids = set(streams.keys())
desired_subscription_ids = set(sub["subscriptionIdentifier"] for sub in desired_subscriptions)
# Find what to add and remove
to_add = desired_subscription_ids - current_subscription_ids
to_remove = current_subscription_ids - desired_subscription_ids
to_check_for_changes = current_subscription_ids & desired_subscription_ids
logger.info(f"Reconciliation: {len(to_add)} to add, {len(to_remove)} to remove, {len(to_check_for_changes)} to check for changes")
# Remove subscriptions that are no longer wanted
for subscription_id in to_remove:
await unsubscribe_internal(subscription_id)
# Check existing subscriptions for parameter changes
for subscription_id in to_check_for_changes:
desired_sub = next(sub for sub in desired_subscriptions if sub["subscriptionIdentifier"] == subscription_id)
current_stream = streams[subscription_id]
# Check if parameters changed
if has_subscription_changed(desired_sub, current_stream):
logger.info(f"Parameters changed for {subscription_id}, resubscribing")
await unsubscribe_internal(subscription_id)
await subscribe_internal(desired_sub, websocket)
# Add new subscriptions
for subscription_id in to_add:
desired_sub = next(sub for sub in desired_subscriptions if sub["subscriptionIdentifier"] == subscription_id)
await subscribe_internal(desired_sub, websocket)
def has_subscription_changed(desired_sub, current_stream):
"""Check if subscription parameters have changed"""
return (
desired_sub.get("rtspUrl") != current_stream.get("rtsp_url") or
desired_sub.get("snapshotUrl") != current_stream.get("snapshot_url") or
desired_sub.get("snapshotInterval") != current_stream.get("snapshot_interval") or
desired_sub.get("cropX1") != current_stream.get("cropX1") or
desired_sub.get("cropY1") != current_stream.get("cropY1") or
desired_sub.get("cropX2") != current_stream.get("cropX2") or
desired_sub.get("cropY2") != current_stream.get("cropY2") or
desired_sub.get("modelId") != current_stream.get("modelId") or
desired_sub.get("modelName") != current_stream.get("modelName")
)
async def subscribe_internal(subscription, websocket):
"""Internal subscription logic extracted from original subscribe handler"""
subscriptionIdentifier = subscription.get("subscriptionIdentifier")
rtsp_url = subscription.get("rtspUrl")
snapshot_url = subscription.get("snapshotUrl")
snapshot_interval = subscription.get("snapshotInterval")
model_url = subscription.get("modelUrl")
modelId = subscription.get("modelId")
modelName = subscription.get("modelName")
cropX1 = subscription.get("cropX1")
cropY1 = subscription.get("cropY1")
cropX2 = subscription.get("cropX2")
cropY2 = subscription.get("cropY2")
# Extract camera_id from subscriptionIdentifier
parts = subscriptionIdentifier.split(';')
if len(parts) != 2:
logger.error(f"Invalid subscriptionIdentifier format: {subscriptionIdentifier}")
return
display_identifier, camera_identifier = parts
camera_id = subscriptionIdentifier
# Load model if needed
if model_url:
with models_lock:
if (camera_id not in models) or (modelId not in models[camera_id]):
logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}")
extraction_dir = os.path.join("models", camera_identifier, str(modelId))
os.makedirs(extraction_dir, exist_ok=True)
# Handle model loading (same as original)
parsed = urlparse(model_url)
if parsed.scheme in ("http", "https"):
filename = os.path.basename(parsed.path) or f"model_{modelId}.mpta"
local_mpta = os.path.join(extraction_dir, filename)
local_path = download_mpta(model_url, local_mpta)
if not local_path:
logger.error(f"Failed to download model from {model_url}")
return
model_tree = load_pipeline_from_zip(local_path, extraction_dir)
else:
if not os.path.exists(model_url):
logger.error(f"Model file not found: {model_url}")
return
model_tree = load_pipeline_from_zip(model_url, extraction_dir)
if model_tree is None:
logger.error(f"Failed to load model {modelId}")
return
if camera_id not in models:
models[camera_id] = {}
models[camera_id][modelId] = model_tree
# Create stream (same logic as original)
if camera_id and (rtsp_url or snapshot_url) and len(streams) < max_streams:
camera_url = snapshot_url if snapshot_url else rtsp_url
# Check if we already have a stream for this camera URL
shared_stream = camera_streams.get(camera_url)
if shared_stream:
# Reuse existing stream
buffer = shared_stream["buffer"]
stop_event = shared_stream["stop_event"]
thread = shared_stream["thread"]
mode = shared_stream["mode"]
shared_stream["ref_count"] = shared_stream.get("ref_count", 0) + 1
else:
# Create new stream
buffer = queue.Queue(maxsize=1)
stop_event = threading.Event()
if snapshot_url and snapshot_interval:
thread = threading.Thread(target=snapshot_reader, args=(camera_id, snapshot_url, snapshot_interval, buffer, stop_event))
thread.daemon = True
thread.start()
mode = "snapshot"
shared_stream = {
"buffer": buffer, "thread": thread, "stop_event": stop_event,
"mode": mode, "url": snapshot_url, "snapshot_interval": snapshot_interval, "ref_count": 1
}
camera_streams[camera_url] = shared_stream
elif rtsp_url:
cap = cv2.VideoCapture(rtsp_url)
if not cap.isOpened():
logger.error(f"Failed to open RTSP stream for camera {camera_id}")
return
thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event))
thread.daemon = True
thread.start()
mode = "rtsp"
shared_stream = {
"buffer": buffer, "thread": thread, "stop_event": stop_event,
"mode": mode, "url": rtsp_url, "cap": cap, "ref_count": 1
}
camera_streams[camera_url] = shared_stream
else:
logger.error(f"No valid URL provided for camera {camera_id}")
return
# Create stream info
stream_info = {
"buffer": buffer, "thread": thread, "stop_event": stop_event,
"modelId": modelId, "modelName": modelName, "subscriptionIdentifier": subscriptionIdentifier,
"cropX1": cropX1, "cropY1": cropY1, "cropX2": cropX2, "cropY2": cropY2,
"mode": mode, "camera_url": camera_url, "modelUrl": model_url
}
if mode == "snapshot":
stream_info["snapshot_url"] = snapshot_url
stream_info["snapshot_interval"] = snapshot_interval
elif mode == "rtsp":
stream_info["rtsp_url"] = rtsp_url
stream_info["cap"] = shared_stream["cap"]
streams[camera_id] = stream_info
subscription_to_camera[camera_id] = camera_url
logger.info(f"Subscribed to camera {camera_id}")
async def unsubscribe_internal(subscription_id):
"""Internal unsubscription logic"""
if subscription_id in streams:
stream = streams.pop(subscription_id)
camera_url = subscription_to_camera.pop(subscription_id, None)
if camera_url and camera_url in camera_streams:
shared_stream = camera_streams[camera_url]
shared_stream["ref_count"] -= 1
if shared_stream["ref_count"] <= 0:
shared_stream["stop_event"].set()
shared_stream["thread"].join()
if "cap" in shared_stream:
shared_stream["cap"].release()
del camera_streams[camera_url]
latest_frames.pop(subscription_id, None)
logger.info(f"Unsubscribed from camera {subscription_id}")
async def process_streams(): async def process_streams():
logger.info("Started processing streams") logger.info("Started processing streams")
try: try:
@ -386,6 +727,10 @@ async def detect(websocket: WebSocket):
logger.debug(f"Got frame from buffer for camera {camera_id}") logger.debug(f"Got frame from buffer for camera {camera_id}")
frame = buffer.get() frame = buffer.get()
# Cache the frame for REST API access
latest_frames[camera_id] = frame.copy()
logger.debug(f"Cached frame for REST API access for camera {camera_id}")
with models_lock: with models_lock:
model_tree = models.get(camera_id, {}).get(stream["modelId"]) model_tree = models.get(camera_id, {}).get(stream["modelId"])
if not model_tree: if not model_tree:
@ -416,18 +761,23 @@ async def detect(websocket: WebSocket):
cpu_usage = psutil.cpu_percent() cpu_usage = psutil.cpu_percent()
memory_usage = psutil.virtual_memory().percent memory_usage = psutil.virtual_memory().percent
if torch.cuda.is_available(): if torch.cuda.is_available():
gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # MB gpu_usage = torch.cuda.utilization() if hasattr(torch.cuda, 'utilization') else None
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # MB gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2)
else: else:
gpu_usage = None gpu_usage = None
gpu_memory_usage = None gpu_memory_usage = None
camera_connections = [ camera_connections = [
{ {
"cameraIdentifier": camera_id, "subscriptionIdentifier": stream["subscriptionIdentifier"],
"modelId": stream["modelId"], "modelId": stream["modelId"],
"modelName": stream["modelName"], "modelName": stream["modelName"],
"online": True "online": True,
# Include all subscription parameters for proper change detection
"rtspUrl": stream.get("rtsp_url"),
"snapshotUrl": stream.get("snapshot_url"),
"snapshotInterval": stream.get("snapshot_interval"),
**{k: v for k, v in get_crop_coords(stream).items() if v is not None}
} }
for camera_id, stream in streams.items() for camera_id, stream in streams.items()
] ]
@ -455,58 +805,87 @@ async def detect(websocket: WebSocket):
data = json.loads(msg) data = json.loads(msg)
msg_type = data.get("type") msg_type = data.get("type")
if msg_type == "subscribe": if msg_type == "setSubscriptionList":
payload = data.get("payload", {}) # Declarative approach: Backend sends list of subscriptions this worker should have
camera_id = payload.get("cameraIdentifier") desired_subscriptions = data.get("subscriptions", [])
rtsp_url = payload.get("rtspUrl") logger.info(f"Received subscription list with {len(desired_subscriptions)} subscriptions")
snapshot_url = payload.get("snapshotUrl")
snapshot_interval = payload.get("snapshotInterval") # in milliseconds
model_url = payload.get("modelUrl") # may be remote or local
modelId = payload.get("modelId")
modelName = payload.get("modelName")
await reconcile_subscriptions(desired_subscriptions, websocket)
elif msg_type == "subscribe":
# Legacy support - convert single subscription to list
payload = data.get("payload", {})
await reconcile_subscriptions([payload], websocket)
elif msg_type == "unsubscribe":
# Legacy support - remove subscription
payload = data.get("payload", {})
subscriptionIdentifier = payload.get("subscriptionIdentifier")
# Remove from current subscriptions and reconcile
current_subs = []
with streams_lock:
for camera_id, stream in streams.items():
if stream["subscriptionIdentifier"] != subscriptionIdentifier:
# Convert stream back to subscription format
current_subs.append({
"subscriptionIdentifier": stream["subscriptionIdentifier"],
"rtspUrl": stream.get("rtsp_url"),
"snapshotUrl": stream.get("snapshot_url"),
"snapshotInterval": stream.get("snapshot_interval"),
"modelId": stream["modelId"],
"modelName": stream["modelName"],
"modelUrl": stream.get("modelUrl", ""),
"cropX1": stream.get("cropX1"),
"cropY1": stream.get("cropY1"),
"cropX2": stream.get("cropX2"),
"cropY2": stream.get("cropY2")
})
await reconcile_subscriptions(current_subs, websocket)
elif msg_type == "old_subscribe_logic_removed":
if model_url: if model_url:
with models_lock: with models_lock:
if (camera_id not in models) or (modelId not in models[camera_id]): if (camera_id not in models) or (modelId not in models[camera_id]):
logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}") logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}")
extraction_dir = os.path.join("models", camera_id, str(modelId)) extraction_dir = os.path.join("models", camera_identifier, str(modelId))
os.makedirs(extraction_dir, exist_ok=True) os.makedirs(extraction_dir, exist_ok=True)
# If model_url is remote, download it first. # If model_url is remote, download it first.
parsed = urlparse(model_url) parsed = urlparse(model_url)
if parsed.scheme in ("http", "https"): if parsed.scheme in ("http", "https"):
logger.info(f"Downloading remote model from {model_url}") logger.info(f"Downloading remote .mpta file from {model_url}")
local_mpta = os.path.join(extraction_dir, os.path.basename(parsed.path)) filename = os.path.basename(parsed.path) or f"model_{modelId}.mpta"
local_mpta = os.path.join(extraction_dir, filename)
logger.debug(f"Download destination: {local_mpta}") logger.debug(f"Download destination: {local_mpta}")
local_path = download_mpta(model_url, local_mpta) local_path = download_mpta(model_url, local_mpta)
if not local_path: if not local_path:
logger.error(f"Failed to download the remote mpta file from {model_url}") logger.error(f"Failed to download the remote .mpta file from {model_url}")
error_response = { error_response = {
"type": "error", "type": "error",
"cameraIdentifier": camera_id, "subscriptionIdentifier": subscriptionIdentifier,
"error": f"Failed to download model from {model_url}" "error": f"Failed to download model from {model_url}"
} }
await websocket.send_json(error_response) await websocket.send_json(error_response)
continue continue
model_tree = load_pipeline_from_zip(local_path, extraction_dir) model_tree = load_pipeline_from_zip(local_path, extraction_dir)
else: else:
logger.info(f"Loading local model from {model_url}") logger.info(f"Loading local .mpta file from {model_url}")
# Check if file exists before attempting to load # Check if file exists before attempting to load
if not os.path.exists(model_url): if not os.path.exists(model_url):
logger.error(f"Local model file not found: {model_url}") logger.error(f"Local .mpta file not found: {model_url}")
logger.debug(f"Current working directory: {os.getcwd()}") logger.debug(f"Current working directory: {os.getcwd()}")
error_response = { error_response = {
"type": "error", "type": "error",
"cameraIdentifier": camera_id, "subscriptionIdentifier": subscriptionIdentifier,
"error": f"Model file not found: {model_url}" "error": f"Model file not found: {model_url}"
} }
await websocket.send_json(error_response) await websocket.send_json(error_response)
continue continue
model_tree = load_pipeline_from_zip(model_url, extraction_dir) model_tree = load_pipeline_from_zip(model_url, extraction_dir)
if model_tree is None: if model_tree is None:
logger.error(f"Failed to load model {modelId} from mpta file for camera {camera_id}") logger.error(f"Failed to load model {modelId} from .mpta file for camera {camera_id}")
error_response = { error_response = {
"type": "error", "type": "error",
"cameraIdentifier": camera_id, "subscriptionIdentifier": subscriptionIdentifier,
"error": f"Failed to load model {modelId}" "error": f"Failed to load model {modelId}"
} }
await websocket.send_json(error_response) await websocket.send_json(error_response)
@ -515,95 +894,139 @@ async def detect(websocket: WebSocket):
models[camera_id] = {} models[camera_id] = {}
models[camera_id][modelId] = model_tree models[camera_id][modelId] = model_tree
logger.info(f"Successfully loaded model {modelId} for camera {camera_id}") logger.info(f"Successfully loaded model {modelId} for camera {camera_id}")
success_response = { logger.debug(f"Model extraction directory: {extraction_dir}")
"type": "modelLoaded",
"cameraIdentifier": camera_id,
"modelId": modelId
}
await websocket.send_json(success_response)
if camera_id and (rtsp_url or snapshot_url): if camera_id and (rtsp_url or snapshot_url):
with streams_lock: with streams_lock:
if camera_id not in streams and len(streams) < max_streams: # Determine camera URL for shared stream management
buffer = queue.Queue(maxsize=1) camera_url = snapshot_url if snapshot_url else rtsp_url
stop_event = threading.Event()
# Choose between snapshot and RTSP based on availability if camera_id not in streams and len(streams) < max_streams:
if snapshot_url and snapshot_interval: # Check if we already have a stream for this camera URL
logger.info(f"Using snapshot mode for camera {camera_id}: {snapshot_url}") shared_stream = camera_streams.get(camera_url)
thread = threading.Thread(target=snapshot_reader, args=(camera_id, snapshot_url, snapshot_interval, buffer, stop_event))
thread.daemon = True if shared_stream:
thread.start() # Reuse existing stream
streams[camera_id] = { logger.info(f"Reusing existing stream for camera URL: {camera_url}")
"buffer": buffer, buffer = shared_stream["buffer"]
"thread": thread, stop_event = shared_stream["stop_event"]
"snapshot_url": snapshot_url, thread = shared_stream["thread"]
"snapshot_interval": snapshot_interval, mode = shared_stream["mode"]
"stop_event": stop_event,
"modelId": modelId, # Increment reference count
"modelName": modelName, shared_stream["ref_count"] = shared_stream.get("ref_count", 0) + 1
"mode": "snapshot"
}
logger.info(f"Subscribed to camera {camera_id} (snapshot mode) with modelId {modelId}, modelName {modelName}, URL {snapshot_url}, interval {snapshot_interval}ms")
elif rtsp_url:
logger.info(f"Using RTSP mode for camera {camera_id}: {rtsp_url}")
cap = cv2.VideoCapture(rtsp_url)
if not cap.isOpened():
logger.error(f"Failed to open RTSP stream for camera {camera_id}")
continue
thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event))
thread.daemon = True
thread.start()
streams[camera_id] = {
"cap": cap,
"buffer": buffer,
"thread": thread,
"rtsp_url": rtsp_url,
"stop_event": stop_event,
"modelId": modelId,
"modelName": modelName,
"mode": "rtsp"
}
logger.info(f"Subscribed to camera {camera_id} (RTSP mode) with modelId {modelId}, modelName {modelName}, URL {rtsp_url}")
else: else:
logger.error(f"No valid URL provided for camera {camera_id}") # Create new stream
continue buffer = queue.Queue(maxsize=1)
stop_event = threading.Event()
if snapshot_url and snapshot_interval:
logger.info(f"Creating new snapshot stream for camera {camera_id}: {snapshot_url}")
thread = threading.Thread(target=snapshot_reader, args=(camera_id, snapshot_url, snapshot_interval, buffer, stop_event))
thread.daemon = True
thread.start()
mode = "snapshot"
# Store shared stream info
shared_stream = {
"buffer": buffer,
"thread": thread,
"stop_event": stop_event,
"mode": mode,
"url": snapshot_url,
"snapshot_interval": snapshot_interval,
"ref_count": 1
}
camera_streams[camera_url] = shared_stream
elif rtsp_url:
logger.info(f"Creating new RTSP stream for camera {camera_id}: {rtsp_url}")
cap = cv2.VideoCapture(rtsp_url)
if not cap.isOpened():
logger.error(f"Failed to open RTSP stream for camera {camera_id}")
continue
thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event))
thread.daemon = True
thread.start()
mode = "rtsp"
# Store shared stream info
shared_stream = {
"buffer": buffer,
"thread": thread,
"stop_event": stop_event,
"mode": mode,
"url": rtsp_url,
"cap": cap,
"ref_count": 1
}
camera_streams[camera_url] = shared_stream
else:
logger.error(f"No valid URL provided for camera {camera_id}")
continue
# Create stream info for this subscription
stream_info = {
"buffer": buffer,
"thread": thread,
"stop_event": stop_event,
"modelId": modelId,
"modelName": modelName,
"subscriptionIdentifier": subscriptionIdentifier,
"cropX1": cropX1,
"cropY1": cropY1,
"cropX2": cropX2,
"cropY2": cropY2,
"mode": mode,
"camera_url": camera_url
}
if mode == "snapshot":
stream_info["snapshot_url"] = snapshot_url
stream_info["snapshot_interval"] = snapshot_interval
elif mode == "rtsp":
stream_info["rtsp_url"] = rtsp_url
stream_info["cap"] = shared_stream["cap"]
streams[camera_id] = stream_info
subscription_to_camera[camera_id] = camera_url
elif camera_id and camera_id in streams: elif camera_id and camera_id in streams:
# If already subscribed, unsubscribe first # If already subscribed, unsubscribe first
stream = streams.pop(camera_id) logger.info(f"Resubscribing to camera {camera_id}")
stream["stop_event"].set() # Note: Keep models in memory for reuse across subscriptions
stream["thread"].join()
if "cap" in stream:
stream["cap"].release()
logger.info(f"Unsubscribed from camera {camera_id} for resubscription")
with models_lock:
if camera_id in models and modelId in models[camera_id]:
del models[camera_id][modelId]
if not models[camera_id]:
del models[camera_id]
elif msg_type == "unsubscribe": elif msg_type == "unsubscribe":
payload = data.get("payload", {}) payload = data.get("payload", {})
camera_id = payload.get("cameraIdentifier") subscriptionIdentifier = payload.get("subscriptionIdentifier")
logger.debug(f"Unsubscribing from camera {camera_id}") camera_id = subscriptionIdentifier
with streams_lock: with streams_lock:
if camera_id and camera_id in streams: if camera_id and camera_id in streams:
stream = streams.pop(camera_id) stream = streams.pop(camera_id)
stream["stop_event"].set() camera_url = subscription_to_camera.pop(camera_id, None)
stream["thread"].join()
# Only release cap if it exists (RTSP mode) if camera_url and camera_url in camera_streams:
if "cap" in stream: shared_stream = camera_streams[camera_url]
stream["cap"].release() shared_stream["ref_count"] -= 1
logger.info(f"Released RTSP capture for camera {camera_id}")
else: # If no more references, stop the shared stream
logger.info(f"Released snapshot reader for camera {camera_id}") if shared_stream["ref_count"] <= 0:
logger.info(f"Stopping shared stream for camera URL: {camera_url}")
shared_stream["stop_event"].set()
shared_stream["thread"].join()
if "cap" in shared_stream:
shared_stream["cap"].release()
del camera_streams[camera_url]
else:
logger.info(f"Shared stream for {camera_url} still has {shared_stream['ref_count']} references")
# Clean up cached frame
latest_frames.pop(camera_id, None)
logger.info(f"Unsubscribed from camera {camera_id}") logger.info(f"Unsubscribed from camera {camera_id}")
with models_lock: # Note: Keep models in memory for potential reuse
if camera_id in models:
del models[camera_id]
elif msg_type == "requestState": elif msg_type == "requestState":
cpu_usage = psutil.cpu_percent() cpu_usage = psutil.cpu_percent()
memory_usage = psutil.virtual_memory().percent memory_usage = psutil.virtual_memory().percent
if torch.cuda.is_available(): if torch.cuda.is_available():
gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) gpu_usage = torch.cuda.utilization() if hasattr(torch.cuda, 'utilization') else None
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2)
else: else:
gpu_usage = None gpu_usage = None
@ -611,10 +1034,15 @@ async def detect(websocket: WebSocket):
camera_connections = [ camera_connections = [
{ {
"cameraIdentifier": camera_id, "subscriptionIdentifier": stream["subscriptionIdentifier"],
"modelId": stream["modelId"], "modelId": stream["modelId"],
"modelName": stream["modelName"], "modelName": stream["modelName"],
"online": True "online": True,
# Include all subscription parameters for proper change detection
"rtspUrl": stream.get("rtsp_url"),
"snapshotUrl": stream.get("snapshot_url"),
"snapshotInterval": stream.get("snapshot_interval"),
**{k: v for k, v in get_crop_coords(stream).items() if v is not None}
} }
for camera_id, stream in streams.items() for camera_id, stream in streams.items()
] ]
@ -628,6 +1056,37 @@ async def detect(websocket: WebSocket):
"cameraConnections": camera_connections "cameraConnections": camera_connections
} }
await websocket.send_text(json.dumps(state_report)) await websocket.send_text(json.dumps(state_report))
elif msg_type == "setSessionId":
payload = data.get("payload", {})
display_identifier = payload.get("displayIdentifier")
session_id = payload.get("sessionId")
if display_identifier:
# Store session ID for this display
if session_id is None:
session_ids.pop(display_identifier, None)
logger.info(f"Cleared session ID for display {display_identifier}")
else:
session_ids[display_identifier] = session_id
logger.info(f"Set session ID {session_id} for display {display_identifier}")
elif msg_type == "patchSession":
session_id = data.get("sessionId")
patch_data = data.get("data", {})
# For now, just acknowledge the patch - actual implementation depends on backend requirements
response = {
"type": "patchSessionResult",
"payload": {
"sessionId": session_id,
"success": True,
"message": "Session patch acknowledged"
}
}
await websocket.send_json(response)
logger.info(f"Acknowledged patch for session {session_id}")
else: else:
logger.error(f"Unknown message type: {msg_type}") logger.error(f"Unknown message type: {msg_type}")
except json.JSONDecodeError: except json.JSONDecodeError:
@ -638,7 +1097,6 @@ async def detect(websocket: WebSocket):
except Exception as e: except Exception as e:
logger.error(f"Error handling message: {e}") logger.error(f"Error handling message: {e}")
break break
try: try:
await websocket.accept() await websocket.accept()
stream_task = asyncio.create_task(process_streams()) stream_task = asyncio.create_task(process_streams())
@ -651,19 +1109,24 @@ async def detect(websocket: WebSocket):
stream_task.cancel() stream_task.cancel()
await stream_task await stream_task
with streams_lock: with streams_lock:
for camera_id, stream in streams.items(): # Clean up shared camera streams
stream["stop_event"].set() for camera_url, shared_stream in camera_streams.items():
stream["thread"].join() shared_stream["stop_event"].set()
# Only release cap if it exists (RTSP mode) shared_stream["thread"].join()
if "cap" in stream: if "cap" in shared_stream:
stream["cap"].release() shared_stream["cap"].release()
while not stream["buffer"].empty(): while not shared_stream["buffer"].empty():
try: try:
stream["buffer"].get_nowait() shared_stream["buffer"].get_nowait()
except queue.Empty: except queue.Empty:
pass pass
logger.info(f"Released camera {camera_id} and cleaned up resources") logger.info(f"Released shared camera stream for {camera_url}")
streams.clear() streams.clear()
camera_streams.clear()
subscription_to_camera.clear()
with models_lock: with models_lock:
models.clear() models.clear()
latest_frames.clear()
session_ids.clear()
logger.info("WebSocket connection closed") logger.info("WebSocket connection closed")

View file

@ -1,366 +0,0 @@
from typing import List
from fastapi import FastAPI, WebSocket
from fastapi.websockets import WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError
from ultralytics import YOLO
import torch
import cv2
import base64
import numpy as np
import json
import logging
import threading
import queue
import os
import requests
from urllib.parse import urlparse
import asyncio
import psutil
app = FastAPI()
models = {}
with open("config.json", "r") as f:
config = json.load(f)
poll_interval = config.get("poll_interval_ms", 100)
reconnect_interval = config.get("reconnect_interval_sec", 5)
TARGET_FPS = config.get("target_fps", 10)
poll_interval = 1000 / TARGET_FPS
logging.info(f"Poll interval: {poll_interval}ms")
max_streams = config.get("max_streams", 5)
max_retries = config.get("max_retries", 3)
# Configure logging
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.FileHandler("app.log"),
logging.StreamHandler()
]
)
# Ensure the models directory exists
os.makedirs("models", exist_ok=True)
# Add constants for heartbeat
HEARTBEAT_INTERVAL = 2 # seconds
WORKER_TIMEOUT_MS = 10000
# Add a lock for thread-safe operations on shared resources
streams_lock = threading.Lock()
models_lock = threading.Lock()
@app.websocket("/")
async def detect(websocket: WebSocket):
import asyncio
import time
logging.info("WebSocket connection accepted")
streams = {}
# This function is user-modifiable
# Save data you want to persist across frames in the persistent_data dictionary
async def handle_detection(camera_id, stream, frame, websocket, model: YOLO, persistent_data):
try:
highest_conf_box = None
max_conf = -1
for r in model.track(frame, stream=False, persist=True):
for box in r.boxes:
box_cpu = box.cpu()
conf = float(box_cpu.conf[0])
if conf > max_conf and hasattr(box, "id") and box.id is not None:
max_conf = conf
highest_conf_box = {
"class": model.names[int(box_cpu.cls[0])],
"confidence": conf,
"id": box.id.item(),
}
# Broadcast to all subscribers of this URL
detection_data = {
"type": "imageDetection",
"cameraIdentifier": camera_id,
"timestamp": time.time(),
"data": {
"detections": highest_conf_box if highest_conf_box else None,
"modelId": stream['modelId'],
"modelName": stream['modelName']
}
}
logging.debug(f"Sending detection data for camera {camera_id}: {detection_data}")
await websocket.send_json(detection_data)
return persistent_data
except Exception as e:
logging.error(f"Error in handle_detection for camera {camera_id}: {e}")
return persistent_data
def frame_reader(camera_id, cap, buffer, stop_event):
import time
retries = 0
try:
while not stop_event.is_set():
try:
ret, frame = cap.read()
if not ret:
logging.warning(f"Connection lost for camera: {camera_id}, retry {retries+1}/{max_retries}")
cap.release()
time.sleep(reconnect_interval)
retries += 1
if retries > max_retries and max_retries != -1:
logging.error(f"Max retries reached for camera: {camera_id}")
break
# Re-open the VideoCapture
cap = cv2.VideoCapture(streams[camera_id]['rtsp_url'])
if not cap.isOpened():
logging.error(f"Failed to reopen RTSP stream for camera: {camera_id}")
continue
continue
retries = 0 # Reset on success
if not buffer.empty():
try:
buffer.get_nowait() # Discard the old frame
except queue.Empty:
pass
buffer.put(frame)
except cv2.error as e:
logging.error(f"OpenCV error for camera {camera_id}: {e}")
cap.release()
time.sleep(reconnect_interval)
retries += 1
if retries > max_retries and max_retries != -1:
logging.error(f"Max retries reached after OpenCV error for camera: {camera_id}")
break
# Re-open the VideoCapture
cap = cv2.VideoCapture(streams[camera_id]['rtsp_url'])
if not cap.isOpened():
logging.error(f"Failed to reopen RTSP stream for camera {camera_id} after OpenCV error")
continue
except Exception as e:
logging.error(f"Unexpected error for camera {camera_id}: {e}")
cap.release()
break
except Exception as e:
logging.error(f"Error in frame_reader thread for camera {camera_id}: {e}")
async def process_streams():
global models
logging.info("Started processing streams")
persistent_data_dict = {}
try:
while True:
start_time = time.time()
# Round-robin processing
with streams_lock:
current_streams = list(streams.items())
for camera_id, stream in current_streams:
buffer = stream['buffer']
if not buffer.empty():
frame = buffer.get()
with models_lock:
model = models.get(camera_id, {}).get(stream['modelId'])
key = (camera_id, stream['modelId'])
persistent_data = persistent_data_dict.get(key, {})
updated_persistent_data = await handle_detection(camera_id, stream, frame, websocket, model, persistent_data)
persistent_data_dict[key] = updated_persistent_data
elapsed_time = (time.time() - start_time) * 1000 # in ms
sleep_time = max(poll_interval - elapsed_time, 0)
logging.debug(f"Elapsed time: {elapsed_time}ms, sleeping for: {sleep_time}ms")
await asyncio.sleep(sleep_time / 1000.0)
except asyncio.CancelledError:
logging.info("Stream processing task cancelled")
except Exception as e:
logging.error(f"Error in process_streams: {e}")
async def send_heartbeat():
while True:
try:
cpu_usage = psutil.cpu_percent()
memory_usage = psutil.virtual_memory().percent
if torch.cuda.is_available():
gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # Convert to MB
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB
else:
gpu_usage = None
gpu_memory_usage = None
camera_connections = [
{
"cameraIdentifier": camera_id,
"modelId": stream['modelId'],
"modelName": stream['modelName'],
"online": True
}
for camera_id, stream in streams.items()
]
state_report = {
"type": "stateReport",
"cpuUsage": cpu_usage,
"memoryUsage": memory_usage,
"gpuUsage": gpu_usage,
"gpuMemoryUsage": gpu_memory_usage,
"cameraConnections": camera_connections
}
await websocket.send_text(json.dumps(state_report))
logging.debug("Sent stateReport as heartbeat")
await asyncio.sleep(HEARTBEAT_INTERVAL)
except Exception as e:
logging.error(f"Error sending stateReport heartbeat: {e}")
break
async def on_message():
global models
while True:
try:
msg = await websocket.receive_text()
logging.debug(f"Received message: {msg}")
print(f"Received message: {msg}")
data = json.loads(msg)
msg_type = data.get("type")
if msg_type == "subscribe":
payload = data.get("payload", {})
camera_id = payload.get("cameraIdentifier")
rtsp_url = payload.get("rtspUrl")
model_url = payload.get("modelUrl")
modelId = payload.get("modelId")
modelName = payload.get("modelName")
if model_url:
with models_lock:
if camera_id not in models:
models[camera_id] = {}
if modelId not in models[camera_id]:
print(f"Downloading model from {model_url}")
parsed_url = urlparse(model_url)
filename = os.path.basename(parsed_url.path)
model_filename = os.path.join("models", filename)
# Download the model
response = requests.get(model_url, stream=True)
if response.status_code == 200:
with open(model_filename, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logging.info(f"Downloaded model from {model_url} to {model_filename}")
model = YOLO(model_filename)
if torch.cuda.is_available():
model.to('cuda')
models[camera_id][modelId] = model
logging.info(f"Loaded model {modelId} for camera {camera_id}")
else:
logging.error(f"Failed to download model from {model_url}")
continue
if camera_id and rtsp_url:
with streams_lock:
if camera_id not in streams and len(streams) < max_streams:
cap = cv2.VideoCapture(rtsp_url)
if not cap.isOpened():
logging.error(f"Failed to open RTSP stream for camera {camera_id}")
continue
buffer = queue.Queue(maxsize=1)
stop_event = threading.Event()
thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event))
thread.daemon = True
thread.start()
streams[camera_id] = {
'cap': cap,
'buffer': buffer,
'thread': thread,
'rtsp_url': rtsp_url,
'stop_event': stop_event,
'modelId': modelId,
'modelName': modelName
}
logging.info(f"Subscribed to camera {camera_id} with modelId {modelId}, modelName {modelName} and URL {rtsp_url}")
elif camera_id and camera_id in streams:
stream = streams.pop(camera_id)
stream['cap'].release()
logging.info(f"Unsubscribed from camera {camera_id}")
if camera_id in models and modelId in models[camera_id]:
del models[camera_id][modelId]
if not models[camera_id]:
del models[camera_id]
elif msg_type == "unsubscribe":
payload = data.get("payload", {})
camera_id = payload.get("cameraIdentifier")
logging.debug(f"Unsubscribing from camera {camera_id}")
with streams_lock:
if camera_id and camera_id in streams:
stream = streams.pop(camera_id)
stream['stop_event'].set()
stream['thread'].join()
stream['cap'].release()
logging.info(f"Unsubscribed from camera {camera_id}")
if camera_id in models and modelId in models[camera_id]:
del models[camera_id][modelId]
if not models[camera_id]:
del models[camera_id]
elif msg_type == "requestState":
# Handle state request
cpu_usage = psutil.cpu_percent()
memory_usage = psutil.virtual_memory().percent
if torch.cuda.is_available():
gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # Convert to MB
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB
else:
gpu_usage = None
gpu_memory_usage = None
camera_connections = [
{
"cameraIdentifier": camera_id,
"modelId": stream['modelId'],
"modelName": stream['modelName'],
"online": True
}
for camera_id, stream in streams.items()
]
state_report = {
"type": "stateReport",
"cpuUsage": cpu_usage,
"memoryUsage": memory_usage,
"gpuUsage": gpu_usage,
"gpuMemoryUsage": gpu_memory_usage,
"cameraConnections": camera_connections
}
await websocket.send_text(json.dumps(state_report))
else:
logging.error(f"Unknown message type: {msg_type}")
except json.JSONDecodeError:
logging.error("Received invalid JSON message")
except (WebSocketDisconnect, ConnectionClosedError) as e:
logging.warning(f"WebSocket disconnected: {e}")
break
except Exception as e:
logging.error(f"Error handling message: {e}")
break
try:
await websocket.accept()
task = asyncio.create_task(process_streams())
heartbeat_task = asyncio.create_task(send_heartbeat())
message_task = asyncio.create_task(on_message())
await asyncio.gather(heartbeat_task, message_task)
except Exception as e:
logging.error(f"Error in detect websocket: {e}")
finally:
task.cancel()
await task
with streams_lock:
for camera_id, stream in streams.items():
stream['stop_event'].set()
stream['thread'].join()
stream['cap'].release()
stream['buffer'].queue.clear()
logging.info(f"Released camera {camera_id} and cleaned up resources")
streams.clear()
with models_lock:
models.clear()
logging.info("WebSocket connection closed")

1449
docs/MasterElection.md Normal file

File diff suppressed because it is too large Load diff

1498
docs/WorkerConnection.md Normal file

File diff suppressed because it is too large Load diff

327
pympta.md Normal file
View file

@ -0,0 +1,327 @@
# pympta: Modular Pipeline Task Executor
`pympta` is a Python module designed to load and execute modular, multi-stage AI pipelines defined in a special package format (`.mpta`). It is primarily used within the detector worker to run complex computer vision tasks where the output of one model can trigger a subsequent model on a specific region of interest.
## Core Concepts
### 1. MPTA Package (`.mpta`)
An `.mpta` file is a standard `.zip` archive with a different extension. It bundles all the necessary components for a pipeline to run.
A typical `.mpta` file has the following structure:
```
my_pipeline.mpta/
├── pipeline.json
├── model1.pt
├── model2.pt
└── ...
```
- **`pipeline.json`**: (Required) The manifest file that defines the structure of the pipeline, the models to use, and the logic connecting them.
- **Model Files (`.pt`, etc.)**: The actual pre-trained model files (e.g., PyTorch, ONNX). The pipeline currently uses `ultralytics.YOLO` models.
### 2. Pipeline Structure
A pipeline is a tree-like structure of "nodes," defined in `pipeline.json`.
- **Root Node**: The entry point of the pipeline. It processes the initial, full-frame image.
- **Branch Nodes**: Child nodes that are triggered by specific detection results from their parent. For example, a root node might detect a "vehicle," which then triggers a branch node to detect a "license plate" within the vehicle's bounding box.
This modular structure allows for creating complex and efficient inference logic, avoiding the need to run every model on every frame.
## `pipeline.json` Specification
This file defines the entire pipeline logic. The root object contains a `pipeline` key for the pipeline definition, optional `redis` key for Redis configuration, and optional `postgresql` key for database integration.
### Top-Level Object Structure
| Key | Type | Required | Description |
| ------------ | ------ | -------- | ------------------------------------------------------- |
| `pipeline` | Object | Yes | The root node object of the pipeline. |
| `redis` | Object | No | Configuration for connecting to a Redis server. |
| `postgresql` | Object | No | Configuration for connecting to a PostgreSQL database. |
### Redis Configuration (`redis`)
| Key | Type | Required | Description |
| ---------- | ------ | -------- | ------------------------------------------------------- |
| `host` | String | Yes | The hostname or IP address of the Redis server. |
| `port` | Number | Yes | The port number of the Redis server. |
| `password` | String | No | The password for Redis authentication. |
| `db` | Number | No | The Redis database number to use. Defaults to `0`. |
### PostgreSQL Configuration (`postgresql`)
| Key | Type | Required | Description |
| ---------- | ------ | -------- | ------------------------------------------------------- |
| `host` | String | Yes | The hostname or IP address of the PostgreSQL server. |
| `port` | Number | Yes | The port number of the PostgreSQL server. |
| `database` | String | Yes | The database name to connect to. |
| `username` | String | Yes | The username for database authentication. |
| `password` | String | Yes | The password for database authentication. |
### Node Object Structure
| Key | Type | Required | Description |
| ------------------- | ------------- | -------- | -------------------------------------------------------------------------------------------------------------------------------------- |
| `modelId` | String | Yes | A unique identifier for this model node (e.g., "vehicle-detector"). |
| `modelFile` | String | Yes | The path to the model file within the `.mpta` archive (e.g., "yolov8n.pt"). |
| `minConfidence` | Float | Yes | The minimum confidence score (0.0 to 1.0) required for a detection to be considered valid and potentially trigger a branch. |
| `triggerClasses` | Array<String> | Yes | A list of class names that, when detected by the parent, can trigger this node. For the root node, this lists all classes of interest. |
| `crop` | Boolean | No | If `true`, the image is cropped to the parent's detection bounding box before being passed to this node's model. Defaults to `false`. |
| `cropClass` | String | No | The specific class to use for cropping (e.g., "Frontal" for frontal view cropping). |
| `multiClass` | Boolean | No | If `true`, enables multi-class detection mode where multiple classes can be detected simultaneously. |
| `expectedClasses` | Array<String> | No | When `multiClass` is true, defines which classes are expected. At least one must be detected for processing to continue. |
| `parallel` | Boolean | No | If `true`, this branch will be processed in parallel with other parallel branches. |
| `branches` | Array<Node> | No | A list of child node objects that can be triggered by this node's detections. |
| `actions` | Array<Action> | No | A list of actions to execute upon a successful detection in this node. |
| `parallelActions` | Array<Action> | No | A list of actions to execute after all specified branches have completed. |
### Action Object Structure
Actions allow the pipeline to interact with Redis and PostgreSQL databases. They are executed sequentially for a given detection.
#### Action Context & Dynamic Keys
All actions have access to a dynamic context for formatting keys and messages. The context is created for each detection event and includes:
- All key-value pairs from the detection result (e.g., `class`, `confidence`, `id`).
- `{timestamp_ms}`: The current Unix timestamp in milliseconds.
- `{timestamp}`: Formatted timestamp string (YYYY-MM-DDTHH-MM-SS).
- `{uuid}`: A unique identifier (UUID4) for the detection event.
- `{filename}`: Generated filename with UUID.
- `{camera_id}`: Full camera subscription identifier.
- `{display_id}`: Display identifier extracted from subscription.
- `{session_id}`: Session ID for database operations.
- `{image_key}`: If a `redis_save_image` action has already been executed for this event, this placeholder will be replaced with the key where the image was stored.
#### `redis_save_image`
Saves the current image frame (or cropped sub-image) to a Redis key.
| Key | Type | Required | Description |
| ---------------- | ------ | -------- | ------------------------------------------------------------------------------------------------------- |
| `type` | String | Yes | Must be `"redis_save_image"`. |
| `key` | String | Yes | The Redis key to save the image to. Can contain any of the dynamic placeholders. |
| `region` | String | No | Specific detected region to crop and save (e.g., "Frontal"). |
| `format` | String | No | Image format: "jpeg" or "png". Defaults to "jpeg". |
| `quality` | Number | No | JPEG quality (1-100). Defaults to 90. |
| `expire_seconds` | Number | No | If provided, sets an expiration time (in seconds) for the Redis key. |
#### `redis_publish`
Publishes a message to a Redis channel.
| Key | Type | Required | Description |
| --------- | ------ | -------- | ------------------------------------------------------------------------------------------------------- |
| `type` | String | Yes | Must be `"redis_publish"`. |
| `channel` | String | Yes | The Redis channel to publish the message to. |
| `message` | String | Yes | The message to publish. Can contain any of the dynamic placeholders, including `{image_key}`. |
#### `postgresql_update_combined`
Updates PostgreSQL database with results from multiple branches after they complete.
| Key | Type | Required | Description |
| ------------------ | ------------- | -------- | ------------------------------------------------------------------------------------------------------- |
| `type` | String | Yes | Must be `"postgresql_update_combined"`. |
| `table` | String | Yes | The database table name (will be prefixed with `gas_station_1.` schema). |
| `key_field` | String | Yes | The field to use as the update key (typically "session_id"). |
| `key_value` | String | Yes | Template for the key value (e.g., "{session_id}"). |
| `waitForBranches` | Array<String> | Yes | List of branch model IDs to wait for completion before executing update. |
| `fields` | Object | Yes | Field mapping object where keys are database columns and values are templates (e.g., "{branch.field}").|
### Complete Example `pipeline.json`
This example demonstrates a comprehensive pipeline for vehicle detection with parallel classification and database integration:
```json
{
"redis": {
"host": "10.100.1.3",
"port": 6379,
"password": "your-redis-password",
"db": 0
},
"postgresql": {
"host": "10.100.1.3",
"port": 5432,
"database": "inference",
"username": "root",
"password": "your-db-password"
},
"pipeline": {
"modelId": "car_frontal_detection_v1",
"modelFile": "car_frontal_detection_v1.pt",
"crop": false,
"triggerClasses": ["Car", "Frontal"],
"minConfidence": 0.8,
"multiClass": true,
"expectedClasses": ["Car", "Frontal"],
"actions": [
{
"type": "redis_save_image",
"region": "Frontal",
"key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
"expire_seconds": 600,
"format": "jpeg",
"quality": 90
},
{
"type": "redis_publish",
"channel": "car_detections",
"message": "{\"event\":\"frontal_detected\"}"
}
],
"branches": [
{
"modelId": "car_brand_cls_v1",
"modelFile": "car_brand_cls_v1.pt",
"crop": true,
"cropClass": "Frontal",
"resizeTarget": [224, 224],
"triggerClasses": ["Frontal"],
"minConfidence": 0.85,
"parallel": true,
"branches": []
},
{
"modelId": "car_bodytype_cls_v1",
"modelFile": "car_bodytype_cls_v1.pt",
"crop": true,
"cropClass": "Car",
"resizeTarget": [224, 224],
"triggerClasses": ["Car"],
"minConfidence": 0.85,
"parallel": true,
"branches": []
}
],
"parallelActions": [
{
"type": "postgresql_update_combined",
"table": "car_frontal_info",
"key_field": "session_id",
"key_value": "{session_id}",
"waitForBranches": ["car_brand_cls_v1", "car_bodytype_cls_v1"],
"fields": {
"car_brand": "{car_brand_cls_v1.brand}",
"car_body_type": "{car_bodytype_cls_v1.body_type}"
}
}
]
}
}
```
## API Reference
The `pympta` module exposes two main functions.
### `load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict`
Loads, extracts, and parses an `.mpta` file to build a pipeline tree in memory. It also establishes Redis and PostgreSQL connections if configured in `pipeline.json`.
- **Parameters:**
- `zip_source` (str): The file path to the local `.mpta` zip archive.
- `target_dir` (str): A directory path where the archive's contents will be extracted.
- **Returns:**
- A dictionary representing the root node of the pipeline, ready to be used with `run_pipeline`. Returns `None` if loading fails.
### `run_pipeline(frame, node: dict, return_bbox: bool = False, context: dict = None)`
Executes the inference pipeline on a single image frame.
- **Parameters:**
- `frame`: The input image frame (e.g., a NumPy array from OpenCV).
- `node` (dict): The pipeline node to execute (typically the root node returned by `load_pipeline_from_zip`).
- `return_bbox` (bool): If `True`, the function returns a tuple `(detection, bounding_box)`. Otherwise, it returns only the `detection`.
- `context` (dict): Optional context dictionary containing camera_id, display_id, session_id for action formatting.
- **Returns:**
- The final detection result from the last executed node in the chain. A detection is a dictionary like `{'class': 'car', 'confidence': 0.95, 'id': 1}`. If no detection meets the criteria, it returns `None` (or `(None, None)` if `return_bbox` is `True`).
## Database Integration
The pipeline system includes automatic PostgreSQL database management:
### Table Schema (`gas_station_1.car_frontal_info`)
The system automatically creates and manages the following table structure:
```sql
CREATE TABLE IF NOT EXISTS gas_station_1.car_frontal_info (
display_id VARCHAR(255),
captured_timestamp VARCHAR(255),
session_id VARCHAR(255) PRIMARY KEY,
license_character VARCHAR(255) DEFAULT NULL,
license_type VARCHAR(255) DEFAULT 'No model available',
car_brand VARCHAR(255) DEFAULT NULL,
car_model VARCHAR(255) DEFAULT NULL,
car_body_type VARCHAR(255) DEFAULT NULL,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
```
### Workflow
1. **Initial Record Creation**: When both "Car" and "Frontal" are detected, an initial database record is created with a UUID session_id.
2. **Redis Storage**: Vehicle images are stored in Redis with keys containing the session_id.
3. **Parallel Classification**: Brand and body type classification run concurrently.
4. **Database Update**: After all branches complete, the database record is updated with classification results.
## Usage Example
This snippet shows how to use `pympta` with the enhanced features:
```python
import cv2
from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline
# 1. Define paths
MPTA_FILE = "path/to/your/pipeline.mpta"
CACHE_DIR = ".mptacache"
# 2. Load the pipeline from the .mpta file
# This reads pipeline.json and loads the YOLO models into memory.
model_tree = load_pipeline_from_zip(MPTA_FILE, CACHE_DIR)
if not model_tree:
print("Failed to load pipeline.")
exit()
# 3. Open a video source
cap = cv2.VideoCapture(0)
while True:
ret, frame = cap.read()
if not ret:
break
# 4. Run the pipeline on the current frame with context
context = {
"camera_id": "display-001;cam-001",
"display_id": "display-001",
"session_id": None # Will be generated automatically
}
detection_result, bounding_box = run_pipeline(frame, model_tree, return_bbox=True, context=context)
# 5. Display the results
if detection_result:
print(f"Detected: {detection_result['class']} with confidence {detection_result['confidence']:.2f}")
if bounding_box:
x1, y1, x2, y2 = bounding_box
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, detection_result['class'], (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
cv2.imshow("Pipeline Output", frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
```

7
requirements.base.txt Normal file
View file

@ -0,0 +1,7 @@
torch
torchvision
ultralytics
opencv-python
scipy
filterpy
psycopg2-binary

View file

@ -1,8 +1,6 @@
fastapi fastapi
uvicorn uvicorn
torch
torchvision
ultralytics
opencv-python
websockets websockets
fastapi[standard] fastapi[standard]
redis
urllib3<2.0.0

211
siwatsystem/database.py Normal file
View file

@ -0,0 +1,211 @@
import psycopg2
import psycopg2.extras
from typing import Optional, Dict, Any
import logging
import uuid
logger = logging.getLogger(__name__)
class DatabaseManager:
def __init__(self, config: Dict[str, Any]):
self.config = config
self.connection: Optional[psycopg2.extensions.connection] = None
def connect(self) -> bool:
try:
self.connection = psycopg2.connect(
host=self.config['host'],
port=self.config['port'],
database=self.config['database'],
user=self.config['username'],
password=self.config['password']
)
logger.info("PostgreSQL connection established successfully")
return True
except Exception as e:
logger.error(f"Failed to connect to PostgreSQL: {e}")
return False
def disconnect(self):
if self.connection:
self.connection.close()
self.connection = None
logger.info("PostgreSQL connection closed")
def is_connected(self) -> bool:
try:
if self.connection and not self.connection.closed:
cur = self.connection.cursor()
cur.execute("SELECT 1")
cur.fetchone()
cur.close()
return True
except:
pass
return False
def update_car_info(self, session_id: str, brand: str, model: str, body_type: str) -> bool:
if not self.is_connected():
if not self.connect():
return False
try:
cur = self.connection.cursor()
query = """
INSERT INTO car_frontal_info (session_id, car_brand, car_model, car_body_type, updated_at)
VALUES (%s, %s, %s, %s, NOW())
ON CONFLICT (session_id)
DO UPDATE SET
car_brand = EXCLUDED.car_brand,
car_model = EXCLUDED.car_model,
car_body_type = EXCLUDED.car_body_type,
updated_at = NOW()
"""
cur.execute(query, (session_id, brand, model, body_type))
self.connection.commit()
cur.close()
logger.info(f"Updated car info for session {session_id}: {brand} {model} ({body_type})")
return True
except Exception as e:
logger.error(f"Failed to update car info: {e}")
if self.connection:
self.connection.rollback()
return False
def execute_update(self, table: str, key_field: str, key_value: str, fields: Dict[str, str]) -> bool:
if not self.is_connected():
if not self.connect():
return False
try:
cur = self.connection.cursor()
# Build the UPDATE query dynamically
set_clauses = []
values = []
for field, value in fields.items():
if value == "NOW()":
set_clauses.append(f"{field} = NOW()")
else:
set_clauses.append(f"{field} = %s")
values.append(value)
# Add schema prefix if table doesn't already have it
full_table_name = table if '.' in table else f"gas_station_1.{table}"
query = f"""
INSERT INTO {full_table_name} ({key_field}, {', '.join(fields.keys())})
VALUES (%s, {', '.join(['%s'] * len(fields))})
ON CONFLICT ({key_field})
DO UPDATE SET {', '.join(set_clauses)}
"""
# Add key_value to the beginning of values list
all_values = [key_value] + list(fields.values()) + values
cur.execute(query, all_values)
self.connection.commit()
cur.close()
logger.info(f"Updated {table} for {key_field}={key_value}")
return True
except Exception as e:
logger.error(f"Failed to execute update on {table}: {e}")
if self.connection:
self.connection.rollback()
return False
def create_car_frontal_info_table(self) -> bool:
"""Create the car_frontal_info table in gas_station_1 schema if it doesn't exist."""
if not self.is_connected():
if not self.connect():
return False
try:
cur = self.connection.cursor()
# Create schema if it doesn't exist
cur.execute("CREATE SCHEMA IF NOT EXISTS gas_station_1")
# Create table if it doesn't exist
create_table_query = """
CREATE TABLE IF NOT EXISTS gas_station_1.car_frontal_info (
display_id VARCHAR(255),
captured_timestamp VARCHAR(255),
session_id VARCHAR(255) PRIMARY KEY,
license_character VARCHAR(255) DEFAULT NULL,
license_type VARCHAR(255) DEFAULT 'No model available',
car_brand VARCHAR(255) DEFAULT NULL,
car_model VARCHAR(255) DEFAULT NULL,
car_body_type VARCHAR(255) DEFAULT NULL,
updated_at TIMESTAMP DEFAULT NOW()
)
"""
cur.execute(create_table_query)
# Add columns if they don't exist (for existing tables)
alter_queries = [
"ALTER TABLE gas_station_1.car_frontal_info ADD COLUMN IF NOT EXISTS car_brand VARCHAR(255) DEFAULT NULL",
"ALTER TABLE gas_station_1.car_frontal_info ADD COLUMN IF NOT EXISTS car_model VARCHAR(255) DEFAULT NULL",
"ALTER TABLE gas_station_1.car_frontal_info ADD COLUMN IF NOT EXISTS car_body_type VARCHAR(255) DEFAULT NULL",
"ALTER TABLE gas_station_1.car_frontal_info ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP DEFAULT NOW()"
]
for alter_query in alter_queries:
try:
cur.execute(alter_query)
logger.debug(f"Executed: {alter_query}")
except Exception as e:
# Ignore errors if column already exists (for older PostgreSQL versions)
if "already exists" in str(e).lower():
logger.debug(f"Column already exists, skipping: {alter_query}")
else:
logger.warning(f"Error in ALTER TABLE: {e}")
self.connection.commit()
cur.close()
logger.info("Successfully created/verified car_frontal_info table with all required columns")
return True
except Exception as e:
logger.error(f"Failed to create car_frontal_info table: {e}")
if self.connection:
self.connection.rollback()
return False
def insert_initial_detection(self, display_id: str, captured_timestamp: str, session_id: str = None) -> str:
"""Insert initial detection record and return the session_id."""
if not self.is_connected():
if not self.connect():
return None
# Generate session_id if not provided
if not session_id:
session_id = str(uuid.uuid4())
try:
# Ensure table exists
if not self.create_car_frontal_info_table():
logger.error("Failed to create/verify table before insertion")
return None
cur = self.connection.cursor()
insert_query = """
INSERT INTO gas_station_1.car_frontal_info
(display_id, captured_timestamp, session_id, license_character, license_type, car_brand, car_model, car_body_type)
VALUES (%s, %s, %s, NULL, 'No model available', NULL, NULL, NULL)
ON CONFLICT (session_id) DO NOTHING
"""
cur.execute(insert_query, (display_id, captured_timestamp, session_id))
self.connection.commit()
cur.close()
logger.info(f"Inserted initial detection record with session_id: {session_id}")
return session_id
except Exception as e:
logger.error(f"Failed to insert initial detection record: {e}")
if self.connection:
self.connection.rollback()
return None

View file

@ -3,17 +3,72 @@ import json
import logging import logging
import torch import torch
import cv2 import cv2
import requests
import zipfile import zipfile
import shutil import shutil
import traceback import traceback
import redis
import time
import uuid
import concurrent.futures
from ultralytics import YOLO from ultralytics import YOLO
from urllib.parse import urlparse from urllib.parse import urlparse
from .database import DatabaseManager
# Create a logger specifically for this module # Create a logger specifically for this module
logger = logging.getLogger("detector_worker.pympta") logger = logging.getLogger("detector_worker.pympta")
def load_pipeline_node(node_config: dict, mpta_dir: str) -> dict: def validate_redis_config(redis_config: dict) -> bool:
"""Validate Redis configuration parameters."""
required_fields = ["host", "port"]
for field in required_fields:
if field not in redis_config:
logger.error(f"Missing required Redis config field: {field}")
return False
if not isinstance(redis_config["port"], int) or redis_config["port"] <= 0:
logger.error(f"Invalid Redis port: {redis_config['port']}")
return False
return True
def validate_postgresql_config(pg_config: dict) -> bool:
"""Validate PostgreSQL configuration parameters."""
required_fields = ["host", "port", "database", "username", "password"]
for field in required_fields:
if field not in pg_config:
logger.error(f"Missing required PostgreSQL config field: {field}")
return False
if not isinstance(pg_config["port"], int) or pg_config["port"] <= 0:
logger.error(f"Invalid PostgreSQL port: {pg_config['port']}")
return False
return True
def crop_region_by_class(frame, regions_dict, class_name):
"""Crop a specific region from frame based on detected class."""
if class_name not in regions_dict:
logger.warning(f"Class '{class_name}' not found in detected regions")
return None
bbox = regions_dict[class_name]['bbox']
x1, y1, x2, y2 = bbox
cropped = frame[y1:y2, x1:x2]
if cropped.size == 0:
logger.warning(f"Empty crop for class '{class_name}' with bbox {bbox}")
return None
return cropped
def format_action_context(base_context, additional_context=None):
"""Format action context with dynamic values."""
context = {**base_context}
if additional_context:
context.update(additional_context)
return context
def load_pipeline_node(node_config: dict, mpta_dir: str, redis_client, db_manager=None) -> dict:
# Recursively load a model node from configuration. # Recursively load a model node from configuration.
model_path = os.path.join(mpta_dir, node_config["modelFile"]) model_path = os.path.join(mpta_dir, node_config["modelFile"])
if not os.path.exists(model_path): if not os.path.exists(model_path):
@ -43,14 +98,22 @@ def load_pipeline_node(node_config: dict, mpta_dir: str) -> dict:
"triggerClasses": trigger_classes, "triggerClasses": trigger_classes,
"triggerClassIndices": trigger_class_indices, "triggerClassIndices": trigger_class_indices,
"crop": node_config.get("crop", False), "crop": node_config.get("crop", False),
"cropClass": node_config.get("cropClass"),
"minConfidence": node_config.get("minConfidence", None), "minConfidence": node_config.get("minConfidence", None),
"multiClass": node_config.get("multiClass", False),
"expectedClasses": node_config.get("expectedClasses", []),
"parallel": node_config.get("parallel", False),
"actions": node_config.get("actions", []),
"parallelActions": node_config.get("parallelActions", []),
"model": model, "model": model,
"branches": [] "branches": [],
"redis_client": redis_client,
"db_manager": db_manager
} }
logger.debug(f"Configured node {node_config['modelId']} with trigger classes: {node['triggerClasses']}") logger.debug(f"Configured node {node_config['modelId']} with trigger classes: {node['triggerClasses']}")
for child in node_config.get("branches", []): for child in node_config.get("branches", []):
logger.debug(f"Loading branch for parent node {node_config['modelId']}") logger.debug(f"Loading branch for parent node {node_config['modelId']}")
node["branches"].append(load_pipeline_node(child, mpta_dir)) node["branches"].append(load_pipeline_node(child, mpta_dir, redis_client, db_manager))
return node return node
def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict: def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
@ -158,7 +221,47 @@ def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
pipeline_config = json.load(f) pipeline_config = json.load(f)
logger.info(f"Successfully loaded pipeline configuration from {pipeline_json_path}") logger.info(f"Successfully loaded pipeline configuration from {pipeline_json_path}")
logger.debug(f"Pipeline config: {json.dumps(pipeline_config, indent=2)}") logger.debug(f"Pipeline config: {json.dumps(pipeline_config, indent=2)}")
return load_pipeline_node(pipeline_config["pipeline"], mpta_dir)
# Establish Redis connection if configured
redis_client = None
if "redis" in pipeline_config:
redis_config = pipeline_config["redis"]
if not validate_redis_config(redis_config):
logger.error("Invalid Redis configuration, skipping Redis connection")
else:
try:
redis_client = redis.Redis(
host=redis_config["host"],
port=redis_config["port"],
password=redis_config.get("password"),
db=redis_config.get("db", 0),
decode_responses=True
)
redis_client.ping()
logger.info(f"Successfully connected to Redis at {redis_config['host']}:{redis_config['port']}")
except redis.exceptions.ConnectionError as e:
logger.error(f"Failed to connect to Redis: {e}")
redis_client = None
# Establish PostgreSQL connection if configured
db_manager = None
if "postgresql" in pipeline_config:
pg_config = pipeline_config["postgresql"]
if not validate_postgresql_config(pg_config):
logger.error("Invalid PostgreSQL configuration, skipping database connection")
else:
try:
db_manager = DatabaseManager(pg_config)
if db_manager.connect():
logger.info(f"Successfully connected to PostgreSQL at {pg_config['host']}:{pg_config['port']}")
else:
logger.error("Failed to connect to PostgreSQL")
db_manager = None
except Exception as e:
logger.error(f"Error initializing PostgreSQL connection: {e}")
db_manager = None
return load_pipeline_node(pipeline_config["pipeline"], mpta_dir, redis_client, db_manager)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"Error parsing pipeline.json: {str(e)}", exc_info=True) logger.error(f"Error parsing pipeline.json: {str(e)}", exc_info=True)
return None return None
@ -169,49 +272,357 @@ def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
logger.error(f"Error loading pipeline.json: {str(e)}", exc_info=True) logger.error(f"Error loading pipeline.json: {str(e)}", exc_info=True)
return None return None
def run_pipeline(frame, node: dict, return_bbox: bool=False): def execute_actions(node, frame, detection_result, regions_dict=None):
if not node["redis_client"] or not node["actions"]:
return
# Create a dynamic context for this detection event
from datetime import datetime
action_context = {
**detection_result,
"timestamp_ms": int(time.time() * 1000),
"uuid": str(uuid.uuid4()),
"timestamp": datetime.now().strftime("%Y-%m-%dT%H-%M-%S"),
"filename": f"{uuid.uuid4()}.jpg"
}
for action in node["actions"]:
try:
if action["type"] == "redis_save_image":
key = action["key"].format(**action_context)
# Check if we need to crop a specific region
region_name = action.get("region")
image_to_save = frame
if region_name and regions_dict:
cropped_image = crop_region_by_class(frame, regions_dict, region_name)
if cropped_image is not None:
image_to_save = cropped_image
logger.debug(f"Cropped region '{region_name}' for redis_save_image")
else:
logger.warning(f"Could not crop region '{region_name}', saving full frame instead")
# Encode image with specified format and quality (default to JPEG)
img_format = action.get("format", "jpeg").lower()
quality = action.get("quality", 90)
if img_format == "jpeg":
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
success, buffer = cv2.imencode('.jpg', image_to_save, encode_params)
elif img_format == "png":
success, buffer = cv2.imencode('.png', image_to_save)
else:
success, buffer = cv2.imencode('.jpg', image_to_save, [cv2.IMWRITE_JPEG_QUALITY, quality])
if not success:
logger.error(f"Failed to encode image for redis_save_image")
continue
expire_seconds = action.get("expire_seconds")
if expire_seconds:
node["redis_client"].setex(key, expire_seconds, buffer.tobytes())
logger.info(f"Saved image to Redis with key: {key} (expires in {expire_seconds}s)")
else:
node["redis_client"].set(key, buffer.tobytes())
logger.info(f"Saved image to Redis with key: {key}")
action_context["image_key"] = key
elif action["type"] == "redis_publish":
channel = action["channel"]
try:
# Handle JSON message format by creating it programmatically
message_template = action["message"]
# Check if the message is JSON-like (starts and ends with braces)
if message_template.strip().startswith('{') and message_template.strip().endswith('}'):
# Create JSON data programmatically to avoid formatting issues
json_data = {}
# Add common fields
json_data["event"] = "frontal_detected"
json_data["display_id"] = action_context.get("display_id", "unknown")
json_data["session_id"] = action_context.get("session_id")
json_data["timestamp"] = action_context.get("timestamp", "")
json_data["image_key"] = action_context.get("image_key", "")
# Convert to JSON string
message = json.dumps(json_data)
else:
# Use regular string formatting for non-JSON messages
message = message_template.format(**action_context)
# Publish to Redis
if not node["redis_client"]:
logger.error("Redis client is None, cannot publish message")
continue
# Test Redis connection
try:
node["redis_client"].ping()
logger.debug("Redis connection is active")
except Exception as ping_error:
logger.error(f"Redis connection test failed: {ping_error}")
continue
result = node["redis_client"].publish(channel, message)
logger.info(f"Published message to Redis channel '{channel}': {message}")
logger.info(f"Redis publish result (subscribers count): {result}")
# Additional debug info
if result == 0:
logger.warning(f"No subscribers listening to channel '{channel}'")
else:
logger.info(f"Message delivered to {result} subscriber(s)")
except KeyError as e:
logger.error(f"Missing key in redis_publish message template: {e}")
logger.debug(f"Available context keys: {list(action_context.keys())}")
except Exception as e:
logger.error(f"Error in redis_publish action: {e}")
logger.debug(f"Message template: {action['message']}")
logger.debug(f"Available context keys: {list(action_context.keys())}")
import traceback
logger.debug(f"Full traceback: {traceback.format_exc()}")
except Exception as e:
logger.error(f"Error executing action {action['type']}: {e}")
def execute_parallel_actions(node, frame, detection_result, regions_dict):
"""Execute parallel actions after all required branches have completed."""
if not node.get("parallelActions"):
return
logger.debug("Executing parallel actions...")
branch_results = detection_result.get("branch_results", {})
for action in node["parallelActions"]:
try:
action_type = action.get("type")
logger.debug(f"Processing parallel action: {action_type}")
if action_type == "postgresql_update_combined":
# Check if all required branches have completed
wait_for_branches = action.get("waitForBranches", [])
missing_branches = [branch for branch in wait_for_branches if branch not in branch_results]
if missing_branches:
logger.warning(f"Cannot execute postgresql_update_combined: missing branch results for {missing_branches}")
continue
logger.info(f"All required branches completed: {wait_for_branches}")
# Execute the database update
execute_postgresql_update_combined(node, action, detection_result, branch_results)
else:
logger.warning(f"Unknown parallel action type: {action_type}")
except Exception as e:
logger.error(f"Error executing parallel action {action.get('type', 'unknown')}: {e}")
import traceback
logger.debug(f"Full traceback: {traceback.format_exc()}")
def execute_postgresql_update_combined(node, action, detection_result, branch_results):
"""Execute a PostgreSQL update with combined branch results."""
if not node.get("db_manager"):
logger.error("No database manager available for postgresql_update_combined action")
return
try:
table = action["table"]
key_field = action["key_field"]
key_value_template = action["key_value"]
fields = action["fields"]
# Create context for key value formatting
action_context = {**detection_result}
key_value = key_value_template.format(**action_context)
logger.info(f"Executing database update: table={table}, {key_field}={key_value}")
# Process field mappings
mapped_fields = {}
for db_field, value_template in fields.items():
try:
mapped_value = resolve_field_mapping(value_template, branch_results, action_context)
if mapped_value is not None:
mapped_fields[db_field] = mapped_value
logger.debug(f"Mapped field: {db_field} = {mapped_value}")
else:
logger.warning(f"Could not resolve field mapping for {db_field}: {value_template}")
except Exception as e:
logger.error(f"Error mapping field {db_field} with template '{value_template}': {e}")
if not mapped_fields:
logger.warning("No fields mapped successfully, skipping database update")
return
# Execute the database update
success = node["db_manager"].execute_update(table, key_field, key_value, mapped_fields)
if success:
logger.info(f"Successfully updated database: {table} with {len(mapped_fields)} fields")
else:
logger.error(f"Failed to update database: {table}")
except KeyError as e:
logger.error(f"Missing required field in postgresql_update_combined action: {e}")
except Exception as e:
logger.error(f"Error in postgresql_update_combined action: {e}")
import traceback
logger.debug(f"Full traceback: {traceback.format_exc()}")
def resolve_field_mapping(value_template, branch_results, action_context):
"""Resolve field mapping templates like {car_brand_cls_v1.brand}."""
try:
# Handle simple context variables first (non-branch references)
if not '.' in value_template:
return value_template.format(**action_context)
# Handle branch result references like {model_id.field}
import re
branch_refs = re.findall(r'\{([^}]+\.[^}]+)\}', value_template)
resolved_template = value_template
for ref in branch_refs:
try:
model_id, field_name = ref.split('.', 1)
if model_id in branch_results:
branch_data = branch_results[model_id]
if field_name in branch_data:
field_value = branch_data[field_name]
resolved_template = resolved_template.replace(f'{{{ref}}}', str(field_value))
logger.debug(f"Resolved {ref} to {field_value}")
else:
logger.warning(f"Field '{field_name}' not found in branch '{model_id}' results. Available fields: {list(branch_data.keys())}")
return None
else:
logger.warning(f"Branch '{model_id}' not found in results. Available branches: {list(branch_results.keys())}")
return None
except ValueError as e:
logger.error(f"Invalid branch reference format: {ref}")
return None
# Format any remaining simple variables
try:
final_value = resolved_template.format(**action_context)
return final_value
except KeyError as e:
logger.warning(f"Could not resolve context variable in template: {e}")
return resolved_template
except Exception as e:
logger.error(f"Error resolving field mapping '{value_template}': {e}")
return None
def validate_pipeline_execution(node, regions_dict):
""" """
- For detection nodes (task != 'classify'): Pre-validate that all required branches will execute successfully before
runs `track(..., classes=triggerClassIndices)` committing to Redis actions and database records.
picks top box minConfidence
optionally crops & resizes recurse into child Returns:
else returns (det_dict, bbox) - (True, []) if pipeline can execute completely
- For classify nodes: - (False, missing_branches) if some required branches won't execute
runs `predict()` """
returns top (class,confidence) and no bbox # Get all branches that parallel actions are waiting for
required_branches = set()
for action in node.get("parallelActions", []):
if action.get("type") == "postgresql_update_combined":
wait_for_branches = action.get("waitForBranches", [])
required_branches.update(wait_for_branches)
if not required_branches:
# No parallel actions requiring specific branches
logger.debug("No parallel actions with waitForBranches - validation passes")
return True, []
logger.debug(f"Pre-validation: checking if required branches {list(required_branches)} will execute")
# Check each required branch
missing_branches = []
for branch in node.get("branches", []):
branch_id = branch["modelId"]
if branch_id not in required_branches:
continue # This branch is not required by parallel actions
# Check if this branch would be triggered
trigger_classes = branch.get("triggerClasses", [])
min_conf = branch.get("minConfidence", 0)
branch_triggered = False
for det_class in regions_dict:
det_confidence = regions_dict[det_class]["confidence"]
if (det_class in trigger_classes and det_confidence >= min_conf):
branch_triggered = True
logger.debug(f"Pre-validation: branch {branch_id} WILL be triggered by {det_class} (conf={det_confidence:.3f} >= {min_conf})")
break
if not branch_triggered:
missing_branches.append(branch_id)
logger.warning(f"Pre-validation: branch {branch_id} will NOT be triggered - no matching classes or insufficient confidence")
logger.debug(f" Required: {trigger_classes} with min_conf={min_conf}")
logger.debug(f" Available: {[(cls, regions_dict[cls]['confidence']) for cls in regions_dict]}")
if missing_branches:
logger.error(f"Pipeline pre-validation FAILED: required branches {missing_branches} will not execute")
return False, missing_branches
else:
logger.info(f"Pipeline pre-validation PASSED: all required branches {list(required_branches)} will execute")
return True, []
def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
"""
Enhanced pipeline that supports:
- Multi-class detection (detecting multiple classes simultaneously)
- Parallel branch processing
- Region-based actions and cropping
- Context passing for session/camera information
""" """
try: try:
task = getattr(node["model"], "task", None) task = getattr(node["model"], "task", None)
# ─── Classification stage ─────────────────────────────────── # ─── Classification stage ───────────────────────────────────
if task == "classify": if task == "classify":
# run the classifier and grab its top-1 directly via the Probs API
results = node["model"].predict(frame, stream=False) results = node["model"].predict(frame, stream=False)
# nothing returned?
if not results: if not results:
return (None, None) if return_bbox else None return (None, None) if return_bbox else None
# take the first result's probs object r = results[0]
r = results[0]
probs = r.probs probs = r.probs
if probs is None: if probs is None:
return (None, None) if return_bbox else None return (None, None) if return_bbox else None
# get the top-1 class index and its confidence top1_idx = int(probs.top1)
top1_idx = int(probs.top1)
top1_conf = float(probs.top1conf) top1_conf = float(probs.top1conf)
class_name = node["model"].names[top1_idx]
det = { det = {
"class": node["model"].names[top1_idx], "class": class_name,
"confidence": top1_conf, "confidence": top1_conf,
"id": None "id": None,
class_name: class_name # Add class name as key for backward compatibility
} }
# Add specific field mappings for database operations based on model type
model_id = node.get("modelId", "").lower()
if "brand" in model_id or "brand_cls" in model_id:
det["brand"] = class_name
elif "bodytype" in model_id or "body" in model_id:
det["body_type"] = class_name
elif "color" in model_id:
det["color"] = class_name
execute_actions(node, frame, det)
return (det, None) if return_bbox else det return (det, None) if return_bbox else det
# ─── Detection stage - Multi-class support ──────────────────
# ─── Detection stage ────────────────────────────────────────
# only look for your triggerClasses
tk = node["triggerClassIndices"] tk = node["triggerClassIndices"]
logger.debug(f"Running detection for node {node['modelId']} with trigger classes: {node.get('triggerClasses', [])} (indices: {tk})")
logger.debug(f"Node configuration: minConfidence={node['minConfidence']}, multiClass={node.get('multiClass', False)}")
res = node["model"].track( res = node["model"].track(
frame, frame,
stream=False, stream=False,
@ -219,46 +630,238 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False):
**({"classes": tk} if tk else {}) **({"classes": tk} if tk else {})
)[0] )[0]
dets, boxes = [], [] # Collect all detections above confidence threshold
for box in res.boxes: all_detections = []
conf = float(box.cpu().conf[0]) all_boxes = []
cid = int(box.cpu().cls[0]) regions_dict = {}
name = node["model"].names[cid]
if conf < node["minConfidence"]:
continue
xy = box.cpu().xyxy[0]
x1,y1,x2,y2 = map(int, xy)
dets.append({"class": name, "confidence": conf,
"id": box.id.item() if hasattr(box, "id") else None})
boxes.append((x1, y1, x2, y2))
if not dets: logger.debug(f"Raw detection results from model: {len(res.boxes) if res.boxes is not None else 0} detections")
for i, box in enumerate(res.boxes):
conf = float(box.cpu().conf[0])
cid = int(box.cpu().cls[0])
name = node["model"].names[cid]
logger.debug(f"Detection {i}: class='{name}' (id={cid}), confidence={conf:.3f}, threshold={node['minConfidence']}")
if conf < node["minConfidence"]:
logger.debug(f" -> REJECTED: confidence {conf:.3f} < threshold {node['minConfidence']}")
continue
xy = box.cpu().xyxy[0]
x1, y1, x2, y2 = map(int, xy)
bbox = (x1, y1, x2, y2)
detection = {
"class": name,
"confidence": conf,
"id": box.id.item() if hasattr(box, "id") else None,
"bbox": bbox
}
all_detections.append(detection)
all_boxes.append(bbox)
logger.debug(f" -> ACCEPTED: {name} with confidence {conf:.3f}, bbox={bbox}")
# Store highest confidence detection for each class
if name not in regions_dict or conf > regions_dict[name]["confidence"]:
regions_dict[name] = {
"bbox": bbox,
"confidence": conf,
"detection": detection
}
logger.debug(f" -> Updated regions_dict['{name}'] with confidence {conf:.3f}")
logger.info(f"Detection summary: {len(all_detections)} accepted detections from {len(res.boxes) if res.boxes is not None else 0} total")
logger.info(f"Detected classes: {list(regions_dict.keys())}")
if not all_detections:
logger.warning("No detections above confidence threshold - returning null")
return (None, None) if return_bbox else None return (None, None) if return_bbox else None
# take highestconfidence # ─── Multi-class validation ─────────────────────────────────
best_idx = max(range(len(dets)), key=lambda i: dets[i]["confidence"]) if node.get("multiClass", False) and node.get("expectedClasses"):
best_det = dets[best_idx] expected_classes = node["expectedClasses"]
best_box = boxes[best_idx] detected_classes = list(regions_dict.keys())
# ─── Branch (classification) ─────────────────────────────── logger.info(f"Multi-class validation: expected={expected_classes}, detected={detected_classes}")
for br in node["branches"]:
if (best_det["class"] in br["triggerClasses"]
and best_det["confidence"] >= br["minConfidence"]):
# crop if requested
sub = frame
if br["crop"]:
x1,y1,x2,y2 = best_box
sub = frame[y1:y2, x1:x2]
sub = cv2.resize(sub, (224, 224))
det2, _ = run_pipeline(sub, br, return_bbox=True) # Check if at least one expected class is detected (flexible mode)
if det2: matching_classes = [cls for cls in expected_classes if cls in detected_classes]
# return classification result + original bbox missing_classes = [cls for cls in expected_classes if cls not in detected_classes]
return (det2, best_box) if return_bbox else det2
# ─── No branch matched → return this detection ───────────── logger.debug(f"Matching classes: {matching_classes}, Missing classes: {missing_classes}")
return (best_det, best_box) if return_bbox else best_det
if not matching_classes:
# No expected classes found at all
logger.warning(f"PIPELINE REJECTED: No expected classes detected. Expected: {expected_classes}, Detected: {detected_classes}")
return (None, None) if return_bbox else None
if missing_classes:
logger.info(f"Partial multi-class detection: {matching_classes} found, {missing_classes} missing")
else:
logger.info(f"Complete multi-class detection success: {detected_classes}")
else:
logger.debug("No multi-class validation - proceeding with all detections")
# ─── Pre-validate pipeline execution ────────────────────────
pipeline_valid, missing_branches = validate_pipeline_execution(node, regions_dict)
if not pipeline_valid:
logger.error(f"Pipeline execution validation FAILED - required branches {missing_branches} cannot execute")
logger.error("Aborting pipeline: no Redis actions or database records will be created")
return (None, None) if return_bbox else None
# ─── Execute actions with region information ────────────────
detection_result = {
"detections": all_detections,
"regions": regions_dict,
**(context or {})
}
# ─── Create initial database record when Car+Frontal detected ────
if node.get("db_manager") and node.get("multiClass", False):
# Only create database record if we have both Car and Frontal
has_car = "Car" in regions_dict
has_frontal = "Frontal" in regions_dict
if has_car and has_frontal:
# Generate UUID session_id since client session is None for now
import uuid as uuid_lib
from datetime import datetime
generated_session_id = str(uuid_lib.uuid4())
# Insert initial detection record
display_id = detection_result.get("display_id", "unknown")
timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
inserted_session_id = node["db_manager"].insert_initial_detection(
display_id=display_id,
captured_timestamp=timestamp,
session_id=generated_session_id
)
if inserted_session_id:
# Update detection_result with the generated session_id for actions and branches
detection_result["session_id"] = inserted_session_id
detection_result["timestamp"] = timestamp # Update with proper timestamp
logger.info(f"Created initial database record with session_id: {inserted_session_id}")
else:
logger.debug(f"Database record not created - missing required classes. Has Car: {has_car}, Has Frontal: {has_frontal}")
execute_actions(node, frame, detection_result, regions_dict)
# ─── Parallel branch processing ─────────────────────────────
if node["branches"]:
branch_results = {}
# Filter branches that should be triggered
active_branches = []
for br in node["branches"]:
trigger_classes = br.get("triggerClasses", [])
min_conf = br.get("minConfidence", 0)
logger.debug(f"Evaluating branch {br['modelId']}: trigger_classes={trigger_classes}, min_conf={min_conf}")
# Check if any detected class matches branch trigger
branch_triggered = False
for det_class in regions_dict:
det_confidence = regions_dict[det_class]["confidence"]
logger.debug(f" Checking detected class '{det_class}' (confidence={det_confidence:.3f}) against triggers {trigger_classes}")
if (det_class in trigger_classes and det_confidence >= min_conf):
active_branches.append(br)
branch_triggered = True
logger.info(f"Branch {br['modelId']} activated by class '{det_class}' (conf={det_confidence:.3f} >= {min_conf})")
break
if not branch_triggered:
logger.debug(f"Branch {br['modelId']} not triggered - no matching classes or insufficient confidence")
if active_branches:
if node.get("parallel", False) or any(br.get("parallel", False) for br in active_branches):
# Run branches in parallel
with concurrent.futures.ThreadPoolExecutor(max_workers=len(active_branches)) as executor:
futures = {}
for br in active_branches:
crop_class = br.get("cropClass", br.get("triggerClasses", [])[0] if br.get("triggerClasses") else None)
sub_frame = frame
logger.info(f"Starting parallel branch: {br['modelId']}, crop_class: {crop_class}")
if br.get("crop", False) and crop_class:
cropped = crop_region_by_class(frame, regions_dict, crop_class)
if cropped is not None:
sub_frame = cv2.resize(cropped, (224, 224))
logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']}")
else:
logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch")
continue
future = executor.submit(run_pipeline, sub_frame, br, True, context)
futures[future] = br
# Collect results
for future in concurrent.futures.as_completed(futures):
br = futures[future]
try:
result, _ = future.result()
if result:
branch_results[br["modelId"]] = result
logger.info(f"Branch {br['modelId']} completed: {result}")
except Exception as e:
logger.error(f"Branch {br['modelId']} failed: {e}")
else:
# Run branches sequentially
for br in active_branches:
crop_class = br.get("cropClass", br.get("triggerClasses", [])[0] if br.get("triggerClasses") else None)
sub_frame = frame
logger.info(f"Starting sequential branch: {br['modelId']}, crop_class: {crop_class}")
if br.get("crop", False) and crop_class:
cropped = crop_region_by_class(frame, regions_dict, crop_class)
if cropped is not None:
sub_frame = cv2.resize(cropped, (224, 224))
logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']}")
else:
logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch")
continue
try:
result, _ = run_pipeline(sub_frame, br, True, context)
if result:
branch_results[br["modelId"]] = result
logger.info(f"Branch {br['modelId']} completed: {result}")
else:
logger.warning(f"Branch {br['modelId']} returned no result")
except Exception as e:
logger.error(f"Error in sequential branch {br['modelId']}: {e}")
import traceback
logger.debug(f"Branch error traceback: {traceback.format_exc()}")
# Store branch results in detection_result for parallel actions
detection_result["branch_results"] = branch_results
# ─── Execute Parallel Actions ───────────────────────────────
if node.get("parallelActions") and "branch_results" in detection_result:
execute_parallel_actions(node, frame, detection_result, regions_dict)
# ─── Return detection result ────────────────────────────────
primary_detection = max(all_detections, key=lambda x: x["confidence"])
primary_bbox = primary_detection["bbox"]
# Add branch results and session_id to primary detection for compatibility
if "branch_results" in detection_result:
primary_detection["branch_results"] = detection_result["branch_results"]
if "session_id" in detection_result:
primary_detection["session_id"] = detection_result["session_id"]
return (primary_detection, primary_bbox) if return_bbox else primary_detection
except Exception as e: except Exception as e:
logging.error(f"Error in node {node.get('modelId')}: {e}") logger.error(f"Error in node {node.get('modelId')}: {e}")
traceback.print_exc()
return (None, None) if return_bbox else None return (None, None) if return_bbox else None

125
test_protocol.py Normal file
View file

@ -0,0 +1,125 @@
#!/usr/bin/env python3
"""
Test script to verify the worker implementation follows the protocol
"""
import json
import asyncio
import websockets
import time
async def test_protocol():
"""Test the worker protocol implementation"""
uri = "ws://localhost:8000"
try:
async with websockets.connect(uri) as websocket:
print("✓ Connected to worker")
# Test 1: Check if we receive heartbeat (stateReport)
print("\n1. Testing heartbeat...")
try:
message = await asyncio.wait_for(websocket.recv(), timeout=5)
data = json.loads(message)
if data.get("type") == "stateReport":
print("✓ Received stateReport heartbeat")
print(f" - CPU Usage: {data.get('cpuUsage', 'N/A')}%")
print(f" - Memory Usage: {data.get('memoryUsage', 'N/A')}%")
print(f" - Camera Connections: {len(data.get('cameraConnections', []))}")
else:
print(f"✗ Expected stateReport, got {data.get('type')}")
except asyncio.TimeoutError:
print("✗ No heartbeat received within 5 seconds")
# Test 2: Request state
print("\n2. Testing requestState...")
await websocket.send(json.dumps({"type": "requestState"}))
try:
message = await asyncio.wait_for(websocket.recv(), timeout=5)
data = json.loads(message)
if data.get("type") == "stateReport":
print("✓ Received stateReport response")
else:
print(f"✗ Expected stateReport, got {data.get('type')}")
except asyncio.TimeoutError:
print("✗ No response to requestState within 5 seconds")
# Test 3: Set session ID
print("\n3. Testing setSessionId...")
session_message = {
"type": "setSessionId",
"payload": {
"displayIdentifier": "display-001",
"sessionId": 12345
}
}
await websocket.send(json.dumps(session_message))
print("✓ Sent setSessionId message")
# Test 4: Test patchSession
print("\n4. Testing patchSession...")
patch_message = {
"type": "patchSession",
"sessionId": 12345,
"data": {
"currentCar": {
"carModel": "Civic",
"carBrand": "Honda"
}
}
}
await websocket.send(json.dumps(patch_message))
# Wait for patchSessionResult
try:
message = await asyncio.wait_for(websocket.recv(), timeout=5)
data = json.loads(message)
if data.get("type") == "patchSessionResult":
print("✓ Received patchSessionResult")
print(f" - Success: {data.get('payload', {}).get('success')}")
print(f" - Message: {data.get('payload', {}).get('message')}")
else:
print(f"✗ Expected patchSessionResult, got {data.get('type')}")
except asyncio.TimeoutError:
print("✗ No patchSessionResult received within 5 seconds")
# Test 5: Test subscribe message format (without actual camera)
print("\n5. Testing subscribe message format...")
subscribe_message = {
"type": "subscribe",
"payload": {
"subscriptionIdentifier": "display-001;cam-001",
"snapshotUrl": "http://example.com/snapshot.jpg",
"snapshotInterval": 5000,
"modelUrl": "http://example.com/model.mpta",
"modelName": "Test Model",
"modelId": 101,
"cropX1": 100,
"cropY1": 200,
"cropX2": 300,
"cropY2": 400
}
}
await websocket.send(json.dumps(subscribe_message))
print("✓ Sent subscribe message (will fail without actual camera/model)")
# Listen for a few more messages to catch any errors
print("\n6. Listening for additional messages...")
for i in range(3):
try:
message = await asyncio.wait_for(websocket.recv(), timeout=2)
data = json.loads(message)
msg_type = data.get("type")
print(f" - Received {msg_type}")
if msg_type == "error":
print(f" Error: {data.get('error')}")
except asyncio.TimeoutError:
break
print("\n✓ Protocol test completed successfully!")
except Exception as e:
print(f"✗ Connection failed: {e}")
print("Make sure the worker is running on localhost:8000")
if __name__ == "__main__":
asyncio.run(test_protocol())

495
worker.md Normal file
View file

@ -0,0 +1,495 @@
# Worker Communication Protocol
This document outlines the WebSocket-based communication protocol between the CMS backend and a detector worker. As a worker developer, your primary responsibility is to implement a WebSocket server that adheres to this protocol.
## 1. Connection
The worker must run a WebSocket server, preferably on port `8000`. The backend system, which is managed by a container orchestration service, will automatically discover and establish a WebSocket connection to your worker.
Upon a successful connection from the backend, you should begin sending `stateReport` messages as heartbeats.
## 2. Communication Overview
Communication is bidirectional and asynchronous. All messages are JSON objects with a `type` field that indicates the message's purpose, and an optional `payload` field containing the data.
- **Worker -> Backend:** You will send messages to the backend to report status, forward detection events, or request changes to session data.
- **Backend -> Worker:** The backend will send commands to you to manage camera subscriptions.
## 3. Dynamic Configuration via MPTA File
To enable modularity and dynamic configuration, the backend will send you a URL to a `.mpta` file when it issues a `subscribe` command. This file is a renamed `.zip` archive that contains everything your worker needs to perform its task.
**Your worker is responsible for:**
1. Fetching this file from the provided URL.
2. Extracting its contents.
3. Interpreting the contents to configure its internal pipeline.
**The contents of the `.mpta` file are entirely up to the user who configures the model in the CMS.** This allows for maximum flexibility. For example, the archive could contain:
- AI/ML Models: Pre-trained models for libraries like TensorFlow, PyTorch, or ONNX.
- Configuration Files: A `config.json` or `pipeline.yaml` that defines a sequence of operations, specifies model paths, or sets detection thresholds.
- Scripts: Custom Python scripts for pre-processing or post-processing.
- API Integration Details: A JSON file with endpoint information and credentials for interacting with third-party detection services.
Essentially, the `.mpta` file is a self-contained package that tells your worker _how_ to process the video stream for a given subscription.
## 4. Messages from Worker to Backend
These are the messages your worker is expected to send to the backend.
### 4.1. State Report (Heartbeat)
This message is crucial for the backend to monitor your worker's health and status, including GPU usage.
- **Type:** `stateReport`
- **When to Send:** Periodically (e.g., every 2 seconds) after a connection is established.
**Payload:**
```json
{
"type": "stateReport",
"cpuUsage": 75.5,
"memoryUsage": 40.2,
"gpuUsage": 60.0,
"gpuMemoryUsage": 25.1,
"cameraConnections": [
{
"subscriptionIdentifier": "display-001;cam-001",
"modelId": 101,
"modelName": "General Object Detection",
"online": true,
"cropX1": 100,
"cropY1": 200,
"cropX2": 300,
"cropY2": 400
}
]
}
```
> **Note:**
>
> - `cropX1`, `cropY1`, `cropX2`, `cropY2` (optional, integer) should be included in each camera connection to indicate the crop coordinates for that subscription.
### 4.2. Image Detection
Sent when the worker detects a relevant object. The `detection` object should be flat and contain key-value pairs corresponding to the detected attributes.
- **Type:** `imageDetection`
**Payload Example:**
```json
{
"type": "imageDetection",
"subscriptionIdentifier": "display-001;cam-001",
"timestamp": "2025-07-14T12:34:56.789Z",
"data": {
"detection": {
"carModel": "Civic",
"carBrand": "Honda",
"carYear": 2023,
"bodyType": "Sedan",
"licensePlateText": "ABCD1234",
"licensePlateConfidence": 0.95
},
"modelId": 101,
"modelName": "US-LPR-and-Vehicle-ID"
}
}
```
### 4.3. Patch Session
> **Note:** Patch messages are only used when the worker can't keep up and needs to retroactively send detections. Normally, detections should be sent in real-time using `imageDetection` messages. Use `patchSession` only to update session data after the fact.
Allows the worker to request a modification to an active session's data. The `data` payload must be a partial object of the `DisplayPersistentData` structure.
- **Type:** `patchSession`
**Payload Example:**
```json
{
"type": "patchSession",
"sessionId": 12345,
"data": {
"currentCar": {
"carModel": "Civic",
"carBrand": "Honda",
"licensePlateText": "ABCD1234"
}
}
}
```
The backend will respond with a `patchSessionResult` command.
#### `DisplayPersistentData` Structure
The `data` object in the `patchSession` message is merged with the existing `DisplayPersistentData` on the backend. Here is its structure:
```typescript
interface DisplayPersistentData {
progressionStage:
| 'welcome'
| 'car_fueling'
| 'car_waitpayment'
| 'car_postpayment'
| null;
qrCode: string | null;
adsPlayback: {
playlistSlotOrder: number; // The 'order' of the current slot
adsId: number | null;
adsUrl: string | null;
} | null;
currentCar: {
carModel?: string;
carBrand?: string;
carYear?: number;
bodyType?: string;
licensePlateText?: string;
licensePlateType?: string;
} | null;
fuelPump: {
/* FuelPumpData structure */
} | null;
weatherData: {
/* WeatherResponse structure */
} | null;
sessionId: number | null;
}
```
#### Patching Behavior
- The patch is a **deep merge**.
- **`undefined`** values are ignored.
- **`null`** values will set the corresponding field to `null`.
- Nested objects are merged recursively.
## 5. Commands from Backend to Worker
These are the commands your worker will receive from the backend.
### 5.1. Subscribe to Camera
Instructs the worker to process a camera's RTSP stream using the configuration from the specified `.mpta` file.
- **Type:** `subscribe`
**Payload:**
```json
{
"type": "subscribe",
"payload": {
"subscriptionIdentifier": "display-001;cam-002",
"rtspUrl": "rtsp://user:pass@host:port/stream",
"snapshotUrl": "http://go2rtc/snapshot/1",
"snapshotInterval": 5000,
"modelUrl": "http://storage/models/us-lpr.mpta",
"modelName": "US-LPR-and-Vehicle-ID",
"modelId": 102,
"cropX1": 100,
"cropY1": 200,
"cropX2": 300,
"cropY2": 400
}
}
```
> **Note:**
>
> - `cropX1`, `cropY1`, `cropX2`, `cropY2` (optional, integer) specify the crop coordinates for the camera stream. These values are configured per display and passed in the subscription payload. If not provided, the worker should process the full frame.
>
> **Important:**
> If multiple displays are bound to the same camera, your worker must ensure that only **one stream** is opened per camera. When you receive multiple subscriptions for the same camera (with different `subscriptionIdentifier` values), you should:
>
> - Open the RTSP stream **once** for that camera if using RTSP.
> - Capture each snapshot only once per cycle, and reuse it for all display subscriptions sharing that camera.
> - Capture each frame/image only once per cycle.
> - Reuse the same captured image and snapshot for all display subscriptions that share the camera, processing and routing detection results separately for each display as needed.
> This avoids unnecessary load and bandwidth usage, and ensures consistent detection results and snapshots across all displays sharing the same camera.
### 5.2. Unsubscribe from Camera
Instructs the worker to stop processing a camera's stream.
- **Type:** `unsubscribe`
**Payload:**
```json
{
"type": "unsubscribe",
"payload": {
"subscriptionIdentifier": "display-001;cam-002"
}
}
```
### 5.3. Request State
Direct request for the worker's current state. Respond with a `stateReport` message.
- **Type:** `requestState`
**Payload:**
```json
{
"type": "requestState"
}
```
### 5.4. Patch Session Result
Backend's response to a `patchSession` message.
- **Type:** `patchSessionResult`
**Payload:**
```json
{
"type": "patchSessionResult",
"payload": {
"sessionId": 12345,
"success": true,
"message": "Session updated successfully."
}
}
```
### 5.5. Set Session ID
Allows the backend to instruct the worker to associate a session ID with a subscription. This is useful for linking detection events to a specific session. The session ID can be `null` to indicate no active session.
- **Type:** `setSessionId`
**Payload:**
```json
{
"type": "setSessionId",
"payload": {
"displayIdentifier": "display-001",
"sessionId": 12345
}
}
```
Or to clear the session:
```json
{
"type": "setSessionId",
"payload": {
"displayIdentifier": "display-001",
"sessionId": null
}
}
```
> **Note:**
>
> - The worker should store the session ID for the given subscription and use it in subsequent detection or patch messages as appropriate. If `sessionId` is `null`, the worker should treat the subscription as having no active session.
## Subscription Identifier Format
The `subscriptionIdentifier` used in all messages is constructed as:
```
displayIdentifier;cameraIdentifier
```
This uniquely identifies a camera subscription for a specific display.
### Session ID Association
When the backend sends a `setSessionId` command, it will only provide the `displayIdentifier` (not the full `subscriptionIdentifier`).
**Worker Responsibility:**
- The worker must match the `displayIdentifier` to all active subscriptions for that display (i.e., all `subscriptionIdentifier` values that start with `displayIdentifier;`).
- The worker should set or clear the session ID for all matching subscriptions.
## 6. Example Communication Log
This section shows a typical sequence of messages between the backend and the worker. Patch messages are not included, as they are only used when the worker cannot keep up.
> **Note:** Unsubscribe is triggered when a user removes a camera or when the node is too heavily loaded and needs rebalancing.
1. **Connection Established** & **Heartbeat**
- **Worker -> Backend**
```json
{
"type": "stateReport",
"cpuUsage": 70.2,
"memoryUsage": 38.1,
"gpuUsage": 55.0,
"gpuMemoryUsage": 20.0,
"cameraConnections": []
}
```
2. **Backend Subscribes Camera**
- **Backend -> Worker**
```json
{
"type": "subscribe",
"payload": {
"subscriptionIdentifier": "display-001;entry-cam-01",
"rtspUrl": "rtsp://192.168.1.100/stream1",
"modelUrl": "http://storage/models/vehicle-id.mpta",
"modelName": "Vehicle Identification",
"modelId": 201
}
}
```
3. **Worker Acknowledges in Heartbeat**
- **Worker -> Backend**
```json
{
"type": "stateReport",
"cpuUsage": 72.5,
"memoryUsage": 39.0,
"gpuUsage": 57.0,
"gpuMemoryUsage": 21.0,
"cameraConnections": [
{
"subscriptionIdentifier": "display-001;entry-cam-01",
"modelId": 201,
"modelName": "Vehicle Identification",
"online": true
}
]
}
```
4. **Worker Detects a Car**
- **Worker -> Backend**
```json
{
"type": "imageDetection",
"subscriptionIdentifier": "display-001;entry-cam-01",
"timestamp": "2025-07-15T10:00:00.000Z",
"data": {
"detection": {
"carBrand": "Honda",
"carModel": "CR-V",
"bodyType": "SUV",
"licensePlateText": "GEMINI-AI",
"licensePlateConfidence": 0.98
},
"modelId": 201,
"modelName": "Vehicle Identification"
}
}
```
- **Worker -> Backend**
```json
{
"type": "imageDetection",
"subscriptionIdentifier": "display-001;entry-cam-01",
"timestamp": "2025-07-15T10:00:01.000Z",
"data": {
"detection": {
"carBrand": "Toyota",
"carModel": "Corolla",
"bodyType": "Sedan",
"licensePlateText": "CMS-1234",
"licensePlateConfidence": 0.97
},
"modelId": 201,
"modelName": "Vehicle Identification"
}
}
```
- **Worker -> Backend**
```json
{
"type": "imageDetection",
"subscriptionIdentifier": "display-001;entry-cam-01",
"timestamp": "2025-07-15T10:00:02.000Z",
"data": {
"detection": {
"carBrand": "Ford",
"carModel": "Focus",
"bodyType": "Hatchback",
"licensePlateText": "CMS-5678",
"licensePlateConfidence": 0.96
},
"modelId": 201,
"modelName": "Vehicle Identification"
}
}
```
5. **Backend Unsubscribes Camera**
- **Backend -> Worker**
```json
{
"type": "unsubscribe",
"payload": {
"subscriptionIdentifier": "display-001;entry-cam-01"
}
}
```
6. **Worker Acknowledges Unsubscription**
- **Worker -> Backend**
```json
{
"type": "stateReport",
"cpuUsage": 68.0,
"memoryUsage": 37.0,
"gpuUsage": 50.0,
"gpuMemoryUsage": 18.0,
"cameraConnections": []
}
```
## 7. HTTP API: Image Retrieval
In addition to the WebSocket protocol, the worker exposes an HTTP endpoint for retrieving the latest image frame from a camera.
### Endpoint
```
GET /camera/{camera_id}/image
```
- **`camera_id`**: The full `subscriptionIdentifier` (e.g., `display-001;cam-001`).
### Response
- **Success (200):** Returns the latest JPEG image from the camera stream.
- `Content-Type: image/jpeg`
- Binary JPEG data.
- **Error (404):** If the camera is not found or no frame is available.
- JSON error response.
- **Error (500):** Internal server error.
### Example Request
```
GET /camera/display-001;cam-001/image
```
### Example Response
- **Headers:**
```
Content-Type: image/jpeg
```
- **Body:** Binary JPEG image.
### Notes
- The endpoint returns the most recent frame available for the specified camera subscription.
- If multiple displays share the same camera, each subscription has its own buffer; the endpoint uses the buffer for the given `camera_id`.
- This API is useful for debugging, monitoring, or integrating with external systems that require direct image access.