resolve merge conflicts by accepting main branch versions

This commit is contained in:
Pongsatorn Kanjanasantisak 2025-08-11 00:57:21 +07:00
commit 48db3234ed
11 changed files with 1278 additions and 243 deletions

View file

@ -1,13 +1,68 @@
name: Build Backend Application and Docker Image name: Build Worker Base and Application Images
on: on:
push: push:
branches: branches:
- main - main
- dev
workflow_dispatch: workflow_dispatch:
inputs:
force_base_build:
description: 'Force base image build regardless of changes'
required: false
default: 'false'
type: boolean
jobs: jobs:
check-base-changes:
runs-on: ubuntu-latest
outputs:
base-changed: ${{ steps.changes.outputs.base-changed }}
steps:
- name: Checkout code
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Check for base changes
id: changes
run: |
if git diff HEAD^ HEAD --name-only | grep -E "(Dockerfile\.base|requirements\.base\.txt)" > /dev/null; then
echo "base-changed=true" >> $GITHUB_OUTPUT
else
echo "base-changed=false" >> $GITHUB_OUTPUT
fi
build-base:
needs: check-base-changes
if: needs.check-base-changes.outputs.base-changed == 'true' || (github.event_name == 'workflow_dispatch' && github.event.inputs.force_base_build == 'true')
runs-on: ubuntu-latest
permissions:
packages: write
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: git.siwatsystem.com
username: ${{ github.actor }}
password: ${{ secrets.RUNNER_TOKEN }}
- name: Build and push base Docker image
uses: docker/build-push-action@v4
with:
context: .
file: ./Dockerfile.base
push: true
tags: git.siwatsystem.com/adsist-cms/worker-base:latest
build-docker: build-docker:
needs: [check-base-changes, build-base]
if: always() && (needs.build-base.result == 'success' || needs.build-base.result == 'skipped')
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:
packages: write packages: write
@ -31,4 +86,27 @@ jobs:
context: . context: .
file: ./Dockerfile file: ./Dockerfile
push: true push: true
tags: git.siwatsystem.com/adsist-cms/worker:latest tags: git.siwatsystem.com/adsist-cms/worker:${{ github.ref_name == 'main' && 'latest' || 'dev' }}
deploy-stack:
needs: build-docker
runs-on: adsist
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up SSH connection
run: |
mkdir -p ~/.ssh
echo "${{ secrets.DEPLOY_KEY_CMS }}" > ~/.ssh/id_rsa
chmod 600 ~/.ssh/id_rsa
ssh-keyscan -H ${{ vars.DEPLOY_HOST_CMS }} >> ~/.ssh/known_hosts
- name: Deploy stack
run: |
echo "Pulling and starting containers on server..."
if [ "${{ github.ref_name }}" = "main" ]; then
echo "Deploying production stack..."
ssh -i ~/.ssh/id_rsa ${{ vars.DEPLOY_USER_CMS }}@${{ vars.DEPLOY_HOST_CMS }} "cd ~/cms-system-k8s && docker compose -f docker-compose.production.yml pull && docker compose -f docker-compose.production.yml up -d"
else
echo "Deploying staging stack..."
ssh -i ~/.ssh/id_rsa ${{ vars.DEPLOY_USER_CMS }}@${{ vars.DEPLOY_HOST_CMS }} "cd ~/cms-system-k8s && docker compose -f docker-compose.staging.yml pull && docker compose -f docker-compose.staging.yml up -d"
fi

135
CLAUDE.md
View file

@ -1,13 +1,23 @@
# Python Detector Worker - CLAUDE.md # Python Detector Worker - CLAUDE.md
## Project Overview ## Project Overview
This is a FastAPI-based computer vision detection worker that processes video streams from RTSP/HTTP sources and runs YOLO-based machine learning pipelines for object detection and classification. The system is designed to work within a larger CMS (Content Management System) architecture. This is a FastAPI-based computer vision detection worker that processes video streams from RTSP/HTTP sources and runs advanced YOLO-based machine learning pipelines for multi-class object detection and parallel classification. The system features comprehensive database integration, Redis support, and hierarchical pipeline execution designed to work within a larger CMS (Content Management System) architecture.
### Key Features
- **Multi-Class Detection**: Simultaneous detection of multiple object classes (e.g., Car + Frontal)
- **Parallel Processing**: Concurrent execution of classification branches using ThreadPoolExecutor
- **Database Integration**: Automatic PostgreSQL schema management and record updates
- **Redis Actions**: Image storage with region cropping and pub/sub messaging
- **Pipeline Synchronization**: Branch coordination with `waitForBranches` functionality
- **Dynamic Field Mapping**: Template-based field resolution for database operations
## Architecture & Technology Stack ## Architecture & Technology Stack
- **Framework**: FastAPI with WebSocket support - **Framework**: FastAPI with WebSocket support
- **ML/CV**: PyTorch, Ultralytics YOLO, OpenCV - **ML/CV**: PyTorch, Ultralytics YOLO, OpenCV
- **Containerization**: Docker (Python 3.13-bookworm base) - **Containerization**: Docker (Python 3.13-bookworm base)
- **Data Storage**: Redis integration for action handling - **Data Storage**: Redis integration for action handling + PostgreSQL for persistent storage
- **Database**: Automatic schema management with gas_station_1 database
- **Parallel Processing**: ThreadPoolExecutor for concurrent classification
- **Communication**: WebSocket-based real-time protocol - **Communication**: WebSocket-based real-time protocol
## Core Components ## Core Components
@ -24,9 +34,20 @@ This is a FastAPI-based computer vision detection worker that processes video st
### Pipeline System (`siwatsystem/pympta.py`) ### Pipeline System (`siwatsystem/pympta.py`)
- **MPTA file handling** - ZIP archives containing model configurations - **MPTA file handling** - ZIP archives containing model configurations
- **Hierarchical pipeline execution** with detection → classification branching - **Hierarchical pipeline execution** with detection → classification branching
- **Redis action system** for image saving and message publishing - **Multi-class detection** - Simultaneous detection of multiple classes (Car + Frontal)
- **Parallel processing** - Concurrent classification branches with ThreadPoolExecutor
- **Redis action system** - Image saving with region cropping and message publishing
- **PostgreSQL integration** - Automatic table creation and combined updates
- **Dynamic model loading** with GPU optimization - **Dynamic model loading** with GPU optimization
- **Configurable trigger classes and confidence thresholds** - **Configurable trigger classes and confidence thresholds**
- **Branch synchronization** - waitForBranches coordination for database updates
### Database System (`siwatsystem/database.py`)
- **DatabaseManager class** for PostgreSQL operations
- **Automatic table creation** with gas_station_1.car_frontal_info schema
- **Combined update operations** with field mapping from branch results
- **Session management** with UUID generation
- **Error handling** and connection management
### Testing & Debugging ### Testing & Debugging
- **Protocol test script** (`test_protocol.py`) for WebSocket communication validation - **Protocol test script** (`test_protocol.py`) for WebSocket communication validation
@ -92,33 +113,61 @@ This is a FastAPI-based computer vision detection worker that processes video st
## Model Pipeline (MPTA) Format ## Model Pipeline (MPTA) Format
### Structure ### Enhanced Structure
- **ZIP archive** containing models and configuration - **ZIP archive** containing models and configuration
- **pipeline.json** - Main configuration file - **pipeline.json** - Main configuration file with Redis + PostgreSQL settings
- **Model files** - YOLO .pt files for detection/classification - **Model files** - YOLO .pt files for detection/classification
- **Redis configuration** - Optional for action execution - **Multi-model support** - Detection + multiple classification models
### Pipeline Flow ### Advanced Pipeline Flow
1. **Detection stage** - YOLO object detection with bounding boxes 1. **Multi-class detection stage** - YOLO detection of Car + Frontal simultaneously
2. **Trigger evaluation** - Check if detected class matches trigger conditions 2. **Validation stage** - Check for expected classes (flexible matching)
3. **Classification stage** - Crop detected region and run classification model 3. **Database initialization** - Create initial record with session_id
4. **Action execution** - Redis operations (image saving, message publishing) 4. **Redis actions** - Save cropped frontal images with expiration
5. **Parallel classification** - Concurrent brand and body type classification
6. **Branch synchronization** - Wait for all classification branches to complete
7. **Database update** - Combined update with all classification results
### Branch Configuration ### Enhanced Branch Configuration
```json ```json
{ {
"modelId": "detector-v1", "modelId": "car_frontal_detection_v1",
"modelFile": "detector.pt", "modelFile": "car_frontal_detection_v1.pt",
"triggerClasses": ["car", "truck"], "multiClass": true,
"minConfidence": 0.5, "expectedClasses": ["Car", "Frontal"],
"branches": [{ "triggerClasses": ["Car", "Frontal"],
"modelId": "classifier-v1", "minConfidence": 0.8,
"modelFile": "classifier.pt", "actions": [
{
"type": "redis_save_image",
"region": "Frontal",
"key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
"expire_seconds": 600
}
],
"branches": [
{
"modelId": "car_brand_cls_v1",
"modelFile": "car_brand_cls_v1.pt",
"parallel": true,
"crop": true, "crop": true,
"triggerClasses": ["car"], "cropClass": "Frontal",
"minConfidence": 0.3, "triggerClasses": ["Frontal"],
"actions": [...] "minConfidence": 0.85
}] }
],
"parallelActions": [
{
"type": "postgresql_update_combined",
"table": "car_frontal_info",
"key_field": "session_id",
"waitForBranches": ["car_brand_cls_v1", "car_bodytype_cls_v1"],
"fields": {
"car_brand": "{car_brand_cls_v1.brand}",
"car_body_type": "{car_bodytype_cls_v1.body_type}"
}
}
]
} }
``` ```
@ -173,6 +222,9 @@ docker run -p 8000:8000 -v ./models:/app/models detector-worker
- **opencv-python**: Computer vision operations - **opencv-python**: Computer vision operations
- **websockets**: WebSocket client/server - **websockets**: WebSocket client/server
- **redis**: Redis client for action execution - **redis**: Redis client for action execution
- **psycopg2-binary**: PostgreSQL database adapter
- **scipy**: Scientific computing for advanced algorithms
- **filterpy**: Kalman filtering and state estimation
## Security Considerations ## Security Considerations
- Model files are loaded from trusted sources only - Model files are loaded from trusted sources only
@ -180,9 +232,46 @@ docker run -p 8000:8000 -v ./models:/app/models detector-worker
- WebSocket connections handle disconnects gracefully - WebSocket connections handle disconnects gracefully
- Resource usage is monitored to prevent DoS - Resource usage is monitored to prevent DoS
## Database Integration
### Schema Management
The system automatically creates and manages PostgreSQL tables:
```sql
CREATE TABLE IF NOT EXISTS gas_station_1.car_frontal_info (
display_id VARCHAR(255),
captured_timestamp VARCHAR(255),
session_id VARCHAR(255) PRIMARY KEY,
license_character VARCHAR(255) DEFAULT NULL,
license_type VARCHAR(255) DEFAULT 'No model available',
car_brand VARCHAR(255) DEFAULT NULL,
car_model VARCHAR(255) DEFAULT NULL,
car_body_type VARCHAR(255) DEFAULT NULL,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
```
### Workflow
1. **Detection**: When both "Car" and "Frontal" are detected, create initial database record with UUID session_id
2. **Redis Storage**: Save cropped frontal image to Redis with session_id in key
3. **Parallel Processing**: Run brand and body type classification concurrently
4. **Synchronization**: Wait for all branches to complete using `waitForBranches`
5. **Database Update**: Update record with combined classification results using field mapping
### Field Mapping
Templates like `{car_brand_cls_v1.brand}` are resolved to actual classification results:
- `car_brand_cls_v1.brand` → "Honda"
- `car_bodytype_cls_v1.body_type` → "Sedan"
## Performance Optimizations ## Performance Optimizations
- GPU acceleration when CUDA is available - GPU acceleration when CUDA is available
- Shared camera streams reduce resource usage - Shared camera streams reduce resource usage
- Frame queue optimization (single latest frame) - Frame queue optimization (single latest frame)
- Model caching across subscriptions - Model caching across subscriptions
- Trigger class filtering for faster inference - Trigger class filtering for faster inference
- Parallel processing with ThreadPoolExecutor for classification branches
- Multi-class detection reduces inference passes
- Region-based cropping minimizes processing overhead
- Database connection pooling and prepared statements
- Redis image storage with automatic expiration

View file

@ -1,20 +1,12 @@
# Use newer, more secure base image # Use our pre-built base image with ML dependencies
FROM python:3.13-alpine FROM git.siwatsystem.com/adsist-cms/worker-base:latest
# Update system packages first # Copy and install application requirements (frequently changing dependencies)
RUN apk update && apk upgrade
# Install minimal dependencies
RUN apk add --no-cache mesa-gl
# Use specific package versions
COPY requirements.txt . COPY requirements.txt .
RUN pip install --no-cache-dir --upgrade pip && \ RUN pip install --no-cache-dir -r requirements.txt
pip install --no-cache-dir -r requirements.txt
# Run as non-root user
RUN adduser -D -s /bin/sh appuser
USER appuser
# Copy the application code
COPY . . COPY . .
# Run the application
CMD ["python3", "-m", "fastapi", "run", "--host", "0.0.0.0", "--port", "8000"] CMD ["python3", "-m", "fastapi", "run", "--host", "0.0.0.0", "--port", "8000"]

15
Dockerfile.base Normal file
View file

@ -0,0 +1,15 @@
# Base image with all ML dependencies
FROM python:3.13-bookworm
# Install system dependencies
RUN apt update && apt install -y libgl1 && rm -rf /var/lib/apt/lists/*
# Copy and install base requirements (ML dependencies that rarely change)
COPY requirements.base.txt .
RUN pip install --no-cache-dir -r requirements.base.txt
# Set working directory
WORKDIR /app
# This base image will be reused for all worker builds
CMD ["python3", "-m", "fastapi", "run", "--host", "0.0.0.0", "--port", "8000"]

108
app.py
View file

@ -35,6 +35,8 @@ session_ids: Dict[str, int] = {}
camera_streams: Dict[str, Dict[str, Any]] = {} camera_streams: Dict[str, Dict[str, Any]] = {}
# Map subscriptions to their camera URL # Map subscriptions to their camera URL
subscription_to_camera: Dict[str, str] = {} subscription_to_camera: Dict[str, str] = {}
# Store latest frames for REST API access (separate from processing buffer)
latest_frames: Dict[str, Any] = {}
with open("config.json", "r") as f: with open("config.json", "r") as f:
config = json.load(f) config = json.load(f)
@ -109,20 +111,60 @@ def download_mpta(url: str, dest_path: str) -> str:
# Add helper to fetch snapshot image from HTTP/HTTPS URL # Add helper to fetch snapshot image from HTTP/HTTPS URL
def fetch_snapshot(url: str): def fetch_snapshot(url: str):
try: try:
response = requests.get(url, timeout=10) from requests.auth import HTTPBasicAuth, HTTPDigestAuth
# Parse URL to extract credentials
parsed = urlparse(url)
# Prepare headers - some cameras require User-Agent
headers = {
'User-Agent': 'Mozilla/5.0 (compatible; DetectorWorker/1.0)'
}
# Reconstruct URL without credentials
clean_url = f"{parsed.scheme}://{parsed.hostname}"
if parsed.port:
clean_url += f":{parsed.port}"
clean_url += parsed.path
if parsed.query:
clean_url += f"?{parsed.query}"
auth = None
if parsed.username and parsed.password:
# Try HTTP Digest authentication first (common for IP cameras)
try:
auth = HTTPDigestAuth(parsed.username, parsed.password)
response = requests.get(clean_url, auth=auth, headers=headers, timeout=10)
if response.status_code == 200:
logger.debug(f"Successfully authenticated using HTTP Digest for {clean_url}")
elif response.status_code == 401:
# If Digest fails, try Basic auth
logger.debug(f"HTTP Digest failed, trying Basic auth for {clean_url}")
auth = HTTPBasicAuth(parsed.username, parsed.password)
response = requests.get(clean_url, auth=auth, headers=headers, timeout=10)
if response.status_code == 200:
logger.debug(f"Successfully authenticated using HTTP Basic for {clean_url}")
except Exception as auth_error:
logger.debug(f"Authentication setup error: {auth_error}")
# Fallback to original URL with embedded credentials
response = requests.get(url, headers=headers, timeout=10)
else:
# No credentials in URL, make request as-is
response = requests.get(url, headers=headers, timeout=10)
if response.status_code == 200: if response.status_code == 200:
# Convert response content to numpy array # Convert response content to numpy array
nparr = np.frombuffer(response.content, np.uint8) nparr = np.frombuffer(response.content, np.uint8)
# Decode image # Decode image
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if frame is not None: if frame is not None:
logger.debug(f"Successfully fetched snapshot from {url}, shape: {frame.shape}") logger.debug(f"Successfully fetched snapshot from {clean_url}, shape: {frame.shape}")
return frame return frame
else: else:
logger.error(f"Failed to decode image from snapshot URL: {url}") logger.error(f"Failed to decode image from snapshot URL: {clean_url}")
return None return None
else: else:
logger.error(f"Failed to fetch snapshot (status code {response.status_code}): {url}") logger.error(f"Failed to fetch snapshot (status code {response.status_code}): {clean_url}")
return None return None
except Exception as e: except Exception as e:
logger.error(f"Exception fetching snapshot from {url}: {str(e)}") logger.error(f"Exception fetching snapshot from {url}: {str(e)}")
@ -146,26 +188,24 @@ async def get_camera_image(camera_id: str):
Get the current frame from a camera as JPEG image Get the current frame from a camera as JPEG image
""" """
try: try:
# URL decode the camera_id to handle encoded characters like %3B for semicolon
from urllib.parse import unquote
original_camera_id = camera_id
camera_id = unquote(camera_id)
logger.debug(f"REST API request: original='{original_camera_id}', decoded='{camera_id}'")
with streams_lock: with streams_lock:
if camera_id not in streams: if camera_id not in streams:
logger.warning(f"Camera ID '{camera_id}' not found in streams. Current streams: {list(streams.keys())}") logger.warning(f"Camera ID '{camera_id}' not found in streams. Current streams: {list(streams.keys())}")
raise HTTPException(status_code=404, detail=f"Camera {camera_id} not found or not active") raise HTTPException(status_code=404, detail=f"Camera {camera_id} not found or not active")
stream = streams[camera_id] # Check if we have a cached frame for this camera
buffer = stream["buffer"] if camera_id not in latest_frames:
logger.debug(f"Camera '{camera_id}' buffer size: {buffer.qsize()}, buffer empty: {buffer.empty()}") logger.warning(f"No cached frame available for camera '{camera_id}'.")
logger.debug(f"Buffer queue contents: {getattr(buffer, 'queue', None)}")
if buffer.empty():
logger.warning(f"No frame available for camera '{camera_id}'. Buffer is empty.")
raise HTTPException(status_code=404, detail=f"No frame available for camera {camera_id}") raise HTTPException(status_code=404, detail=f"No frame available for camera {camera_id}")
# Get the latest frame (non-blocking) frame = latest_frames[camera_id]
try: logger.debug(f"Retrieved cached frame for camera '{camera_id}', frame shape: {frame.shape}")
frame = buffer.queue[-1] # Get the most recent frame without removing it
except IndexError:
logger.warning(f"Buffer queue is empty for camera '{camera_id}' when trying to access last frame.")
raise HTTPException(status_code=404, detail=f"No frame available for camera {camera_id}")
# Encode frame as JPEG # Encode frame as JPEG
success, buffer_img = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 85]) success, buffer_img = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
if not success: if not success:
@ -199,7 +239,20 @@ async def detect(websocket: WebSocket):
logger.debug(f"Processing frame for camera {camera_id} with model {stream['modelId']}") logger.debug(f"Processing frame for camera {camera_id} with model {stream['modelId']}")
start_time = time.time() start_time = time.time()
detection_result = run_pipeline(cropped_frame, model_tree)
# Extract display identifier for session ID lookup
subscription_parts = stream["subscriptionIdentifier"].split(';')
display_identifier = subscription_parts[0] if subscription_parts else None
session_id = session_ids.get(display_identifier) if display_identifier else None
# Create context for pipeline execution
pipeline_context = {
"camera_id": camera_id,
"display_id": display_identifier,
"session_id": session_id
}
detection_result = run_pipeline(cropped_frame, model_tree, context=pipeline_context)
process_time = (time.time() - start_time) * 1000 process_time = (time.time() - start_time) * 1000
logger.debug(f"Detection for camera {camera_id} completed in {process_time:.2f}ms") logger.debug(f"Detection for camera {camera_id} completed in {process_time:.2f}ms")
@ -258,11 +311,6 @@ async def detect(websocket: WebSocket):
if key not in ["box", "id"]: # Skip internal fields if key not in ["box", "id"]: # Skip internal fields
detection_dict[key] = value detection_dict[key] = value
# Extract display identifier for session ID lookup
subscription_parts = stream["subscriptionIdentifier"].split(';')
display_identifier = subscription_parts[0] if subscription_parts else None
session_id = session_ids.get(display_identifier) if display_identifier else None
detection_data = { detection_data = {
"type": "imageDetection", "type": "imageDetection",
"subscriptionIdentifier": stream["subscriptionIdentifier"], "subscriptionIdentifier": stream["subscriptionIdentifier"],
@ -282,9 +330,6 @@ async def detect(websocket: WebSocket):
logger.info(f"Camera {camera_id}: Detected {highest_confidence_detection['class']} with confidence {highest_confidence_detection['confidence']:.2f} using model {stream['modelName']}") logger.info(f"Camera {camera_id}: Detected {highest_confidence_detection['class']} with confidence {highest_confidence_detection['confidence']:.2f} using model {stream['modelName']}")
# Log session ID if available # Log session ID if available
subscription_parts = stream["subscriptionIdentifier"].split(';')
display_identifier = subscription_parts[0] if subscription_parts else None
session_id = session_ids.get(display_identifier) if display_identifier else None
if session_id: if session_id:
logger.debug(f"Detection associated with session ID: {session_id}") logger.debug(f"Detection associated with session ID: {session_id}")
@ -476,6 +521,10 @@ async def detect(websocket: WebSocket):
logger.debug(f"Got frame from buffer for camera {camera_id}") logger.debug(f"Got frame from buffer for camera {camera_id}")
frame = buffer.get() frame = buffer.get()
# Cache the frame for REST API access
latest_frames[camera_id] = frame.copy()
logger.debug(f"Cached frame for REST API access for camera {camera_id}")
with models_lock: with models_lock:
model_tree = models.get(camera_id, {}).get(stream["modelId"]) model_tree = models.get(camera_id, {}).get(stream["modelId"])
if not model_tree: if not model_tree:
@ -647,7 +696,7 @@ async def detect(websocket: WebSocket):
if snapshot_url and snapshot_interval: if snapshot_url and snapshot_interval:
logger.info(f"Creating new snapshot stream for camera {camera_id}: {snapshot_url}") logger.info(f"Creating new snapshot stream for camera {camera_id}: {snapshot_url}")
thread = threading.Thread(target=snapshot_reader, args=(camera_identifier, snapshot_url, snapshot_interval, buffer, stop_event)) thread = threading.Thread(target=snapshot_reader, args=(camera_id, snapshot_url, snapshot_interval, buffer, stop_event))
thread.daemon = True thread.daemon = True
thread.start() thread.start()
mode = "snapshot" mode = "snapshot"
@ -670,7 +719,7 @@ async def detect(websocket: WebSocket):
if not cap.isOpened(): if not cap.isOpened():
logger.error(f"Failed to open RTSP stream for camera {camera_id}") logger.error(f"Failed to open RTSP stream for camera {camera_id}")
continue continue
thread = threading.Thread(target=frame_reader, args=(camera_identifier, cap, buffer, stop_event)) thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event))
thread.daemon = True thread.daemon = True
thread.start() thread.start()
mode = "rtsp" mode = "rtsp"
@ -744,6 +793,8 @@ async def detect(websocket: WebSocket):
else: else:
logger.info(f"Shared stream for {camera_url} still has {shared_stream['ref_count']} references") logger.info(f"Shared stream for {camera_url} still has {shared_stream['ref_count']} references")
# Clean up cached frame
latest_frames.pop(camera_id, None)
logger.info(f"Unsubscribed from camera {camera_id}") logger.info(f"Unsubscribed from camera {camera_id}")
# Note: Keep models in memory for potential reuse # Note: Keep models in memory for potential reuse
elif msg_type == "requestState": elif msg_type == "requestState":
@ -847,5 +898,6 @@ async def detect(websocket: WebSocket):
subscription_to_camera.clear() subscription_to_camera.clear()
with models_lock: with models_lock:
models.clear() models.clear()
latest_frames.clear()
session_ids.clear() session_ids.clear()
logger.info("WebSocket connection closed") logger.info("WebSocket connection closed")

165
pympta.md
View file

@ -32,14 +32,15 @@ This modular structure allows for creating complex and efficient inference logic
## `pipeline.json` Specification ## `pipeline.json` Specification
This file defines the entire pipeline logic. The root object contains a `pipeline` key for the pipeline definition and an optional `redis` key for Redis configuration. This file defines the entire pipeline logic. The root object contains a `pipeline` key for the pipeline definition, optional `redis` key for Redis configuration, and optional `postgresql` key for database integration.
### Top-Level Object Structure ### Top-Level Object Structure
| Key | Type | Required | Description | | Key | Type | Required | Description |
| ---------- | ------ | -------- | ------------------------------------------------------- | | ------------ | ------ | -------- | ------------------------------------------------------- |
| `pipeline` | Object | Yes | The root node object of the pipeline. | | `pipeline` | Object | Yes | The root node object of the pipeline. |
| `redis` | Object | No | Configuration for connecting to a Redis server. | | `redis` | Object | No | Configuration for connecting to a Redis server. |
| `postgresql` | Object | No | Configuration for connecting to a PostgreSQL database. |
### Redis Configuration (`redis`) ### Redis Configuration (`redis`)
@ -50,6 +51,16 @@ This file defines the entire pipeline logic. The root object contains a `pipelin
| `password` | String | No | The password for Redis authentication. | | `password` | String | No | The password for Redis authentication. |
| `db` | Number | No | The Redis database number to use. Defaults to `0`. | | `db` | Number | No | The Redis database number to use. Defaults to `0`. |
### PostgreSQL Configuration (`postgresql`)
| Key | Type | Required | Description |
| ---------- | ------ | -------- | ------------------------------------------------------- |
| `host` | String | Yes | The hostname or IP address of the PostgreSQL server. |
| `port` | Number | Yes | The port number of the PostgreSQL server. |
| `database` | String | Yes | The database name to connect to. |
| `username` | String | Yes | The username for database authentication. |
| `password` | String | Yes | The password for database authentication. |
### Node Object Structure ### Node Object Structure
| Key | Type | Required | Description | | Key | Type | Required | Description |
@ -59,12 +70,17 @@ This file defines the entire pipeline logic. The root object contains a `pipelin
| `minConfidence` | Float | Yes | The minimum confidence score (0.0 to 1.0) required for a detection to be considered valid and potentially trigger a branch. | | `minConfidence` | Float | Yes | The minimum confidence score (0.0 to 1.0) required for a detection to be considered valid and potentially trigger a branch. |
| `triggerClasses` | Array<String> | Yes | A list of class names that, when detected by the parent, can trigger this node. For the root node, this lists all classes of interest. | | `triggerClasses` | Array<String> | Yes | A list of class names that, when detected by the parent, can trigger this node. For the root node, this lists all classes of interest. |
| `crop` | Boolean | No | If `true`, the image is cropped to the parent's detection bounding box before being passed to this node's model. Defaults to `false`. | | `crop` | Boolean | No | If `true`, the image is cropped to the parent's detection bounding box before being passed to this node's model. Defaults to `false`. |
| `cropClass` | String | No | The specific class to use for cropping (e.g., "Frontal" for frontal view cropping). |
| `multiClass` | Boolean | No | If `true`, enables multi-class detection mode where multiple classes can be detected simultaneously. |
| `expectedClasses` | Array<String> | No | When `multiClass` is true, defines which classes are expected. At least one must be detected for processing to continue. |
| `parallel` | Boolean | No | If `true`, this branch will be processed in parallel with other parallel branches. |
| `branches` | Array<Node> | No | A list of child node objects that can be triggered by this node's detections. | | `branches` | Array<Node> | No | A list of child node objects that can be triggered by this node's detections. |
| `actions` | Array<Action> | No | A list of actions to execute upon a successful detection in this node. | | `actions` | Array<Action> | No | A list of actions to execute upon a successful detection in this node. |
| `parallelActions` | Array<Action> | No | A list of actions to execute after all specified branches have completed. |
### Action Object Structure ### Action Object Structure
Actions allow the pipeline to interact with Redis. They are executed sequentially for a given detection. Actions allow the pipeline to interact with Redis and PostgreSQL databases. They are executed sequentially for a given detection.
#### Action Context & Dynamic Keys #### Action Context & Dynamic Keys
@ -72,7 +88,12 @@ All actions have access to a dynamic context for formatting keys and messages. T
- All key-value pairs from the detection result (e.g., `class`, `confidence`, `id`). - All key-value pairs from the detection result (e.g., `class`, `confidence`, `id`).
- `{timestamp_ms}`: The current Unix timestamp in milliseconds. - `{timestamp_ms}`: The current Unix timestamp in milliseconds.
- `{timestamp}`: Formatted timestamp string (YYYY-MM-DDTHH-MM-SS).
- `{uuid}`: A unique identifier (UUID4) for the detection event. - `{uuid}`: A unique identifier (UUID4) for the detection event.
- `{filename}`: Generated filename with UUID.
- `{camera_id}`: Full camera subscription identifier.
- `{display_id}`: Display identifier extracted from subscription.
- `{session_id}`: Session ID for database operations.
- `{image_key}`: If a `redis_save_image` action has already been executed for this event, this placeholder will be replaced with the key where the image was stored. - `{image_key}`: If a `redis_save_image` action has already been executed for this event, this placeholder will be replaced with the key where the image was stored.
#### `redis_save_image` #### `redis_save_image`
@ -83,6 +104,9 @@ Saves the current image frame (or cropped sub-image) to a Redis key.
| ---------------- | ------ | -------- | ------------------------------------------------------------------------------------------------------- | | ---------------- | ------ | -------- | ------------------------------------------------------------------------------------------------------- |
| `type` | String | Yes | Must be `"redis_save_image"`. | | `type` | String | Yes | Must be `"redis_save_image"`. |
| `key` | String | Yes | The Redis key to save the image to. Can contain any of the dynamic placeholders. | | `key` | String | Yes | The Redis key to save the image to. Can contain any of the dynamic placeholders. |
| `region` | String | No | Specific detected region to crop and save (e.g., "Frontal"). |
| `format` | String | No | Image format: "jpeg" or "png". Defaults to "jpeg". |
| `quality` | Number | No | JPEG quality (1-100). Defaults to 90. |
| `expire_seconds` | Number | No | If provided, sets an expiration time (in seconds) for the Redis key. | | `expire_seconds` | Number | No | If provided, sets an expiration time (in seconds) for the Redis key. |
#### `redis_publish` #### `redis_publish`
@ -95,35 +119,98 @@ Publishes a message to a Redis channel.
| `channel` | String | Yes | The Redis channel to publish the message to. | | `channel` | String | Yes | The Redis channel to publish the message to. |
| `message` | String | Yes | The message to publish. Can contain any of the dynamic placeholders, including `{image_key}`. | | `message` | String | Yes | The message to publish. Can contain any of the dynamic placeholders, including `{image_key}`. |
### Example `pipeline.json` with Redis #### `postgresql_update_combined`
This example demonstrates a pipeline that detects vehicles, saves a uniquely named image of each detection that expires in one hour, and then publishes a notification with the image key. Updates PostgreSQL database with results from multiple branches after they complete.
| Key | Type | Required | Description |
| ------------------ | ------------- | -------- | ------------------------------------------------------------------------------------------------------- |
| `type` | String | Yes | Must be `"postgresql_update_combined"`. |
| `table` | String | Yes | The database table name (will be prefixed with `gas_station_1.` schema). |
| `key_field` | String | Yes | The field to use as the update key (typically "session_id"). |
| `key_value` | String | Yes | Template for the key value (e.g., "{session_id}"). |
| `waitForBranches` | Array<String> | Yes | List of branch model IDs to wait for completion before executing update. |
| `fields` | Object | Yes | Field mapping object where keys are database columns and values are templates (e.g., "{branch.field}").|
### Complete Example `pipeline.json`
This example demonstrates a comprehensive pipeline for vehicle detection with parallel classification and database integration:
```json ```json
{ {
"redis": { "redis": {
"host": "redis.local", "host": "10.100.1.3",
"port": 6379, "port": 6379,
"password": "your-super-secret-password" "password": "your-redis-password",
"db": 0
},
"postgresql": {
"host": "10.100.1.3",
"port": 5432,
"database": "inference",
"username": "root",
"password": "your-db-password"
}, },
"pipeline": { "pipeline": {
"modelId": "vehicle-detector", "modelId": "car_frontal_detection_v1",
"modelFile": "vehicle_model.pt", "modelFile": "car_frontal_detection_v1.pt",
"minConfidence": 0.6, "crop": false,
"triggerClasses": ["car", "truck"], "triggerClasses": ["Car", "Frontal"],
"minConfidence": 0.8,
"multiClass": true,
"expectedClasses": ["Car", "Frontal"],
"actions": [ "actions": [
{ {
"type": "redis_save_image", "type": "redis_save_image",
"key": "detections:{class}:{timestamp_ms}:{uuid}", "region": "Frontal",
"expire_seconds": 3600 "key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
"expire_seconds": 600,
"format": "jpeg",
"quality": 90
}, },
{ {
"type": "redis_publish", "type": "redis_publish",
"channel": "vehicle_events", "channel": "car_detections",
"message": "{\"event\":\"new_detection\",\"class\":\"{class}\",\"confidence\":{confidence},\"image_key\":\"{image_key}\"}" "message": "{\"event\":\"frontal_detected\"}"
} }
], ],
"branches": [
{
"modelId": "car_brand_cls_v1",
"modelFile": "car_brand_cls_v1.pt",
"crop": true,
"cropClass": "Frontal",
"resizeTarget": [224, 224],
"triggerClasses": ["Frontal"],
"minConfidence": 0.85,
"parallel": true,
"branches": [] "branches": []
},
{
"modelId": "car_bodytype_cls_v1",
"modelFile": "car_bodytype_cls_v1.pt",
"crop": true,
"cropClass": "Car",
"resizeTarget": [224, 224],
"triggerClasses": ["Car"],
"minConfidence": 0.85,
"parallel": true,
"branches": []
}
],
"parallelActions": [
{
"type": "postgresql_update_combined",
"table": "car_frontal_info",
"key_field": "session_id",
"key_value": "{session_id}",
"waitForBranches": ["car_brand_cls_v1", "car_bodytype_cls_v1"],
"fields": {
"car_brand": "{car_brand_cls_v1.brand}",
"car_body_type": "{car_bodytype_cls_v1.body_type}"
}
}
]
} }
} }
``` ```
@ -134,7 +221,7 @@ The `pympta` module exposes two main functions.
### `load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict` ### `load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict`
Loads, extracts, and parses an `.mpta` file to build a pipeline tree in memory. It also establishes a Redis connection if configured in `pipeline.json`. Loads, extracts, and parses an `.mpta` file to build a pipeline tree in memory. It also establishes Redis and PostgreSQL connections if configured in `pipeline.json`.
- **Parameters:** - **Parameters:**
- `zip_source` (str): The file path to the local `.mpta` zip archive. - `zip_source` (str): The file path to the local `.mpta` zip archive.
@ -142,7 +229,7 @@ Loads, extracts, and parses an `.mpta` file to build a pipeline tree in memory.
- **Returns:** - **Returns:**
- A dictionary representing the root node of the pipeline, ready to be used with `run_pipeline`. Returns `None` if loading fails. - A dictionary representing the root node of the pipeline, ready to be used with `run_pipeline`. Returns `None` if loading fails.
### `run_pipeline(frame, node: dict, return_bbox: bool = False)` ### `run_pipeline(frame, node: dict, return_bbox: bool = False, context: dict = None)`
Executes the inference pipeline on a single image frame. Executes the inference pipeline on a single image frame.
@ -150,12 +237,43 @@ Executes the inference pipeline on a single image frame.
- `frame`: The input image frame (e.g., a NumPy array from OpenCV). - `frame`: The input image frame (e.g., a NumPy array from OpenCV).
- `node` (dict): The pipeline node to execute (typically the root node returned by `load_pipeline_from_zip`). - `node` (dict): The pipeline node to execute (typically the root node returned by `load_pipeline_from_zip`).
- `return_bbox` (bool): If `True`, the function returns a tuple `(detection, bounding_box)`. Otherwise, it returns only the `detection`. - `return_bbox` (bool): If `True`, the function returns a tuple `(detection, bounding_box)`. Otherwise, it returns only the `detection`.
- `context` (dict): Optional context dictionary containing camera_id, display_id, session_id for action formatting.
- **Returns:** - **Returns:**
- The final detection result from the last executed node in the chain. A detection is a dictionary like `{'class': 'car', 'confidence': 0.95, 'id': 1}`. If no detection meets the criteria, it returns `None` (or `(None, None)` if `return_bbox` is `True`). - The final detection result from the last executed node in the chain. A detection is a dictionary like `{'class': 'car', 'confidence': 0.95, 'id': 1}`. If no detection meets the criteria, it returns `None` (or `(None, None)` if `return_bbox` is `True`).
## Database Integration
The pipeline system includes automatic PostgreSQL database management:
### Table Schema (`gas_station_1.car_frontal_info`)
The system automatically creates and manages the following table structure:
```sql
CREATE TABLE IF NOT EXISTS gas_station_1.car_frontal_info (
display_id VARCHAR(255),
captured_timestamp VARCHAR(255),
session_id VARCHAR(255) PRIMARY KEY,
license_character VARCHAR(255) DEFAULT NULL,
license_type VARCHAR(255) DEFAULT 'No model available',
car_brand VARCHAR(255) DEFAULT NULL,
car_model VARCHAR(255) DEFAULT NULL,
car_body_type VARCHAR(255) DEFAULT NULL,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
```
### Workflow
1. **Initial Record Creation**: When both "Car" and "Frontal" are detected, an initial database record is created with a UUID session_id.
2. **Redis Storage**: Vehicle images are stored in Redis with keys containing the session_id.
3. **Parallel Classification**: Brand and body type classification run concurrently.
4. **Database Update**: After all branches complete, the database record is updated with classification results.
## Usage Example ## Usage Example
This snippet, inspired by `pipeline_webcam.py`, shows how to use `pympta` to load a pipeline and process an image from a webcam. This snippet shows how to use `pympta` with the enhanced features:
```python ```python
import cv2 import cv2
@ -181,9 +299,14 @@ while True:
if not ret: if not ret:
break break
# 4. Run the pipeline on the current frame # 4. Run the pipeline on the current frame with context
# The function will handle the entire logic tree (e.g., find a car, then find its license plate). context = {
detection_result, bounding_box = run_pipeline(frame, model_tree, return_bbox=True) "camera_id": "display-001;cam-001",
"display_id": "display-001",
"session_id": None # Will be generated automatically
}
detection_result, bounding_box = run_pipeline(frame, model_tree, return_bbox=True, context=context)
# 5. Display the results # 5. Display the results
if detection_result: if detection_result:

7
requirements.base.txt Normal file
View file

@ -0,0 +1,7 @@
torch
torchvision
ultralytics
opencv-python
scipy
filterpy
psycopg2-binary

View file

@ -1,66 +1,6 @@
fastapi fastapi
uvicorn uvicorn
# torch
# torchvision
# ultralytics
# opencv-python
websockets websockets
fastapi[standard] fastapi[standard]
redis redis
urllib3<2.0.0
# Trackers Environment
# pip install -r requirements.txt
ultralytics==8.0.20
# Base ----------------------------------------
gitpython
ipython # interactive notebook
matplotlib>=3.2.2
numpy==1.23.1
opencv-python>=4.1.1
Pillow>=7.1.2
psutil # system resources
PyYAML>=5.3.1
requests>=2.23.0
scipy>=1.4.1
thop>=0.1.1 # FLOPs computation
torch>=1.7.0,<=2.5.1 # see https://pytorch.org/get-started/locally (recommended)
torchvision>=0.8.1,<=0.20.1
tqdm>=4.64.0
# protobuf<=3.20.1 # https://github.com/ultralytics/yolov5/issues/8012
# Logging ---------------------------------------------------------------------
tensorboard>=2.4.1
# clearml>=1.2.0
# comet
# Plotting --------------------------------------------------------------------
pandas>=1.1.4
seaborn>=0.11.0
# StrongSORT ------------------------------------------------------------------
easydict
# torchreid -------------------------------------------------------------------
gdown
# ByteTrack -------------------------------------------------------------------
lap
# OCSORT ----------------------------------------------------------------------
filterpy
# Export ----------------------------------------------------------------------
# onnx>=1.9.0 # ONNX export
# onnx-simplifier>=0.4.1 # ONNX simplifier
# nvidia-pyindex # TensorRT export
# nvidia-tensorrt # TensorRT export
# openvino-dev # OpenVINO export
# Hyperparam search -----------------------------------------------------------
# optuna
# plotly # for hp importance and pareto front plots
# kaleido
# joblib
pyzmq
loguru

211
siwatsystem/database.py Normal file
View file

@ -0,0 +1,211 @@
import psycopg2
import psycopg2.extras
from typing import Optional, Dict, Any
import logging
import uuid
logger = logging.getLogger(__name__)
class DatabaseManager:
def __init__(self, config: Dict[str, Any]):
self.config = config
self.connection: Optional[psycopg2.extensions.connection] = None
def connect(self) -> bool:
try:
self.connection = psycopg2.connect(
host=self.config['host'],
port=self.config['port'],
database=self.config['database'],
user=self.config['username'],
password=self.config['password']
)
logger.info("PostgreSQL connection established successfully")
return True
except Exception as e:
logger.error(f"Failed to connect to PostgreSQL: {e}")
return False
def disconnect(self):
if self.connection:
self.connection.close()
self.connection = None
logger.info("PostgreSQL connection closed")
def is_connected(self) -> bool:
try:
if self.connection and not self.connection.closed:
cur = self.connection.cursor()
cur.execute("SELECT 1")
cur.fetchone()
cur.close()
return True
except:
pass
return False
def update_car_info(self, session_id: str, brand: str, model: str, body_type: str) -> bool:
if not self.is_connected():
if not self.connect():
return False
try:
cur = self.connection.cursor()
query = """
INSERT INTO car_frontal_info (session_id, car_brand, car_model, car_body_type, updated_at)
VALUES (%s, %s, %s, %s, NOW())
ON CONFLICT (session_id)
DO UPDATE SET
car_brand = EXCLUDED.car_brand,
car_model = EXCLUDED.car_model,
car_body_type = EXCLUDED.car_body_type,
updated_at = NOW()
"""
cur.execute(query, (session_id, brand, model, body_type))
self.connection.commit()
cur.close()
logger.info(f"Updated car info for session {session_id}: {brand} {model} ({body_type})")
return True
except Exception as e:
logger.error(f"Failed to update car info: {e}")
if self.connection:
self.connection.rollback()
return False
def execute_update(self, table: str, key_field: str, key_value: str, fields: Dict[str, str]) -> bool:
if not self.is_connected():
if not self.connect():
return False
try:
cur = self.connection.cursor()
# Build the UPDATE query dynamically
set_clauses = []
values = []
for field, value in fields.items():
if value == "NOW()":
set_clauses.append(f"{field} = NOW()")
else:
set_clauses.append(f"{field} = %s")
values.append(value)
# Add schema prefix if table doesn't already have it
full_table_name = table if '.' in table else f"gas_station_1.{table}"
query = f"""
INSERT INTO {full_table_name} ({key_field}, {', '.join(fields.keys())})
VALUES (%s, {', '.join(['%s'] * len(fields))})
ON CONFLICT ({key_field})
DO UPDATE SET {', '.join(set_clauses)}
"""
# Add key_value to the beginning of values list
all_values = [key_value] + list(fields.values()) + values
cur.execute(query, all_values)
self.connection.commit()
cur.close()
logger.info(f"Updated {table} for {key_field}={key_value}")
return True
except Exception as e:
logger.error(f"Failed to execute update on {table}: {e}")
if self.connection:
self.connection.rollback()
return False
def create_car_frontal_info_table(self) -> bool:
"""Create the car_frontal_info table in gas_station_1 schema if it doesn't exist."""
if not self.is_connected():
if not self.connect():
return False
try:
cur = self.connection.cursor()
# Create schema if it doesn't exist
cur.execute("CREATE SCHEMA IF NOT EXISTS gas_station_1")
# Create table if it doesn't exist
create_table_query = """
CREATE TABLE IF NOT EXISTS gas_station_1.car_frontal_info (
display_id VARCHAR(255),
captured_timestamp VARCHAR(255),
session_id VARCHAR(255) PRIMARY KEY,
license_character VARCHAR(255) DEFAULT NULL,
license_type VARCHAR(255) DEFAULT 'No model available',
car_brand VARCHAR(255) DEFAULT NULL,
car_model VARCHAR(255) DEFAULT NULL,
car_body_type VARCHAR(255) DEFAULT NULL,
updated_at TIMESTAMP DEFAULT NOW()
)
"""
cur.execute(create_table_query)
# Add columns if they don't exist (for existing tables)
alter_queries = [
"ALTER TABLE gas_station_1.car_frontal_info ADD COLUMN IF NOT EXISTS car_brand VARCHAR(255) DEFAULT NULL",
"ALTER TABLE gas_station_1.car_frontal_info ADD COLUMN IF NOT EXISTS car_model VARCHAR(255) DEFAULT NULL",
"ALTER TABLE gas_station_1.car_frontal_info ADD COLUMN IF NOT EXISTS car_body_type VARCHAR(255) DEFAULT NULL",
"ALTER TABLE gas_station_1.car_frontal_info ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP DEFAULT NOW()"
]
for alter_query in alter_queries:
try:
cur.execute(alter_query)
logger.debug(f"Executed: {alter_query}")
except Exception as e:
# Ignore errors if column already exists (for older PostgreSQL versions)
if "already exists" in str(e).lower():
logger.debug(f"Column already exists, skipping: {alter_query}")
else:
logger.warning(f"Error in ALTER TABLE: {e}")
self.connection.commit()
cur.close()
logger.info("Successfully created/verified car_frontal_info table with all required columns")
return True
except Exception as e:
logger.error(f"Failed to create car_frontal_info table: {e}")
if self.connection:
self.connection.rollback()
return False
def insert_initial_detection(self, display_id: str, captured_timestamp: str, session_id: str = None) -> str:
"""Insert initial detection record and return the session_id."""
if not self.is_connected():
if not self.connect():
return None
# Generate session_id if not provided
if not session_id:
session_id = str(uuid.uuid4())
try:
# Ensure table exists
if not self.create_car_frontal_info_table():
logger.error("Failed to create/verify table before insertion")
return None
cur = self.connection.cursor()
insert_query = """
INSERT INTO gas_station_1.car_frontal_info
(display_id, captured_timestamp, session_id, license_character, license_type, car_brand, car_model, car_body_type)
VALUES (%s, %s, %s, NULL, 'No model available', NULL, NULL, NULL)
ON CONFLICT (session_id) DO NOTHING
"""
cur.execute(insert_query, (display_id, captured_timestamp, session_id))
self.connection.commit()
cur.close()
logger.info(f"Inserted initial detection record with session_id: {session_id}")
return session_id
except Exception as e:
logger.error(f"Failed to insert initial detection record: {e}")
if self.connection:
self.connection.rollback()
return None

View file

@ -3,20 +3,72 @@ import json
import logging import logging
import torch import torch
import cv2 import cv2
import requests
import zipfile import zipfile
import shutil import shutil
import traceback import traceback
import redis import redis
import time import time
import uuid import uuid
import concurrent.futures
from ultralytics import YOLO from ultralytics import YOLO
from urllib.parse import urlparse from urllib.parse import urlparse
from .database import DatabaseManager
# Create a logger specifically for this module # Create a logger specifically for this module
logger = logging.getLogger("detector_worker.pympta") logger = logging.getLogger("detector_worker.pympta")
def load_pipeline_node(node_config: dict, mpta_dir: str, redis_client) -> dict: def validate_redis_config(redis_config: dict) -> bool:
"""Validate Redis configuration parameters."""
required_fields = ["host", "port"]
for field in required_fields:
if field not in redis_config:
logger.error(f"Missing required Redis config field: {field}")
return False
if not isinstance(redis_config["port"], int) or redis_config["port"] <= 0:
logger.error(f"Invalid Redis port: {redis_config['port']}")
return False
return True
def validate_postgresql_config(pg_config: dict) -> bool:
"""Validate PostgreSQL configuration parameters."""
required_fields = ["host", "port", "database", "username", "password"]
for field in required_fields:
if field not in pg_config:
logger.error(f"Missing required PostgreSQL config field: {field}")
return False
if not isinstance(pg_config["port"], int) or pg_config["port"] <= 0:
logger.error(f"Invalid PostgreSQL port: {pg_config['port']}")
return False
return True
def crop_region_by_class(frame, regions_dict, class_name):
"""Crop a specific region from frame based on detected class."""
if class_name not in regions_dict:
logger.warning(f"Class '{class_name}' not found in detected regions")
return None
bbox = regions_dict[class_name]['bbox']
x1, y1, x2, y2 = bbox
cropped = frame[y1:y2, x1:x2]
if cropped.size == 0:
logger.warning(f"Empty crop for class '{class_name}' with bbox {bbox}")
return None
return cropped
def format_action_context(base_context, additional_context=None):
"""Format action context with dynamic values."""
context = {**base_context}
if additional_context:
context.update(additional_context)
return context
def load_pipeline_node(node_config: dict, mpta_dir: str, redis_client, db_manager=None) -> dict:
# Recursively load a model node from configuration. # Recursively load a model node from configuration.
model_path = os.path.join(mpta_dir, node_config["modelFile"]) model_path = os.path.join(mpta_dir, node_config["modelFile"])
if not os.path.exists(model_path): if not os.path.exists(model_path):
@ -46,16 +98,22 @@ def load_pipeline_node(node_config: dict, mpta_dir: str, redis_client) -> dict:
"triggerClasses": trigger_classes, "triggerClasses": trigger_classes,
"triggerClassIndices": trigger_class_indices, "triggerClassIndices": trigger_class_indices,
"crop": node_config.get("crop", False), "crop": node_config.get("crop", False),
"cropClass": node_config.get("cropClass"),
"minConfidence": node_config.get("minConfidence", None), "minConfidence": node_config.get("minConfidence", None),
"multiClass": node_config.get("multiClass", False),
"expectedClasses": node_config.get("expectedClasses", []),
"parallel": node_config.get("parallel", False),
"actions": node_config.get("actions", []), "actions": node_config.get("actions", []),
"parallelActions": node_config.get("parallelActions", []),
"model": model, "model": model,
"branches": [], "branches": [],
"redis_client": redis_client "redis_client": redis_client,
"db_manager": db_manager
} }
logger.debug(f"Configured node {node_config['modelId']} with trigger classes: {node['triggerClasses']}") logger.debug(f"Configured node {node_config['modelId']} with trigger classes: {node['triggerClasses']}")
for child in node_config.get("branches", []): for child in node_config.get("branches", []):
logger.debug(f"Loading branch for parent node {node_config['modelId']}") logger.debug(f"Loading branch for parent node {node_config['modelId']}")
node["branches"].append(load_pipeline_node(child, mpta_dir, redis_client)) node["branches"].append(load_pipeline_node(child, mpta_dir, redis_client, db_manager))
return node return node
def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict: def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
@ -183,6 +241,9 @@ def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
redis_client = None redis_client = None
if "redis" in pipeline_config: if "redis" in pipeline_config:
redis_config = pipeline_config["redis"] redis_config = pipeline_config["redis"]
if not validate_redis_config(redis_config):
logger.error("Invalid Redis configuration, skipping Redis connection")
else:
try: try:
redis_client = redis.Redis( redis_client = redis.Redis(
host=redis_config["host"], host=redis_config["host"],
@ -197,7 +258,25 @@ def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
logger.error(f"Failed to connect to Redis: {e}") logger.error(f"Failed to connect to Redis: {e}")
redis_client = None redis_client = None
return load_pipeline_node(pipeline_config["pipeline"], mpta_dir, redis_client) # Establish PostgreSQL connection if configured
db_manager = None
if "postgresql" in pipeline_config:
pg_config = pipeline_config["postgresql"]
if not validate_postgresql_config(pg_config):
logger.error("Invalid PostgreSQL configuration, skipping database connection")
else:
try:
db_manager = DatabaseManager(pg_config)
if db_manager.connect():
logger.info(f"Successfully connected to PostgreSQL at {pg_config['host']}:{pg_config['port']}")
else:
logger.error("Failed to connect to PostgreSQL")
db_manager = None
except Exception as e:
logger.error(f"Error initializing PostgreSQL connection: {e}")
db_manager = None
return load_pipeline_node(pipeline_config["pipeline"], mpta_dir, redis_client, db_manager)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"Error parsing pipeline.json: {str(e)}", exc_info=True) logger.error(f"Error parsing pipeline.json: {str(e)}", exc_info=True)
return None return None
@ -208,22 +287,53 @@ def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
logger.error(f"Error loading pipeline.json: {str(e)}", exc_info=True) logger.error(f"Error loading pipeline.json: {str(e)}", exc_info=True)
return None return None
def execute_actions(node, frame, detection_result): def execute_actions(node, frame, detection_result, regions_dict=None):
if not node["redis_client"] or not node["actions"]: if not node["redis_client"] or not node["actions"]:
return return
# Create a dynamic context for this detection event # Create a dynamic context for this detection event
from datetime import datetime
action_context = { action_context = {
**detection_result, **detection_result,
"timestamp_ms": int(time.time() * 1000), "timestamp_ms": int(time.time() * 1000),
"uuid": str(uuid.uuid4()), "uuid": str(uuid.uuid4()),
"timestamp": datetime.now().strftime("%Y-%m-%dT%H-%M-%S"),
"filename": f"{uuid.uuid4()}.jpg"
} }
for action in node["actions"]: for action in node["actions"]:
try: try:
if action["type"] == "redis_save_image": if action["type"] == "redis_save_image":
key = action["key"].format(**action_context) key = action["key"].format(**action_context)
_, buffer = cv2.imencode('.jpg', frame)
# Check if we need to crop a specific region
region_name = action.get("region")
image_to_save = frame
if region_name and regions_dict:
cropped_image = crop_region_by_class(frame, regions_dict, region_name)
if cropped_image is not None:
image_to_save = cropped_image
logger.debug(f"Cropped region '{region_name}' for redis_save_image")
else:
logger.warning(f"Could not crop region '{region_name}', saving full frame instead")
# Encode image with specified format and quality (default to JPEG)
img_format = action.get("format", "jpeg").lower()
quality = action.get("quality", 90)
if img_format == "jpeg":
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
success, buffer = cv2.imencode('.jpg', image_to_save, encode_params)
elif img_format == "png":
success, buffer = cv2.imencode('.png', image_to_save)
else:
success, buffer = cv2.imencode('.jpg', image_to_save, [cv2.IMWRITE_JPEG_QUALITY, quality])
if not success:
logger.error(f"Failed to encode image for redis_save_image")
continue
expire_seconds = action.get("expire_seconds") expire_seconds = action.get("expire_seconds")
if expire_seconds: if expire_seconds:
node["redis_client"].setex(key, expire_seconds, buffer.tobytes()) node["redis_client"].setex(key, expire_seconds, buffer.tobytes())
@ -231,60 +341,244 @@ def execute_actions(node, frame, detection_result):
else: else:
node["redis_client"].set(key, buffer.tobytes()) node["redis_client"].set(key, buffer.tobytes())
logger.info(f"Saved image to Redis with key: {key}") logger.info(f"Saved image to Redis with key: {key}")
# Add the generated key to the context for subsequent actions
action_context["image_key"] = key action_context["image_key"] = key
elif action["type"] == "redis_publish": elif action["type"] == "redis_publish":
channel = action["channel"] channel = action["channel"]
message = action["message"].format(**action_context) try:
node["redis_client"].publish(channel, message) # Handle JSON message format by creating it programmatically
message_template = action["message"]
# Check if the message is JSON-like (starts and ends with braces)
if message_template.strip().startswith('{') and message_template.strip().endswith('}'):
# Create JSON data programmatically to avoid formatting issues
json_data = {}
# Add common fields
json_data["event"] = "frontal_detected"
json_data["display_id"] = action_context.get("display_id", "unknown")
json_data["session_id"] = action_context.get("session_id")
json_data["timestamp"] = action_context.get("timestamp", "")
json_data["image_key"] = action_context.get("image_key", "")
# Convert to JSON string
message = json.dumps(json_data)
else:
# Use regular string formatting for non-JSON messages
message = message_template.format(**action_context)
# Publish to Redis
if not node["redis_client"]:
logger.error("Redis client is None, cannot publish message")
continue
# Test Redis connection
try:
node["redis_client"].ping()
logger.debug("Redis connection is active")
except Exception as ping_error:
logger.error(f"Redis connection test failed: {ping_error}")
continue
result = node["redis_client"].publish(channel, message)
logger.info(f"Published message to Redis channel '{channel}': {message}") logger.info(f"Published message to Redis channel '{channel}': {message}")
logger.info(f"Redis publish result (subscribers count): {result}")
# Additional debug info
if result == 0:
logger.warning(f"No subscribers listening to channel '{channel}'")
else:
logger.info(f"Message delivered to {result} subscriber(s)")
except KeyError as e:
logger.error(f"Missing key in redis_publish message template: {e}")
logger.debug(f"Available context keys: {list(action_context.keys())}")
except Exception as e:
logger.error(f"Error in redis_publish action: {e}")
logger.debug(f"Message template: {action['message']}")
logger.debug(f"Available context keys: {list(action_context.keys())}")
import traceback
logger.debug(f"Full traceback: {traceback.format_exc()}")
except Exception as e: except Exception as e:
logger.error(f"Error executing action {action['type']}: {e}") logger.error(f"Error executing action {action['type']}: {e}")
def run_pipeline(frame, node: dict, return_bbox: bool=False): def execute_parallel_actions(node, frame, detection_result, regions_dict):
"""Execute parallel actions after all required branches have completed."""
if not node.get("parallelActions"):
return
logger.debug("Executing parallel actions...")
branch_results = detection_result.get("branch_results", {})
for action in node["parallelActions"]:
try:
action_type = action.get("type")
logger.debug(f"Processing parallel action: {action_type}")
if action_type == "postgresql_update_combined":
# Check if all required branches have completed
wait_for_branches = action.get("waitForBranches", [])
missing_branches = [branch for branch in wait_for_branches if branch not in branch_results]
if missing_branches:
logger.warning(f"Cannot execute postgresql_update_combined: missing branch results for {missing_branches}")
continue
logger.info(f"All required branches completed: {wait_for_branches}")
# Execute the database update
execute_postgresql_update_combined(node, action, detection_result, branch_results)
else:
logger.warning(f"Unknown parallel action type: {action_type}")
except Exception as e:
logger.error(f"Error executing parallel action {action.get('type', 'unknown')}: {e}")
import traceback
logger.debug(f"Full traceback: {traceback.format_exc()}")
def execute_postgresql_update_combined(node, action, detection_result, branch_results):
"""Execute a PostgreSQL update with combined branch results."""
if not node.get("db_manager"):
logger.error("No database manager available for postgresql_update_combined action")
return
try:
table = action["table"]
key_field = action["key_field"]
key_value_template = action["key_value"]
fields = action["fields"]
# Create context for key value formatting
action_context = {**detection_result}
key_value = key_value_template.format(**action_context)
logger.info(f"Executing database update: table={table}, {key_field}={key_value}")
# Process field mappings
mapped_fields = {}
for db_field, value_template in fields.items():
try:
mapped_value = resolve_field_mapping(value_template, branch_results, action_context)
if mapped_value is not None:
mapped_fields[db_field] = mapped_value
logger.debug(f"Mapped field: {db_field} = {mapped_value}")
else:
logger.warning(f"Could not resolve field mapping for {db_field}: {value_template}")
except Exception as e:
logger.error(f"Error mapping field {db_field} with template '{value_template}': {e}")
if not mapped_fields:
logger.warning("No fields mapped successfully, skipping database update")
return
# Execute the database update
success = node["db_manager"].execute_update(table, key_field, key_value, mapped_fields)
if success:
logger.info(f"Successfully updated database: {table} with {len(mapped_fields)} fields")
else:
logger.error(f"Failed to update database: {table}")
except KeyError as e:
logger.error(f"Missing required field in postgresql_update_combined action: {e}")
except Exception as e:
logger.error(f"Error in postgresql_update_combined action: {e}")
import traceback
logger.debug(f"Full traceback: {traceback.format_exc()}")
def resolve_field_mapping(value_template, branch_results, action_context):
"""Resolve field mapping templates like {car_brand_cls_v1.brand}."""
try:
# Handle simple context variables first (non-branch references)
if not '.' in value_template:
return value_template.format(**action_context)
# Handle branch result references like {model_id.field}
import re
branch_refs = re.findall(r'\{([^}]+\.[^}]+)\}', value_template)
resolved_template = value_template
for ref in branch_refs:
try:
model_id, field_name = ref.split('.', 1)
if model_id in branch_results:
branch_data = branch_results[model_id]
if field_name in branch_data:
field_value = branch_data[field_name]
resolved_template = resolved_template.replace(f'{{{ref}}}', str(field_value))
logger.debug(f"Resolved {ref} to {field_value}")
else:
logger.warning(f"Field '{field_name}' not found in branch '{model_id}' results. Available fields: {list(branch_data.keys())}")
return None
else:
logger.warning(f"Branch '{model_id}' not found in results. Available branches: {list(branch_results.keys())}")
return None
except ValueError as e:
logger.error(f"Invalid branch reference format: {ref}")
return None
# Format any remaining simple variables
try:
final_value = resolved_template.format(**action_context)
return final_value
except KeyError as e:
logger.warning(f"Could not resolve context variable in template: {e}")
return resolved_template
except Exception as e:
logger.error(f"Error resolving field mapping '{value_template}': {e}")
return None
def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
""" """
- For detection nodes (task != 'classify'): Enhanced pipeline that supports:
runs `track(..., classes=triggerClassIndices)` - Multi-class detection (detecting multiple classes simultaneously)
picks top box minConfidence - Parallel branch processing
optionally crops & resizes recurse into child - Region-based actions and cropping
else returns (det_dict, bbox) - Context passing for session/camera information
- For classify nodes:
runs `predict()`
returns top (class,confidence) and no bbox
""" """
try: try:
task = getattr(node["model"], "task", None) task = getattr(node["model"], "task", None)
# ─── Classification stage ─────────────────────────────────── # ─── Classification stage ───────────────────────────────────
if task == "classify": if task == "classify":
# run the classifier and grab its top-1 directly via the Probs API
results = node["model"].predict(frame, stream=False) results = node["model"].predict(frame, stream=False)
# nothing returned?
if not results: if not results:
return (None, None) if return_bbox else None return (None, None) if return_bbox else None
# take the first result's probs object
r = results[0] r = results[0]
probs = r.probs probs = r.probs
if probs is None: if probs is None:
return (None, None) if return_bbox else None return (None, None) if return_bbox else None
# get the top-1 class index and its confidence
top1_idx = int(probs.top1) top1_idx = int(probs.top1)
top1_conf = float(probs.top1conf) top1_conf = float(probs.top1conf)
class_name = node["model"].names[top1_idx]
det = { det = {
"class": node["model"].names[top1_idx], "class": class_name,
"confidence": top1_conf, "confidence": top1_conf,
"id": None "id": None,
class_name: class_name # Add class name as key for backward compatibility
} }
# Add specific field mappings for database operations based on model type
model_id = node.get("modelId", "").lower()
if "brand" in model_id or "brand_cls" in model_id:
det["brand"] = class_name
elif "bodytype" in model_id or "body" in model_id:
det["body_type"] = class_name
elif "color" in model_id:
det["color"] = class_name
execute_actions(node, frame, det) execute_actions(node, frame, det)
return (det, None) if return_bbox else det return (det, None) if return_bbox else det
# ─── Detection stage - Multi-class support ──────────────────
# ─── Detection stage ────────────────────────────────────────
# only look for your triggerClasses
tk = node["triggerClassIndices"] tk = node["triggerClassIndices"]
logger.debug(f"Running detection for node {node['modelId']} with trigger classes: {node.get('triggerClasses', [])} (indices: {tk})")
logger.debug(f"Node configuration: minConfidence={node['minConfidence']}, multiClass={node.get('multiClass', False)}")
res = node["model"].track( res = node["model"].track(
frame, frame,
stream=False, stream=False,
@ -292,48 +586,228 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False):
**({"classes": tk} if tk else {}) **({"classes": tk} if tk else {})
)[0] )[0]
dets, boxes = [], [] # Collect all detections above confidence threshold
for box in res.boxes: all_detections = []
all_boxes = []
regions_dict = {}
logger.debug(f"Raw detection results from model: {len(res.boxes) if res.boxes is not None else 0} detections")
for i, box in enumerate(res.boxes):
conf = float(box.cpu().conf[0]) conf = float(box.cpu().conf[0])
cid = int(box.cpu().cls[0]) cid = int(box.cpu().cls[0])
name = node["model"].names[cid] name = node["model"].names[cid]
if conf < node["minConfidence"]:
continue
xy = box.cpu().xyxy[0]
x1,y1,x2,y2 = map(int, xy)
dets.append({"class": name, "confidence": conf,
"id": box.id.item() if hasattr(box, "id") else None})
boxes.append((x1, y1, x2, y2))
if not dets: logger.debug(f"Detection {i}: class='{name}' (id={cid}), confidence={conf:.3f}, threshold={node['minConfidence']}")
if conf < node["minConfidence"]:
logger.debug(f" -> REJECTED: confidence {conf:.3f} < threshold {node['minConfidence']}")
continue
xy = box.cpu().xyxy[0]
x1, y1, x2, y2 = map(int, xy)
bbox = (x1, y1, x2, y2)
detection = {
"class": name,
"confidence": conf,
"id": box.id.item() if hasattr(box, "id") else None,
"bbox": bbox
}
all_detections.append(detection)
all_boxes.append(bbox)
logger.debug(f" -> ACCEPTED: {name} with confidence {conf:.3f}, bbox={bbox}")
# Store highest confidence detection for each class
if name not in regions_dict or conf > regions_dict[name]["confidence"]:
regions_dict[name] = {
"bbox": bbox,
"confidence": conf,
"detection": detection
}
logger.debug(f" -> Updated regions_dict['{name}'] with confidence {conf:.3f}")
logger.info(f"Detection summary: {len(all_detections)} accepted detections from {len(res.boxes) if res.boxes is not None else 0} total")
logger.info(f"Detected classes: {list(regions_dict.keys())}")
if not all_detections:
logger.warning("No detections above confidence threshold - returning null")
return (None, None) if return_bbox else None return (None, None) if return_bbox else None
# take highestconfidence # ─── Multi-class validation ─────────────────────────────────
best_idx = max(range(len(dets)), key=lambda i: dets[i]["confidence"]) if node.get("multiClass", False) and node.get("expectedClasses"):
best_det = dets[best_idx] expected_classes = node["expectedClasses"]
best_box = boxes[best_idx] detected_classes = list(regions_dict.keys())
# ─── Branch (classification) ─────────────────────────────── logger.info(f"Multi-class validation: expected={expected_classes}, detected={detected_classes}")
# Check if at least one expected class is detected (flexible mode)
matching_classes = [cls for cls in expected_classes if cls in detected_classes]
missing_classes = [cls for cls in expected_classes if cls not in detected_classes]
logger.debug(f"Matching classes: {matching_classes}, Missing classes: {missing_classes}")
if not matching_classes:
# No expected classes found at all
logger.warning(f"PIPELINE REJECTED: No expected classes detected. Expected: {expected_classes}, Detected: {detected_classes}")
return (None, None) if return_bbox else None
if missing_classes:
logger.info(f"Partial multi-class detection: {matching_classes} found, {missing_classes} missing")
else:
logger.info(f"Complete multi-class detection success: {detected_classes}")
else:
logger.debug("No multi-class validation - proceeding with all detections")
# ─── Execute actions with region information ────────────────
detection_result = {
"detections": all_detections,
"regions": regions_dict,
**(context or {})
}
# ─── Create initial database record when Car+Frontal detected ────
if node.get("db_manager") and node.get("multiClass", False):
# Only create database record if we have both Car and Frontal
has_car = "Car" in regions_dict
has_frontal = "Frontal" in regions_dict
if has_car and has_frontal:
# Generate UUID session_id since client session is None for now
import uuid as uuid_lib
from datetime import datetime
generated_session_id = str(uuid_lib.uuid4())
# Insert initial detection record
display_id = detection_result.get("display_id", "unknown")
timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
inserted_session_id = node["db_manager"].insert_initial_detection(
display_id=display_id,
captured_timestamp=timestamp,
session_id=generated_session_id
)
if inserted_session_id:
# Update detection_result with the generated session_id for actions and branches
detection_result["session_id"] = inserted_session_id
detection_result["timestamp"] = timestamp # Update with proper timestamp
logger.info(f"Created initial database record with session_id: {inserted_session_id}")
else:
logger.debug(f"Database record not created - missing required classes. Has Car: {has_car}, Has Frontal: {has_frontal}")
execute_actions(node, frame, detection_result, regions_dict)
# ─── Parallel branch processing ─────────────────────────────
if node["branches"]:
branch_results = {}
# Filter branches that should be triggered
active_branches = []
for br in node["branches"]: for br in node["branches"]:
if (best_det["class"] in br["triggerClasses"] trigger_classes = br.get("triggerClasses", [])
and best_det["confidence"] >= br["minConfidence"]): min_conf = br.get("minConfidence", 0)
# crop if requested
sub = frame
if br["crop"]:
x1,y1,x2,y2 = best_box
sub = frame[y1:y2, x1:x2]
sub = cv2.resize(sub, (224, 224))
det2, _ = run_pipeline(sub, br, return_bbox=True) logger.debug(f"Evaluating branch {br['modelId']}: trigger_classes={trigger_classes}, min_conf={min_conf}")
if det2:
# return classification result + original bbox
execute_actions(br, sub, det2)
return (det2, best_box) if return_bbox else det2
# ─── No branch matched → return this detection ───────────── # Check if any detected class matches branch trigger
execute_actions(node, frame, best_det) branch_triggered = False
return (best_det, best_box) if return_bbox else best_det for det_class in regions_dict:
det_confidence = regions_dict[det_class]["confidence"]
logger.debug(f" Checking detected class '{det_class}' (confidence={det_confidence:.3f}) against triggers {trigger_classes}")
if (det_class in trigger_classes and det_confidence >= min_conf):
active_branches.append(br)
branch_triggered = True
logger.info(f"Branch {br['modelId']} activated by class '{det_class}' (conf={det_confidence:.3f} >= {min_conf})")
break
if not branch_triggered:
logger.debug(f"Branch {br['modelId']} not triggered - no matching classes or insufficient confidence")
if active_branches:
if node.get("parallel", False) or any(br.get("parallel", False) for br in active_branches):
# Run branches in parallel
with concurrent.futures.ThreadPoolExecutor(max_workers=len(active_branches)) as executor:
futures = {}
for br in active_branches:
crop_class = br.get("cropClass", br.get("triggerClasses", [])[0] if br.get("triggerClasses") else None)
sub_frame = frame
logger.info(f"Starting parallel branch: {br['modelId']}, crop_class: {crop_class}")
if br.get("crop", False) and crop_class:
cropped = crop_region_by_class(frame, regions_dict, crop_class)
if cropped is not None:
sub_frame = cv2.resize(cropped, (224, 224))
logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']}")
else:
logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch")
continue
future = executor.submit(run_pipeline, sub_frame, br, True, context)
futures[future] = br
# Collect results
for future in concurrent.futures.as_completed(futures):
br = futures[future]
try:
result, _ = future.result()
if result:
branch_results[br["modelId"]] = result
logger.info(f"Branch {br['modelId']} completed: {result}")
except Exception as e:
logger.error(f"Branch {br['modelId']} failed: {e}")
else:
# Run branches sequentially
for br in active_branches:
crop_class = br.get("cropClass", br.get("triggerClasses", [])[0] if br.get("triggerClasses") else None)
sub_frame = frame
logger.info(f"Starting sequential branch: {br['modelId']}, crop_class: {crop_class}")
if br.get("crop", False) and crop_class:
cropped = crop_region_by_class(frame, regions_dict, crop_class)
if cropped is not None:
sub_frame = cv2.resize(cropped, (224, 224))
logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']}")
else:
logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch")
continue
try:
result, _ = run_pipeline(sub_frame, br, True, context)
if result:
branch_results[br["modelId"]] = result
logger.info(f"Branch {br['modelId']} completed: {result}")
else:
logger.warning(f"Branch {br['modelId']} returned no result")
except Exception as e:
logger.error(f"Error in sequential branch {br['modelId']}: {e}")
import traceback
logger.debug(f"Branch error traceback: {traceback.format_exc()}")
# Store branch results in detection_result for parallel actions
detection_result["branch_results"] = branch_results
# ─── Execute Parallel Actions ───────────────────────────────
if node.get("parallelActions") and "branch_results" in detection_result:
execute_parallel_actions(node, frame, detection_result, regions_dict)
# ─── Return detection result ────────────────────────────────
primary_detection = max(all_detections, key=lambda x: x["confidence"])
primary_bbox = primary_detection["bbox"]
# Add branch results to primary detection for compatibility
if "branch_results" in detection_result:
primary_detection["branch_results"] = detection_result["branch_results"]
return (primary_detection, primary_bbox) if return_bbox else primary_detection
except Exception as e: except Exception as e:
logging.error(f"Error in node {node.get('modelId')}: {e}") logger.error(f"Error in node {node.get('modelId')}: {e}")
traceback.print_exc()
return (None, None) if return_bbox else None return (None, None) if return_bbox else None

View file

@ -2,6 +2,12 @@
This document outlines the WebSocket-based communication protocol between the CMS backend and a detector worker. As a worker developer, your primary responsibility is to implement a WebSocket server that adheres to this protocol. This document outlines the WebSocket-based communication protocol between the CMS backend and a detector worker. As a worker developer, your primary responsibility is to implement a WebSocket server that adheres to this protocol.
The current Python Detector Worker implementation supports advanced computer vision pipelines with:
- Multi-class YOLO detection with parallel processing
- PostgreSQL database integration with automatic schema management
- Redis integration for image storage and pub/sub messaging
- Hierarchical pipeline execution with detection → classification branching
## 1. Connection ## 1. Connection
The worker must run a WebSocket server, preferably on port `8000`. The backend system, which is managed by a container orchestration service, will automatically discover and establish a WebSocket connection to your worker. The worker must run a WebSocket server, preferably on port `8000`. The backend system, which is managed by a container orchestration service, will automatically discover and establish a WebSocket connection to your worker.
@ -25,14 +31,34 @@ To enable modularity and dynamic configuration, the backend will send you a URL
2. Extracting its contents. 2. Extracting its contents.
3. Interpreting the contents to configure its internal pipeline. 3. Interpreting the contents to configure its internal pipeline.
**The contents of the `.mpta` file are entirely up to the user who configures the model in the CMS.** This allows for maximum flexibility. For example, the archive could contain: **The current implementation supports comprehensive pipeline configurations including:**
- AI/ML Models: Pre-trained models for libraries like TensorFlow, PyTorch, or ONNX. - **AI/ML Models**: YOLO models (.pt files) for detection and classification
- Configuration Files: A `config.json` or `pipeline.yaml` that defines a sequence of operations, specifies model paths, or sets detection thresholds. - **Pipeline Configuration**: `pipeline.json` defining hierarchical detection→classification workflows
- Scripts: Custom Python scripts for pre-processing or post-processing. - **Multi-class Detection**: Simultaneous detection of multiple object classes (e.g., Car + Frontal)
- API Integration Details: A JSON file with endpoint information and credentials for interacting with third-party detection services. - **Parallel Processing**: Concurrent execution of classification branches with ThreadPoolExecutor
- **Database Integration**: PostgreSQL configuration for automatic table creation and updates
- **Redis Actions**: Image storage with region cropping and pub/sub messaging
- **Dynamic Field Mapping**: Template-based field resolution for database operations
Essentially, the `.mpta` file is a self-contained package that tells your worker *how* to process the video stream for a given subscription. **Enhanced MPTA Structure:**
```
pipeline.mpta/
├── pipeline.json # Main configuration with redis/postgresql settings
├── car_detection.pt # Primary YOLO detection model
├── brand_classifier.pt # Classification model for car brands
├── bodytype_classifier.pt # Classification model for body types
└── ...
```
The `pipeline.json` now supports advanced features like:
- Multi-class detection with `expectedClasses` validation
- Parallel branch processing with `parallel: true`
- Database actions with `postgresql_update_combined`
- Redis actions with region-specific image cropping
- Branch synchronization with `waitForBranches`
Essentially, the `.mpta` file is a self-contained package that tells your worker *how* to process the video stream for a given subscription, including complex multi-stage AI pipelines with database persistence.
## 4. Messages from Worker to Backend ## 4. Messages from Worker to Backend
@ -79,6 +105,15 @@ Sent when the worker detects a relevant object. The `detection` object should be
- **Type:** `imageDetection` - **Type:** `imageDetection`
**Enhanced Detection Capabilities:**
The current implementation supports multi-class detection with parallel classification processing. When a vehicle is detected, the system:
1. **Multi-Class Detection**: Simultaneously detects "Car" and "Frontal" classes
2. **Parallel Processing**: Runs brand and body type classification concurrently
3. **Database Integration**: Automatically creates and updates PostgreSQL records
4. **Redis Storage**: Saves cropped frontal images with expiration
**Payload Example:** **Payload Example:**
```json ```json
@ -88,19 +123,38 @@ Sent when the worker detects a relevant object. The `detection` object should be
"timestamp": "2025-07-14T12:34:56.789Z", "timestamp": "2025-07-14T12:34:56.789Z",
"data": { "data": {
"detection": { "detection": {
"carModel": "Civic", "class": "Car",
"confidence": 0.92,
"carBrand": "Honda", "carBrand": "Honda",
"carYear": 2023, "carModel": "Civic",
"bodyType": "Sedan", "bodyType": "Sedan",
"licensePlateText": "ABCD1234", "branch_results": {
"licensePlateConfidence": 0.95 "car_brand_cls_v1": {
"class": "Honda",
"confidence": 0.89,
"brand": "Honda"
},
"car_bodytype_cls_v1": {
"class": "Sedan",
"confidence": 0.85,
"body_type": "Sedan"
}
}
}, },
"modelId": 101, "modelId": 101,
"modelName": "US-LPR-and-Vehicle-ID" "modelName": "Car Frontal Detection V1"
} }
} }
``` ```
**Database Integration:**
Each detection automatically:
- Creates a record in `gas_station_1.car_frontal_info` table
- Generates a unique `session_id` for tracking
- Updates the record with classification results after parallel processing completes
- Stores cropped frontal images in Redis with the session_id as key
### 4.3. Patch Session ### 4.3. Patch Session
> **Note:** Patch messages are only used when the worker can't keep up and needs to retroactively send detections. Normally, detections should be sent in real-time using `imageDetection` messages. Use `patchSession` only to update session data after the fact. > **Note:** Patch messages are only used when the worker can't keep up and needs to retroactively send detections. Normally, detections should be sent in real-time using `imageDetection` messages. Use `patchSession` only to update session data after the fact.