Compare commits
7 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
48db3234ed | ||
|
|
aa883e4bfa | ||
|
|
b7d8b3266f | ||
|
|
ffc2e99678 | ||
|
|
e471ab03e9 | ||
|
|
3d0aaab8b3 | ||
|
|
7085a6e00f |
152 changed files with 21461 additions and 14561 deletions
|
|
@ -1,11 +0,0 @@
|
|||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(dir:*)",
|
||||
"WebSearch",
|
||||
"Bash(mkdir:*)"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": []
|
||||
}
|
||||
}
|
||||
|
|
@ -51,7 +51,7 @@ jobs:
|
|||
registry: git.siwatsystem.com
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.RUNNER_TOKEN }}
|
||||
|
||||
|
||||
- name: Build and push base Docker image
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
|
|
@ -79,7 +79,7 @@ jobs:
|
|||
registry: git.siwatsystem.com
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.RUNNER_TOKEN }}
|
||||
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
|
|
@ -103,4 +103,10 @@ jobs:
|
|||
- name: Deploy stack
|
||||
run: |
|
||||
echo "Pulling and starting containers on server..."
|
||||
ssh -i ~/.ssh/id_rsa ${{ vars.DEPLOY_USER_CMS }}@${{ vars.DEPLOY_HOST_CMS }} "cd ~/cms-system-k8s && docker compose -f docker-compose.staging.yml -f docker-compose.production.yml pull && docker compose -f docker-compose.staging.yml -f docker-compose.production.yml up -d"
|
||||
if [ "${{ github.ref_name }}" = "main" ]; then
|
||||
echo "Deploying production stack..."
|
||||
ssh -i ~/.ssh/id_rsa ${{ vars.DEPLOY_USER_CMS }}@${{ vars.DEPLOY_HOST_CMS }} "cd ~/cms-system-k8s && docker compose -f docker-compose.production.yml pull && docker compose -f docker-compose.production.yml up -d"
|
||||
else
|
||||
echo "Deploying staging stack..."
|
||||
ssh -i ~/.ssh/id_rsa ${{ vars.DEPLOY_USER_CMS }}@${{ vars.DEPLOY_HOST_CMS }} "cd ~/cms-system-k8s && docker compose -f docker-compose.staging.yml pull && docker compose -f docker-compose.staging.yml up -d"
|
||||
fi
|
||||
13
.gitignore
vendored
13
.gitignore
vendored
|
|
@ -1,8 +1,11 @@
|
|||
/models
|
||||
# Do not know how to use
|
||||
archive/
|
||||
Dockerfile
|
||||
|
||||
# /models
|
||||
app.log
|
||||
*.pt
|
||||
|
||||
images
|
||||
.venv/
|
||||
|
||||
# All pycache directories
|
||||
__pycache__/
|
||||
|
|
@ -12,3 +15,7 @@ mptas
|
|||
detector_worker.log
|
||||
.gitignore
|
||||
no_frame_debug.log
|
||||
|
||||
|
||||
# Result from tracker
|
||||
feeder/runs/
|
||||
127
Dockerfile.base
127
Dockerfile.base
|
|
@ -1,130 +1,15 @@
|
|||
# Base image with complete ML and hardware acceleration stack
|
||||
FROM pytorch/pytorch:2.8.0-cuda12.6-cudnn9-runtime
|
||||
# Base image with all ML dependencies
|
||||
FROM python:3.13-bookworm
|
||||
|
||||
# Install build dependencies and system libraries
|
||||
RUN apt-get update && apt-get install -y \
|
||||
# Build tools
|
||||
build-essential \
|
||||
cmake \
|
||||
git \
|
||||
pkg-config \
|
||||
wget \
|
||||
unzip \
|
||||
yasm \
|
||||
nasm \
|
||||
# Additional dependencies for FFmpeg/NVIDIA build
|
||||
libtool \
|
||||
libc6 \
|
||||
libc6-dev \
|
||||
libnuma1 \
|
||||
libnuma-dev \
|
||||
# Essential compilation libraries
|
||||
gcc \
|
||||
g++ \
|
||||
libc6-dev \
|
||||
linux-libc-dev \
|
||||
# System libraries
|
||||
libgl1-mesa-glx \
|
||||
libglib2.0-0 \
|
||||
libgomp1 \
|
||||
# Core media libraries (essential ones only)
|
||||
libjpeg-dev \
|
||||
libpng-dev \
|
||||
libx264-dev \
|
||||
libx265-dev \
|
||||
libvpx-dev \
|
||||
libmp3lame-dev \
|
||||
libv4l-dev \
|
||||
# TurboJPEG for fast JPEG encoding
|
||||
libturbojpeg0-dev \
|
||||
# Python development
|
||||
python3-dev \
|
||||
python3-numpy \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
# Install system dependencies
|
||||
RUN apt update && apt install -y libgl1 && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Add NVIDIA CUDA repository and install minimal development tools
|
||||
RUN apt-get update && apt-get install -y wget gnupg && \
|
||||
wget -O - https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub | apt-key add - && \
|
||||
echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y \
|
||||
cuda-nvcc-12-6 \
|
||||
cuda-cudart-dev-12-6 \
|
||||
libnpp-dev-12-6 \
|
||||
&& apt-get remove -y wget gnupg && \
|
||||
apt-get autoremove -y && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Ensure CUDA paths are available
|
||||
ENV PATH="/usr/local/cuda/bin:${PATH}"
|
||||
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
|
||||
|
||||
# Install NVIDIA Video Codec SDK headers (official method)
|
||||
RUN cd /tmp && \
|
||||
git clone https://git.videolan.org/git/ffmpeg/nv-codec-headers.git && \
|
||||
cd nv-codec-headers && \
|
||||
make install && \
|
||||
cd / && rm -rf /tmp/*
|
||||
|
||||
# Build FFmpeg from source with NVIDIA CUDA support
|
||||
RUN cd /tmp && \
|
||||
echo "Building FFmpeg with NVIDIA CUDA support..." && \
|
||||
# Download FFmpeg source (official method)
|
||||
git clone https://git.ffmpeg.org/ffmpeg.git ffmpeg/ && \
|
||||
cd ffmpeg && \
|
||||
# Configure with NVIDIA support (simplified to avoid configure issues)
|
||||
./configure \
|
||||
--prefix=/usr/local \
|
||||
--enable-shared \
|
||||
--disable-static \
|
||||
--enable-nonfree \
|
||||
--enable-gpl \
|
||||
--enable-cuda-nvcc \
|
||||
--enable-cuvid \
|
||||
--enable-nvdec \
|
||||
--enable-nvenc \
|
||||
--enable-libnpp \
|
||||
--extra-cflags=-I/usr/local/cuda/include \
|
||||
--extra-ldflags=-L/usr/local/cuda/lib64 \
|
||||
--enable-libx264 \
|
||||
--enable-libx265 \
|
||||
--enable-libvpx \
|
||||
--enable-libmp3lame && \
|
||||
# Build and install
|
||||
make -j$(nproc) && \
|
||||
make install && \
|
||||
ldconfig && \
|
||||
# Verify CUVID decoders are available
|
||||
echo "=== Verifying FFmpeg CUVID Support ===" && \
|
||||
(ffmpeg -hide_banner -decoders 2>/dev/null | grep cuvid || echo "No CUVID decoders found") && \
|
||||
echo "=== Verifying FFmpeg NVENC Support ===" && \
|
||||
(ffmpeg -hide_banner -encoders 2>/dev/null | grep nvenc || echo "No NVENC encoders found") && \
|
||||
cd / && rm -rf /tmp/*
|
||||
|
||||
# Set environment variables for maximum hardware acceleration
|
||||
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/lib:${LD_LIBRARY_PATH}"
|
||||
ENV PKG_CONFIG_PATH="/usr/local/lib/pkgconfig:${PKG_CONFIG_PATH}"
|
||||
ENV PYTHONPATH="/usr/local/lib/python3.10/dist-packages:${PYTHONPATH}"
|
||||
|
||||
# Optimized environment variables for hardware acceleration
|
||||
ENV OPENCV_FFMPEG_CAPTURE_OPTIONS="rtsp_transport;tcp|hwaccel;cuda|hwaccel_device;0|video_codec;h264_cuvid|hwaccel_output_format;cuda"
|
||||
ENV OPENCV_FFMPEG_WRITER_OPTIONS="video_codec;h264_nvenc|preset;fast|tune;zerolatency|gpu;0"
|
||||
ENV CUDA_VISIBLE_DEVICES=0
|
||||
ENV NVIDIA_VISIBLE_DEVICES=all
|
||||
ENV NVIDIA_DRIVER_CAPABILITIES=compute,video,utility
|
||||
|
||||
# Copy and install base requirements (exclude opencv-python since we built from source)
|
||||
# Copy and install base requirements (ML dependencies that rarely change)
|
||||
COPY requirements.base.txt .
|
||||
RUN grep -v opencv-python requirements.base.txt > requirements.tmp && \
|
||||
mv requirements.tmp requirements.base.txt && \
|
||||
pip install --no-cache-dir -r requirements.base.txt
|
||||
RUN pip install --no-cache-dir -r requirements.base.txt
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Create images directory for bind mount
|
||||
RUN mkdir -p /app/images && \
|
||||
chmod 755 /app/images
|
||||
|
||||
# This base image will be reused for all worker builds
|
||||
CMD ["python3", "-m", "fastapi", "run", "--host", "0.0.0.0", "--port", "8000"]
|
||||
545
REFACTOR_PLAN.md
545
REFACTOR_PLAN.md
|
|
@ -1,545 +0,0 @@
|
|||
# Detector Worker Refactoring Plan
|
||||
|
||||
## Project Overview
|
||||
|
||||
Transform the current monolithic structure (~4000 lines across `app.py` and `siwatsystem/pympta.py`) into a modular, maintainable system with clear separation of concerns. The goal is to make the sophisticated computer vision pipeline easily understandable for other engineers while maintaining all existing functionality.
|
||||
|
||||
## Current System Flow Understanding
|
||||
|
||||
### Validated System Flow
|
||||
1. **WebSocket Connection** → Backend connects and sends `setSubscriptionList`
|
||||
2. **Model Management** → Download unique `.mpta` files to `models/` and extract
|
||||
3. **Tracking Phase** → Continuous tracking with `front_rear_detection_v1.pt`
|
||||
4. **Validation Phase** → Validate stable car (not just passing by)
|
||||
5. **Pipeline Execution** →
|
||||
- Detect car with `yolo11m.pt`
|
||||
- **Branch 1**: Front/rear detection → crop frontal → save to Redis + brand classification
|
||||
- **Branch 2**: Body type classification from car crop
|
||||
6. **Communication** → Send `imageDetection` → Backend generates `sessionId` → Fueling starts
|
||||
7. **Post-Fueling** → Backend clears `sessionId` → Continue tracking same car to avoid re-pipeline
|
||||
|
||||
### Core Responsibilities Identified
|
||||
1. **WebSocket Communication** - Message handling and protocol compliance
|
||||
2. **Stream Management** - RTSP/HTTP frame processing and buffering
|
||||
3. **Model Management** - MPTA download, extraction, and loading
|
||||
4. **Pipeline Configuration** - Parse `pipeline.json` and setup execution flow
|
||||
5. **Vehicle Tracking** - Continuous tracking and car identification
|
||||
6. **Validation Logic** - Stable car detection vs. passing-by cars
|
||||
7. **Detection Pipeline** - Main ML pipeline with parallel branches
|
||||
8. **Data Persistence** - Redis/PostgreSQL operations
|
||||
9. **Session Management** - Handle session IDs and lifecycle
|
||||
|
||||
## Proposed Directory Structure
|
||||
|
||||
```
|
||||
core/
|
||||
├── communication/
|
||||
│ ├── __init__.py
|
||||
│ ├── websocket.py # WebSocket message handling & protocol
|
||||
│ ├── messages.py # Message types and validation
|
||||
│ ├── models.py # Message data structures
|
||||
│ └── state.py # Worker state management
|
||||
├── streaming/
|
||||
│ ├── __init__.py
|
||||
│ ├── manager.py # Stream coordination and lifecycle
|
||||
│ ├── readers.py # RTSP/HTTP frame readers
|
||||
│ └── buffers.py # Frame buffering and caching
|
||||
├── models/
|
||||
│ ├── __init__.py
|
||||
│ ├── manager.py # MPTA download and model loading
|
||||
│ ├── pipeline.py # Pipeline.json parser and config
|
||||
│ └── inference.py # YOLO model wrapper and optimization
|
||||
├── tracking/
|
||||
│ ├── __init__.py
|
||||
│ ├── tracker.py # Vehicle tracking with front_rear_detection_v1
|
||||
│ ├── validator.py # Stable car validation logic
|
||||
│ └── integration.py # Tracking-pipeline integration
|
||||
├── detection/
|
||||
│ ├── __init__.py
|
||||
│ ├── pipeline.py # Main detection pipeline orchestration
|
||||
│ └── branches.py # Parallel branch processing (brand/bodytype)
|
||||
└── storage/
|
||||
├── __init__.py
|
||||
├── redis.py # Redis operations and image storage
|
||||
└── database.py # PostgreSQL operations (existing - will be moved)
|
||||
```
|
||||
|
||||
## Implementation Strategy (Feature-by-Feature Testing)
|
||||
|
||||
### Phase 1: Communication Layer
|
||||
- WebSocket message handling (setSubscriptionList, sessionId management)
|
||||
- HTTP API endpoints (camera image retrieval)
|
||||
- Worker state reporting
|
||||
|
||||
### Phase 2: Pipeline Configuration Reader
|
||||
- Parse `pipeline.json`
|
||||
- Model dependency resolution
|
||||
- Branch configuration setup
|
||||
|
||||
### Phase 3: Tracking System
|
||||
- Continuous vehicle tracking
|
||||
- Car identification and persistence
|
||||
|
||||
### Phase 4: Tracking Validator
|
||||
- Stable car detection logic
|
||||
- Passing-by vs. fueling car differentiation
|
||||
|
||||
### Phase 5: Model Pipeline Execution
|
||||
- Main detection pipeline
|
||||
- Parallel branch processing
|
||||
- Redis/DB integration
|
||||
|
||||
### Phase 6: Post-Session Tracking Validation
|
||||
- Same car validation after sessionId cleared
|
||||
- Prevent duplicate pipeline execution
|
||||
|
||||
## Key Preservation Requirements
|
||||
- **HTTP Endpoint**: `/camera/{camera_id}/image` must remain unchanged
|
||||
- **WebSocket Protocol**: Full compliance with `worker.md` specification
|
||||
- **MPTA Format**: Maintain compatibility with existing model archives
|
||||
- **Database Schema**: Keep existing PostgreSQL structure
|
||||
- **Redis Integration**: Preserve image storage and pub/sub functionality
|
||||
- **Configuration**: Maintain `config.json` compatibility
|
||||
- **Logging**: Preserve structured logging format
|
||||
|
||||
## Expected Benefits
|
||||
- **Maintainability**: Single responsibility modules (~200-400 lines each)
|
||||
- **Testability**: Independent testing of each component
|
||||
- **Readability**: Clear separation of concerns
|
||||
- **Scalability**: Easy to extend and modify individual components
|
||||
- **Documentation**: Self-documenting code structure
|
||||
|
||||
---
|
||||
|
||||
# Comprehensive TODO List
|
||||
|
||||
## ✅ Phase 1: Project Setup & Communication Layer - COMPLETED
|
||||
|
||||
### 1.1 Project Structure Setup
|
||||
- ✅ Create `core/` directory structure
|
||||
- ✅ Create all module directories and `__init__.py` files
|
||||
- ✅ Set up logging configuration for new modules
|
||||
- ✅ Update imports in existing files to prepare for migration
|
||||
|
||||
### 1.2 Communication Module (`core/communication/`)
|
||||
- ✅ **Create `models.py`** - Message data structures
|
||||
- ✅ Define WebSocket message models (SubscriptionList, StateReport, etc.)
|
||||
- ✅ Add validation schemas for incoming messages
|
||||
- ✅ Create response models for outgoing messages
|
||||
|
||||
- ✅ **Create `messages.py`** - Message types and validation
|
||||
- ✅ Implement message type constants
|
||||
- ✅ Add message validation functions
|
||||
- ✅ Create message builders for common responses
|
||||
|
||||
- ✅ **Create `websocket.py`** - WebSocket message handling
|
||||
- ✅ Extract WebSocket connection management from `app.py`
|
||||
- ✅ Implement message routing and dispatching
|
||||
- ✅ Add connection lifecycle management (connect, disconnect, reconnect)
|
||||
- ✅ Handle `setSubscriptionList` message processing
|
||||
- ✅ Handle `setSessionId` and `setProgressionStage` messages
|
||||
- ✅ Handle `requestState` and `patchSessionResult` messages
|
||||
|
||||
- ✅ **Create `state.py`** - Worker state management
|
||||
- ✅ Extract state reporting logic from `app.py`
|
||||
- ✅ Implement system metrics collection (CPU, memory, GPU)
|
||||
- ✅ Manage active subscriptions state
|
||||
- ✅ Handle session ID mapping and storage
|
||||
|
||||
### 1.3 HTTP API Preservation
|
||||
- ✅ **Preserve `/camera/{camera_id}/image` endpoint**
|
||||
- ✅ Extract REST API logic from `app.py`
|
||||
- ✅ Ensure frame caching mechanism works with new structure
|
||||
- ✅ Maintain exact same response format and error handling
|
||||
|
||||
### 1.4 Testing Phase 1
|
||||
- ✅ Test WebSocket connection and message handling
|
||||
- ✅ Test HTTP API endpoint functionality
|
||||
- ✅ Verify state reporting works correctly
|
||||
- ✅ Test session management functionality
|
||||
|
||||
### 1.5 Phase 1 Results
|
||||
- ✅ **Modular Architecture**: Transformed ~900 lines into 4 focused modules (~200 lines each)
|
||||
- ✅ **WebSocket Protocol**: Full compliance with worker.md specification
|
||||
- ✅ **System Metrics**: Real-time CPU, memory, GPU monitoring
|
||||
- ✅ **State Management**: Thread-safe subscription and session tracking
|
||||
- ✅ **Backward Compatibility**: All existing endpoints preserved
|
||||
- ✅ **Modern FastAPI**: Lifespan events, Pydantic v2 compatibility
|
||||
|
||||
## ✅ Phase 2: Pipeline Configuration & Model Management - COMPLETED
|
||||
|
||||
### 2.1 Models Module (`core/models/`)
|
||||
- ✅ **Create `pipeline.py`** - Pipeline.json parser
|
||||
- ✅ Extract pipeline configuration parsing from `pympta.py`
|
||||
- ✅ Implement pipeline validation
|
||||
- ✅ Add configuration schema validation
|
||||
- ✅ Handle Redis and PostgreSQL configuration parsing
|
||||
|
||||
- ✅ **Create `manager.py`** - MPTA download and model loading
|
||||
- ✅ Extract MPTA download logic from `pympta.py`
|
||||
- ✅ Implement ZIP extraction and validation
|
||||
- ✅ Add model file management and caching
|
||||
- ✅ Handle model loading with GPU optimization
|
||||
- ✅ Implement model dependency resolution
|
||||
|
||||
- ✅ **Create `inference.py`** - YOLO model wrapper
|
||||
- ✅ Create unified YOLO model interface
|
||||
- ✅ Add inference optimization and caching
|
||||
- ✅ Implement batch processing capabilities
|
||||
- ✅ Handle model switching and memory management
|
||||
|
||||
### 2.2 Testing Phase 2
|
||||
- ✅ Test MPTA file download and extraction
|
||||
- ✅ Test pipeline.json parsing and validation
|
||||
- ✅ Test model loading with different configurations
|
||||
- ✅ Verify GPU optimization works correctly
|
||||
|
||||
### 2.3 Phase 2 Results
|
||||
- ✅ **ModelManager**: Downloads, extracts, and manages MPTA files with model ID-based directory structure
|
||||
- ✅ **PipelineParser**: Parses and validates pipeline.json with full support for Redis, PostgreSQL, tracking, and branches
|
||||
- ✅ **YOLOWrapper**: Unified interface for YOLO models with caching, tracking, and classification support
|
||||
- ✅ **Model Caching**: Shared model cache across instances to optimize memory usage
|
||||
- ✅ **Dependency Resolution**: Automatically identifies and tracks all model file dependencies
|
||||
|
||||
## ✅ Phase 3: Streaming System - COMPLETED
|
||||
|
||||
### 3.1 Streaming Module (`core/streaming/`)
|
||||
- ✅ **Create `readers.py`** - RTSP/HTTP frame readers
|
||||
- ✅ Extract `frame_reader` function from `app.py`
|
||||
- ✅ Extract `snapshot_reader` function from `app.py`
|
||||
- ✅ Add connection management and retry logic
|
||||
- ✅ Implement frame rate control and optimization
|
||||
|
||||
- ✅ **Create `buffers.py`** - Frame buffering and caching
|
||||
- ✅ Extract frame buffer management from `app.py`
|
||||
- ✅ Implement efficient frame caching for REST API
|
||||
- ✅ Add buffer size management and memory optimization
|
||||
|
||||
- ✅ **Create `manager.py`** - Stream coordination
|
||||
- ✅ Extract stream lifecycle management from `app.py`
|
||||
- ✅ Implement shared stream optimization
|
||||
- ✅ Add subscription reconciliation logic
|
||||
- ✅ Handle stream sharing across multiple subscriptions
|
||||
|
||||
### 3.2 Testing Phase 3
|
||||
- ✅ Test RTSP stream reading and buffering
|
||||
- ✅ Test HTTP snapshot capture functionality
|
||||
- ✅ Test shared stream optimization
|
||||
- ✅ Verify frame caching for REST API access
|
||||
|
||||
### 3.3 Phase 3 Results
|
||||
- ✅ **RTSPReader**: OpenCV-based RTSP stream reader with automatic reconnection and frame callbacks
|
||||
- ✅ **HTTPSnapshotReader**: Periodic HTTP snapshot capture with HTTPBasicAuth and HTTPDigestAuth support
|
||||
- ✅ **FrameBuffer**: Thread-safe frame storage with automatic aging and cleanup
|
||||
- ✅ **CacheBuffer**: Enhanced frame cache with cropping support and highest quality JPEG encoding (default quality=100)
|
||||
- ✅ **StreamManager**: Complete stream lifecycle management with shared optimization and subscription reconciliation
|
||||
- ✅ **Authentication Support**: Proper handling of credentials in URLs with automatic auth type detection
|
||||
- ✅ **Real Camera Testing**: Verified with authenticated RTSP (1280x720) and HTTP snapshot (2688x1520) cameras
|
||||
- ✅ **Production Ready**: Stable concurrent streaming from multiple camera sources
|
||||
- ✅ **Dependencies**: Added opencv-python, numpy, and requests to requirements.txt
|
||||
|
||||
### 3.4 Recent Streaming Enhancements (Post-Phase 3)
|
||||
- ✅ **Format-Specific Optimization**: Tailored for 1280x720@6fps RTSP streams and 2560x1440 HTTP snapshots
|
||||
- ✅ **H.264 Error Recovery**: Enhanced error handling for corrupted frames with automatic stream recovery
|
||||
- ✅ **Frame Validation**: Implemented corruption detection using edge density analysis
|
||||
- ✅ **Buffer Size Optimization**: Adjusted buffer limits to 3MB for RTSP frames (1280x720x3 bytes)
|
||||
- ✅ **FFMPEG Integration**: Added environment variables to suppress verbose H.264 decoder errors
|
||||
- ✅ **URL Preservation**: Maintained clean RTSP URLs without parameter injection
|
||||
- ✅ **Type Detection**: Automatic stream type detection based on frame dimensions
|
||||
- ✅ **Quality Settings**: Format-specific JPEG quality (90% for RTSP, 95% for HTTP)
|
||||
|
||||
## ✅ Phase 4: Vehicle Tracking System - COMPLETED
|
||||
|
||||
### 4.1 Tracking Module (`core/tracking/`)
|
||||
- ✅ **Create `tracker.py`** - Vehicle tracking implementation (305 lines)
|
||||
- ✅ Implement continuous tracking with configurable model (front_rear_detection_v1.pt)
|
||||
- ✅ Add vehicle identification and persistence with TrackedVehicle dataclass
|
||||
- ✅ Implement tracking state management with thread-safe operations
|
||||
- ✅ Add bounding box tracking and motion analysis with position history
|
||||
- ✅ Multi-class tracking support for complex detection scenarios
|
||||
|
||||
- ✅ **Create `validator.py`** - Stable car validation (417 lines)
|
||||
- ✅ Implement stable car detection algorithm with multiple validation criteria
|
||||
- ✅ Add passing-by vs. fueling car differentiation using velocity and position analysis
|
||||
- ✅ Implement validation thresholds and timing with configurable parameters
|
||||
- ✅ Add confidence scoring for validation decisions with state history
|
||||
- ✅ Advanced motion analysis with velocity smoothing and position variance
|
||||
|
||||
- ✅ **Create `integration.py`** - Tracking-pipeline integration (547 lines)
|
||||
- ✅ Connect tracking system with main pipeline through TrackingPipelineIntegration
|
||||
- ✅ Handle tracking state transitions and session management
|
||||
- ✅ Implement post-session tracking validation with cooldown periods
|
||||
- ✅ Add same-car validation after sessionId cleared with 30-second cooldown
|
||||
- ✅ Car abandonment detection with automatic timeout monitoring
|
||||
- ✅ Mock detection system for backend communication
|
||||
- ✅ Async pipeline execution with proper error handling
|
||||
|
||||
### 4.2 Testing Phase 4
|
||||
- ✅ Test continuous vehicle tracking functionality
|
||||
- ✅ Test stable car validation logic
|
||||
- ✅ Test integration with existing pipeline
|
||||
- ✅ Verify tracking performance and accuracy
|
||||
- ✅ Test car abandonment detection with null detection messages
|
||||
- ✅ Verify session management and progression stage handling
|
||||
|
||||
### 4.3 Phase 4 Results
|
||||
- ✅ **VehicleTracker**: Complete tracking implementation with YOLO tracking integration, position history, and stability calculations
|
||||
- ✅ **StableCarValidator**: Sophisticated validation logic using velocity, position variance, and state consistency
|
||||
- ✅ **TrackingPipelineIntegration**: Full integration with pipeline system including session management and async processing
|
||||
- ✅ **StreamManager Integration**: Updated streaming manager to process tracking on every frame with proper threading
|
||||
- ✅ **Thread-Safe Operations**: All tracking operations are thread-safe with proper locking mechanisms
|
||||
- ✅ **Configurable Parameters**: All tracking parameters are configurable through pipeline.json
|
||||
- ✅ **Session Management**: Complete session lifecycle management with post-fueling validation
|
||||
- ✅ **Statistics and Monitoring**: Comprehensive statistics collection for tracking performance
|
||||
- ✅ **Car Abandonment Detection**: Automatic detection when cars leave without fueling, sends `detection: null` to backend
|
||||
- ✅ **Message Protocol**: Fixed JSON serialization to include `detection: null` for abandonment notifications
|
||||
- ✅ **Streaming Optimization**: Enhanced RTSP/HTTP readers for 1280x720@6fps RTSP and 2560x1440 HTTP snapshots
|
||||
- ✅ **Error Recovery**: Improved H.264 error handling and corrupted frame detection
|
||||
|
||||
## ✅ Phase 5: Detection Pipeline System - COMPLETED
|
||||
|
||||
### 5.1 Detection Module (`core/detection/`) ✅
|
||||
- ✅ **Create `pipeline.py`** - Main detection orchestration (574 lines)
|
||||
- ✅ Extracted main pipeline execution from `pympta.py` with full orchestration
|
||||
- ✅ Implemented detection flow coordination with async execution
|
||||
- ✅ Added pipeline state management with comprehensive statistics
|
||||
- ✅ Handled pipeline result aggregation with branch synchronization
|
||||
- ✅ Redis and database integration with error handling
|
||||
- ✅ Immediate and parallel action execution with template resolution
|
||||
|
||||
- ✅ **Create `branches.py`** - Parallel branch processing (442 lines)
|
||||
- ✅ Extracted parallel branch execution from `pympta.py`
|
||||
- ✅ Implemented ThreadPoolExecutor-based parallel processing
|
||||
- ✅ Added branch synchronization and result collection
|
||||
- ✅ Handled branch failure and retry logic with graceful degradation
|
||||
- ✅ Support for nested branches and model caching
|
||||
- ✅ Both detection and classification model support
|
||||
|
||||
### 5.2 Storage Module (`core/storage/`) ✅
|
||||
- ✅ **Create `redis.py`** - Redis operations (410 lines)
|
||||
- ✅ Extracted Redis action execution from `pympta.py`
|
||||
- ✅ Implemented async image storage with region cropping
|
||||
- ✅ Added pub/sub messaging functionality with JSON support
|
||||
- ✅ Handled Redis connection management and retry logic
|
||||
- ✅ Added statistics tracking and health monitoring
|
||||
- ✅ Support for various image formats (JPEG, PNG) with quality control
|
||||
|
||||
- ✅ **Move `database.py`** - PostgreSQL operations (339 lines)
|
||||
- ✅ Moved existing `archive/siwatsystem/database.py` to `core/storage/`
|
||||
- ✅ Updated imports and integration points
|
||||
- ✅ Ensured compatibility with new module structure
|
||||
- ✅ Added session management and statistics methods
|
||||
- ✅ Enhanced error handling and connection management
|
||||
|
||||
### 5.3 Integration Updates ✅
|
||||
- ✅ **Updated `core/tracking/integration.py`**
|
||||
- ✅ Added DetectionPipeline integration
|
||||
- ✅ Replaced placeholder `_execute_pipeline` with real implementation
|
||||
- ✅ Added detection pipeline initialization and cleanup
|
||||
- ✅ Integrated with existing tracking system flow
|
||||
- ✅ Maintained backward compatibility with test mode
|
||||
|
||||
### 5.4 Testing Phase 5 ✅
|
||||
- ✅ Verified module imports work correctly
|
||||
- ✅ All new modules follow established coding patterns
|
||||
- ✅ Integration points properly connected
|
||||
- ✅ Error handling and cleanup methods implemented
|
||||
- ✅ Statistics and monitoring capabilities added
|
||||
|
||||
### 5.5 Phase 5 Results ✅
|
||||
- ✅ **DetectionPipeline**: Complete detection orchestration with Redis/PostgreSQL integration, async execution, and comprehensive error handling
|
||||
- ✅ **BranchProcessor**: Parallel branch execution with ThreadPoolExecutor, model caching, and nested branch support
|
||||
- ✅ **RedisManager**: Async Redis operations with image storage, pub/sub messaging, and connection management
|
||||
- ✅ **DatabaseManager**: Enhanced PostgreSQL operations with session management and statistics
|
||||
- ✅ **Module Integration**: Seamless integration with existing tracking system while maintaining compatibility
|
||||
- ✅ **Error Handling**: Comprehensive error handling and graceful degradation throughout all components
|
||||
- ✅ **Performance**: Optimized parallel processing and caching for high-performance pipeline execution
|
||||
|
||||
## ✅ Additional Implemented Features (Not in Original Plan)
|
||||
|
||||
### License Plate Recognition Integration (`core/storage/license_plate.py`) ✅
|
||||
- ✅ **LicensePlateManager**: Subscribes to Redis channel `license_results` for external LPR service
|
||||
- ✅ **Multi-format Support**: Handles various message formats from LPR service
|
||||
- ✅ **Result Caching**: 5-minute TTL for license plate results
|
||||
- ✅ **WebSocket Integration**: Sends combined `imageDetection` messages with license data
|
||||
- ✅ **Asynchronous Processing**: Non-blocking Redis pub/sub listener
|
||||
|
||||
### Advanced Session State Management (`core/communication/state.py`) ✅
|
||||
- ✅ **Session ID Mapping**: Per-display session identifier tracking
|
||||
- ✅ **Progression Stage Tracking**: Workflow state per display (welcome, car_wait_staff, finished, cleared)
|
||||
- ✅ **Thread-Safe Operations**: RLock-based synchronization for concurrent access
|
||||
- ✅ **Comprehensive State Reporting**: Full system state for debugging
|
||||
|
||||
### Car Abandonment Detection (`core/tracking/integration.py`) ✅
|
||||
- ✅ **Abandonment Monitoring**: Detects cars leaving without completing fueling
|
||||
- ✅ **Timeout Configuration**: 3-second abandonment timeout
|
||||
- ✅ **Null Detection Messages**: Sends `detection: null` to backend for abandoned cars
|
||||
- ✅ **Automatic Cleanup**: Removes abandoned sessions from tracking
|
||||
|
||||
### Enhanced Message Protocol (`core/communication/models.py`) ✅
|
||||
- ✅ **PatchSessionResult**: Session data patching support
|
||||
- ✅ **SetProgressionStage**: Workflow stage management messages
|
||||
- ✅ **Null Detection Handling**: Support for abandonment notifications
|
||||
- ✅ **Complex Detection Structure**: Supports both classification and null states
|
||||
|
||||
### Comprehensive Timeout and Cooldown Systems ✅
|
||||
- ✅ **Post-Session Cooldown**: 30-second cooldown after session clearing
|
||||
- ✅ **Processing Cooldown**: 10-second cooldown for repeated processing
|
||||
- ✅ **Abandonment Timeout**: 3-second timeout for car abandonment detection
|
||||
- ✅ **Vehicle Expiration**: 2-second timeout for tracking cleanup
|
||||
- ✅ **Stream Timeouts**: 30-second connection timeout management
|
||||
|
||||
## 📋 Phase 6: Integration & Final Testing
|
||||
|
||||
### 6.1 Main Application Refactoring
|
||||
- [ ] **Refactor `app.py`**
|
||||
- [ ] Remove extracted functionality
|
||||
- [ ] Update to use new modular structure
|
||||
- [ ] Maintain FastAPI application structure
|
||||
- [ ] Update imports and dependencies
|
||||
|
||||
- [ ] **Clean up `siwatsystem/pympta.py`**
|
||||
- [ ] Remove extracted functionality
|
||||
- [ ] Keep only necessary legacy compatibility code
|
||||
- [ ] Update imports to use new modules
|
||||
|
||||
### 6.2 Post-Session Tracking Validation
|
||||
- [ ] Implement same-car validation after sessionId cleared
|
||||
- [ ] Add logic to prevent duplicate pipeline execution
|
||||
- [ ] Test tracking persistence through session lifecycle
|
||||
- [ ] Verify correct behavior during edge cases
|
||||
|
||||
### 6.3 Configuration & Documentation
|
||||
- [ ] Update configuration handling for new structure
|
||||
- [ ] Ensure `config.json` compatibility maintained
|
||||
- [ ] Update logging configuration for all modules
|
||||
- [ ] Add module-level documentation
|
||||
|
||||
### 6.4 Comprehensive Testing
|
||||
- [ ] **Integration Testing**
|
||||
- [ ] Test complete system flow end-to-end
|
||||
- [ ] Test all WebSocket message types
|
||||
- [ ] Test HTTP API endpoints
|
||||
- [ ] Test error handling and recovery
|
||||
|
||||
- [ ] **Performance Testing**
|
||||
- [ ] Verify system performance is maintained
|
||||
- [ ] Test memory usage optimization
|
||||
- [ ] Test GPU utilization efficiency
|
||||
- [ ] Benchmark against original implementation
|
||||
|
||||
- [ ] **Edge Case Testing**
|
||||
- [ ] Test connection failures and reconnection
|
||||
- [ ] Test model loading failures
|
||||
- [ ] Test stream interruption handling
|
||||
- [ ] Test concurrent subscription management
|
||||
|
||||
### 6.5 Logging Optimization & Cleanup ✅
|
||||
- ✅ **Removed Debug Frame Saving**
|
||||
- ✅ Removed hard-coded debug frame saving in `core/detection/pipeline.py`
|
||||
- ✅ Removed hard-coded debug frame saving in `core/detection/branches.py`
|
||||
- ✅ Eliminated absolute debug paths for production use
|
||||
|
||||
- ✅ **Eliminated Test/Mock Functionality**
|
||||
- ✅ Removed `save_frame_for_testing` function from `core/streaming/buffers.py`
|
||||
- ✅ Removed `save_test_frames` configuration from `StreamConfig`
|
||||
- ✅ Cleaned up test frame saving calls in stream manager
|
||||
- ✅ Updated module exports to remove test functions
|
||||
|
||||
- ✅ **Reduced Verbose Logging**
|
||||
- ✅ Commented out verbose frame storage logging (every frame)
|
||||
- ✅ Converted debug-level info logs to proper debug level
|
||||
- ✅ Reduced repetitive frame dimension logging
|
||||
- ✅ Maintained important model results and detection confidence logging
|
||||
- ✅ Kept critical pipeline execution and error messages
|
||||
|
||||
- ✅ **Production-Ready Logging**
|
||||
- ✅ Clean startup and initialization messages
|
||||
- ✅ Clear model loading and pipeline status
|
||||
- ✅ Preserved detection results with confidence scores
|
||||
- ✅ Maintained session management and tracking messages
|
||||
- ✅ Kept important error and warning messages
|
||||
|
||||
### 6.6 Final Cleanup
|
||||
- [ ] Remove any remaining duplicate code
|
||||
- [ ] Optimize imports across all modules
|
||||
- [ ] Clean up temporary files and debugging code
|
||||
- [ ] Update project documentation
|
||||
|
||||
## 📋 Post-Refactoring Tasks
|
||||
|
||||
### Documentation Updates
|
||||
- [ ] Update `CLAUDE.md` with new architecture
|
||||
- [ ] Create module-specific documentation
|
||||
- [ ] Update installation and deployment guides
|
||||
- [ ] Add troubleshooting guide for new structure
|
||||
|
||||
### Code Quality
|
||||
- [ ] Add type hints to all new modules
|
||||
- [ ] Implement proper error handling patterns
|
||||
- [ ] Add logging consistency across modules
|
||||
- [ ] Ensure proper resource cleanup
|
||||
|
||||
### Future Enhancements (Optional)
|
||||
- [ ] Add unit tests for each module
|
||||
- [ ] Implement monitoring and metrics collection
|
||||
- [ ] Add configuration validation
|
||||
- [ ] Consider adding dependency injection container
|
||||
|
||||
---
|
||||
|
||||
## Success Criteria
|
||||
|
||||
✅ **Modularity**: Each module has a single, clear responsibility
|
||||
✅ **Testability**: Each phase can be tested independently
|
||||
✅ **Maintainability**: Code is easy to understand and modify
|
||||
✅ **Compatibility**: All existing functionality preserved
|
||||
✅ **Performance**: System performance is maintained or improved
|
||||
✅ **Documentation**: Clear documentation for new architecture
|
||||
|
||||
## Risk Mitigation
|
||||
|
||||
- **Feature-by-feature testing** ensures functionality is preserved at each step
|
||||
- **Gradual migration** minimizes risk of breaking existing functionality
|
||||
- **Preserve critical interfaces** (WebSocket protocol, HTTP endpoints)
|
||||
- **Maintain backward compatibility** with existing configurations
|
||||
- **Comprehensive testing** at each phase before proceeding
|
||||
|
||||
---
|
||||
|
||||
## 🎯 Current Status Summary
|
||||
|
||||
### ✅ Completed Phases (95% Complete)
|
||||
- **Phase 1**: Communication Layer - ✅ COMPLETED
|
||||
- **Phase 2**: Pipeline Configuration & Model Management - ✅ COMPLETED
|
||||
- **Phase 3**: Streaming System - ✅ COMPLETED
|
||||
- **Phase 4**: Vehicle Tracking System - ✅ COMPLETED
|
||||
- **Phase 5**: Detection Pipeline System - ✅ COMPLETED
|
||||
- **Additional Features**: License Plate Recognition, Car Abandonment, Session Management - ✅ COMPLETED
|
||||
|
||||
### 📋 Remaining Work (5%)
|
||||
- **Phase 6**: Final Integration & Testing
|
||||
- Main application cleanup (`app.py` and `pympta.py`)
|
||||
- Comprehensive integration testing
|
||||
- Performance benchmarking
|
||||
- Documentation updates
|
||||
|
||||
### 🚀 Production Ready Features
|
||||
- ✅ **Modular Architecture**: ~4000 lines refactored into 20+ focused modules
|
||||
- ✅ **WebSocket Protocol**: Full compliance with all message types
|
||||
- ✅ **License Plate Recognition**: External LPR service integration via Redis
|
||||
- ✅ **Car Abandonment Detection**: Automatic detection and notification
|
||||
- ✅ **Session Management**: Complete lifecycle with progression stages
|
||||
- ✅ **Parallel Processing**: ThreadPoolExecutor for branch execution
|
||||
- ✅ **Redis Integration**: Pub/sub, image storage, LPR subscription
|
||||
- ✅ **PostgreSQL Integration**: Automatic schema management, combined updates
|
||||
- ✅ **Stream Optimization**: Shared streams, format-specific handling
|
||||
- ✅ **Error Recovery**: H.264 corruption detection, automatic reconnection
|
||||
- ✅ **Production Logging**: Clean, informative logging without debug clutter
|
||||
|
||||
### 📊 Metrics
|
||||
- **Modules Created**: 20+ specialized modules
|
||||
- **Lines Per Module**: ~200-500 (highly maintainable)
|
||||
- **Test Coverage**: Feature-by-feature validation completed
|
||||
- **Performance**: Maintained or improved from original implementation
|
||||
- **Backward Compatibility**: 100% preserved
|
||||
903
archive/app.py
903
archive/app.py
|
|
@ -1,903 +0,0 @@
|
|||
from typing import Any, Dict
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import queue
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
import base64
|
||||
import logging
|
||||
import threading
|
||||
import requests
|
||||
import asyncio
|
||||
import psutil
|
||||
import zipfile
|
||||
from urllib.parse import urlparse
|
||||
from fastapi import FastAPI, WebSocket, HTTPException
|
||||
from fastapi.websockets import WebSocketDisconnect
|
||||
from fastapi.responses import Response
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Import shared pipeline functions
|
||||
from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Global dictionaries to keep track of models and streams
|
||||
# "models" now holds a nested dict: { camera_id: { modelId: model_tree } }
|
||||
models: Dict[str, Dict[str, Any]] = {}
|
||||
streams: Dict[str, Dict[str, Any]] = {}
|
||||
# Store session IDs per display
|
||||
session_ids: Dict[str, int] = {}
|
||||
# Track shared camera streams by camera URL
|
||||
camera_streams: Dict[str, Dict[str, Any]] = {}
|
||||
# Map subscriptions to their camera URL
|
||||
subscription_to_camera: Dict[str, str] = {}
|
||||
# Store latest frames for REST API access (separate from processing buffer)
|
||||
latest_frames: Dict[str, Any] = {}
|
||||
|
||||
with open("config.json", "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
poll_interval = config.get("poll_interval_ms", 100)
|
||||
reconnect_interval = config.get("reconnect_interval_sec", 5)
|
||||
TARGET_FPS = config.get("target_fps", 10)
|
||||
poll_interval = 1000 / TARGET_FPS
|
||||
logging.info(f"Poll interval: {poll_interval}ms")
|
||||
max_streams = config.get("max_streams", 5)
|
||||
max_retries = config.get("max_retries", 3)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, # Set to INFO level for less verbose output
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler("detector_worker.log"), # Write logs to a file
|
||||
logging.StreamHandler() # Also output to console
|
||||
]
|
||||
)
|
||||
|
||||
# Create a logger specifically for this application
|
||||
logger = logging.getLogger("detector_worker")
|
||||
logger.setLevel(logging.DEBUG) # Set app-specific logger to DEBUG level
|
||||
|
||||
# Ensure all other libraries (including root) use at least INFO level
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
logger.info("Starting detector worker application")
|
||||
logger.info(f"Configuration: Target FPS: {TARGET_FPS}, Max streams: {max_streams}, Max retries: {max_retries}")
|
||||
|
||||
# Ensure the models directory exists
|
||||
os.makedirs("models", exist_ok=True)
|
||||
logger.info("Ensured models directory exists")
|
||||
|
||||
# Constants for heartbeat and timeouts
|
||||
HEARTBEAT_INTERVAL = 2 # seconds
|
||||
WORKER_TIMEOUT_MS = 10000
|
||||
logger.debug(f"Heartbeat interval set to {HEARTBEAT_INTERVAL} seconds")
|
||||
|
||||
# Locks for thread-safe operations
|
||||
streams_lock = threading.Lock()
|
||||
models_lock = threading.Lock()
|
||||
logger.debug("Initialized thread locks")
|
||||
|
||||
# Add helper to download mpta ZIP file from a remote URL
|
||||
def download_mpta(url: str, dest_path: str) -> str:
|
||||
try:
|
||||
logger.info(f"Starting download of model from {url} to {dest_path}")
|
||||
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
|
||||
response = requests.get(url, stream=True)
|
||||
if response.status_code == 200:
|
||||
file_size = int(response.headers.get('content-length', 0))
|
||||
logger.info(f"Model file size: {file_size/1024/1024:.2f} MB")
|
||||
downloaded = 0
|
||||
with open(dest_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if file_size > 0 and downloaded % (file_size // 10) < 8192: # Log approximately every 10%
|
||||
logger.debug(f"Download progress: {downloaded/file_size*100:.1f}%")
|
||||
logger.info(f"Successfully downloaded mpta file from {url} to {dest_path}")
|
||||
return dest_path
|
||||
else:
|
||||
logger.error(f"Failed to download mpta file (status code {response.status_code}): {response.text}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Exception downloading mpta file from {url}: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
# Add helper to fetch snapshot image from HTTP/HTTPS URL
|
||||
def fetch_snapshot(url: str):
|
||||
try:
|
||||
from requests.auth import HTTPBasicAuth, HTTPDigestAuth
|
||||
|
||||
# Parse URL to extract credentials
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Prepare headers - some cameras require User-Agent
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (compatible; DetectorWorker/1.0)'
|
||||
}
|
||||
|
||||
# Reconstruct URL without credentials
|
||||
clean_url = f"{parsed.scheme}://{parsed.hostname}"
|
||||
if parsed.port:
|
||||
clean_url += f":{parsed.port}"
|
||||
clean_url += parsed.path
|
||||
if parsed.query:
|
||||
clean_url += f"?{parsed.query}"
|
||||
|
||||
auth = None
|
||||
if parsed.username and parsed.password:
|
||||
# Try HTTP Digest authentication first (common for IP cameras)
|
||||
try:
|
||||
auth = HTTPDigestAuth(parsed.username, parsed.password)
|
||||
response = requests.get(clean_url, auth=auth, headers=headers, timeout=10)
|
||||
if response.status_code == 200:
|
||||
logger.debug(f"Successfully authenticated using HTTP Digest for {clean_url}")
|
||||
elif response.status_code == 401:
|
||||
# If Digest fails, try Basic auth
|
||||
logger.debug(f"HTTP Digest failed, trying Basic auth for {clean_url}")
|
||||
auth = HTTPBasicAuth(parsed.username, parsed.password)
|
||||
response = requests.get(clean_url, auth=auth, headers=headers, timeout=10)
|
||||
if response.status_code == 200:
|
||||
logger.debug(f"Successfully authenticated using HTTP Basic for {clean_url}")
|
||||
except Exception as auth_error:
|
||||
logger.debug(f"Authentication setup error: {auth_error}")
|
||||
# Fallback to original URL with embedded credentials
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
else:
|
||||
# No credentials in URL, make request as-is
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
|
||||
if response.status_code == 200:
|
||||
# Convert response content to numpy array
|
||||
nparr = np.frombuffer(response.content, np.uint8)
|
||||
# Decode image
|
||||
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
if frame is not None:
|
||||
logger.debug(f"Successfully fetched snapshot from {clean_url}, shape: {frame.shape}")
|
||||
return frame
|
||||
else:
|
||||
logger.error(f"Failed to decode image from snapshot URL: {clean_url}")
|
||||
return None
|
||||
else:
|
||||
logger.error(f"Failed to fetch snapshot (status code {response.status_code}): {clean_url}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Exception fetching snapshot from {url}: {str(e)}")
|
||||
return None
|
||||
|
||||
# Helper to get crop coordinates from stream
|
||||
def get_crop_coords(stream):
|
||||
return {
|
||||
"cropX1": stream.get("cropX1"),
|
||||
"cropY1": stream.get("cropY1"),
|
||||
"cropX2": stream.get("cropX2"),
|
||||
"cropY2": stream.get("cropY2")
|
||||
}
|
||||
|
||||
####################################################
|
||||
# REST API endpoint for image retrieval
|
||||
####################################################
|
||||
@app.get("/camera/{camera_id}/image")
|
||||
async def get_camera_image(camera_id: str):
|
||||
"""
|
||||
Get the current frame from a camera as JPEG image
|
||||
"""
|
||||
try:
|
||||
# URL decode the camera_id to handle encoded characters like %3B for semicolon
|
||||
from urllib.parse import unquote
|
||||
original_camera_id = camera_id
|
||||
camera_id = unquote(camera_id)
|
||||
logger.debug(f"REST API request: original='{original_camera_id}', decoded='{camera_id}'")
|
||||
|
||||
with streams_lock:
|
||||
if camera_id not in streams:
|
||||
logger.warning(f"Camera ID '{camera_id}' not found in streams. Current streams: {list(streams.keys())}")
|
||||
raise HTTPException(status_code=404, detail=f"Camera {camera_id} not found or not active")
|
||||
|
||||
# Check if we have a cached frame for this camera
|
||||
if camera_id not in latest_frames:
|
||||
logger.warning(f"No cached frame available for camera '{camera_id}'.")
|
||||
raise HTTPException(status_code=404, detail=f"No frame available for camera {camera_id}")
|
||||
|
||||
frame = latest_frames[camera_id]
|
||||
logger.debug(f"Retrieved cached frame for camera '{camera_id}', frame shape: {frame.shape}")
|
||||
# Encode frame as JPEG
|
||||
success, buffer_img = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to encode image as JPEG")
|
||||
|
||||
# Return image as binary response
|
||||
return Response(content=buffer_img.tobytes(), media_type="image/jpeg")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving image for camera {camera_id}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
####################################################
|
||||
# Detection and frame processing functions
|
||||
####################################################
|
||||
@app.websocket("/")
|
||||
async def detect(websocket: WebSocket):
|
||||
logger.info("WebSocket connection accepted")
|
||||
persistent_data_dict = {}
|
||||
|
||||
async def handle_detection(camera_id, stream, frame, websocket, model_tree, persistent_data):
|
||||
try:
|
||||
# Apply crop if specified
|
||||
cropped_frame = frame
|
||||
if all(coord is not None for coord in [stream.get("cropX1"), stream.get("cropY1"), stream.get("cropX2"), stream.get("cropY2")]):
|
||||
cropX1, cropY1, cropX2, cropY2 = stream["cropX1"], stream["cropY1"], stream["cropX2"], stream["cropY2"]
|
||||
cropped_frame = frame[cropY1:cropY2, cropX1:cropX2]
|
||||
logger.debug(f"Applied crop coordinates ({cropX1}, {cropY1}, {cropX2}, {cropY2}) to frame for camera {camera_id}")
|
||||
|
||||
logger.debug(f"Processing frame for camera {camera_id} with model {stream['modelId']}")
|
||||
start_time = time.time()
|
||||
|
||||
# Extract display identifier for session ID lookup
|
||||
subscription_parts = stream["subscriptionIdentifier"].split(';')
|
||||
display_identifier = subscription_parts[0] if subscription_parts else None
|
||||
session_id = session_ids.get(display_identifier) if display_identifier else None
|
||||
|
||||
# Create context for pipeline execution
|
||||
pipeline_context = {
|
||||
"camera_id": camera_id,
|
||||
"display_id": display_identifier,
|
||||
"session_id": session_id
|
||||
}
|
||||
|
||||
detection_result = run_pipeline(cropped_frame, model_tree, context=pipeline_context)
|
||||
process_time = (time.time() - start_time) * 1000
|
||||
logger.debug(f"Detection for camera {camera_id} completed in {process_time:.2f}ms")
|
||||
|
||||
# Log the raw detection result for debugging
|
||||
logger.debug(f"Raw detection result for camera {camera_id}:\n{json.dumps(detection_result, indent=2, default=str)}")
|
||||
|
||||
# Direct class result (no detections/classifications structure)
|
||||
if detection_result and isinstance(detection_result, dict) and "class" in detection_result and "confidence" in detection_result:
|
||||
highest_confidence_detection = {
|
||||
"class": detection_result.get("class", "none"),
|
||||
"confidence": detection_result.get("confidence", 1.0),
|
||||
"box": [0, 0, 0, 0] # Empty bounding box for classifications
|
||||
}
|
||||
# Handle case when no detections found or result is empty
|
||||
elif not detection_result or not detection_result.get("detections"):
|
||||
# Check if we have classification results
|
||||
if detection_result and detection_result.get("classifications"):
|
||||
# Get the highest confidence classification
|
||||
classifications = detection_result.get("classifications", [])
|
||||
highest_confidence_class = max(classifications, key=lambda x: x.get("confidence", 0)) if classifications else None
|
||||
|
||||
if highest_confidence_class:
|
||||
highest_confidence_detection = {
|
||||
"class": highest_confidence_class.get("class", "none"),
|
||||
"confidence": highest_confidence_class.get("confidence", 1.0),
|
||||
"box": [0, 0, 0, 0] # Empty bounding box for classifications
|
||||
}
|
||||
else:
|
||||
highest_confidence_detection = {
|
||||
"class": "none",
|
||||
"confidence": 1.0,
|
||||
"box": [0, 0, 0, 0]
|
||||
}
|
||||
else:
|
||||
highest_confidence_detection = {
|
||||
"class": "none",
|
||||
"confidence": 1.0,
|
||||
"box": [0, 0, 0, 0]
|
||||
}
|
||||
else:
|
||||
# Find detection with highest confidence
|
||||
detections = detection_result.get("detections", [])
|
||||
highest_confidence_detection = max(detections, key=lambda x: x.get("confidence", 0)) if detections else {
|
||||
"class": "none",
|
||||
"confidence": 1.0,
|
||||
"box": [0, 0, 0, 0]
|
||||
}
|
||||
|
||||
# Convert detection format to match protocol - flatten detection attributes
|
||||
detection_dict = {}
|
||||
|
||||
# Handle different detection result formats
|
||||
if isinstance(highest_confidence_detection, dict):
|
||||
# Copy all fields from the detection result
|
||||
for key, value in highest_confidence_detection.items():
|
||||
if key not in ["box", "id"]: # Skip internal fields
|
||||
detection_dict[key] = value
|
||||
|
||||
detection_data = {
|
||||
"type": "imageDetection",
|
||||
"subscriptionIdentifier": stream["subscriptionIdentifier"],
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S.%fZ", time.gmtime()),
|
||||
"data": {
|
||||
"detection": detection_dict,
|
||||
"modelId": stream["modelId"],
|
||||
"modelName": stream["modelName"]
|
||||
}
|
||||
}
|
||||
|
||||
# Add session ID if available
|
||||
if session_id is not None:
|
||||
detection_data["sessionId"] = session_id
|
||||
|
||||
if highest_confidence_detection["class"] != "none":
|
||||
logger.info(f"Camera {camera_id}: Detected {highest_confidence_detection['class']} with confidence {highest_confidence_detection['confidence']:.2f} using model {stream['modelName']}")
|
||||
|
||||
# Log session ID if available
|
||||
if session_id:
|
||||
logger.debug(f"Detection associated with session ID: {session_id}")
|
||||
|
||||
await websocket.send_json(detection_data)
|
||||
logger.debug(f"Sent detection data to client for camera {camera_id}")
|
||||
return persistent_data
|
||||
except Exception as e:
|
||||
logger.error(f"Error in handle_detection for camera {camera_id}: {str(e)}", exc_info=True)
|
||||
return persistent_data
|
||||
|
||||
def frame_reader(camera_id, cap, buffer, stop_event):
|
||||
retries = 0
|
||||
logger.info(f"Starting frame reader thread for camera {camera_id}")
|
||||
frame_count = 0
|
||||
last_log_time = time.time()
|
||||
|
||||
try:
|
||||
# Log initial camera status and properties
|
||||
if cap.isOpened():
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
logger.info(f"Camera {camera_id} opened successfully with resolution {width}x{height}, FPS: {fps}")
|
||||
else:
|
||||
logger.error(f"Camera {camera_id} failed to open initially")
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
if not cap.isOpened():
|
||||
logger.error(f"Camera {camera_id} is not open before trying to read")
|
||||
# Attempt to reopen
|
||||
cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"])
|
||||
time.sleep(reconnect_interval)
|
||||
continue
|
||||
|
||||
logger.debug(f"Attempting to read frame from camera {camera_id}")
|
||||
ret, frame = cap.read()
|
||||
|
||||
if not ret:
|
||||
logger.warning(f"Connection lost for camera: {camera_id}, retry {retries+1}/{max_retries}")
|
||||
cap.release()
|
||||
time.sleep(reconnect_interval)
|
||||
retries += 1
|
||||
if retries > max_retries and max_retries != -1:
|
||||
logger.error(f"Max retries reached for camera: {camera_id}, stopping frame reader")
|
||||
break
|
||||
# Re-open
|
||||
logger.info(f"Attempting to reopen RTSP stream for camera: {camera_id}")
|
||||
cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"])
|
||||
if not cap.isOpened():
|
||||
logger.error(f"Failed to reopen RTSP stream for camera: {camera_id}")
|
||||
continue
|
||||
logger.info(f"Successfully reopened RTSP stream for camera: {camera_id}")
|
||||
continue
|
||||
|
||||
# Successfully read a frame
|
||||
frame_count += 1
|
||||
current_time = time.time()
|
||||
# Log frame stats every 5 seconds
|
||||
if current_time - last_log_time > 5:
|
||||
logger.info(f"Camera {camera_id}: Read {frame_count} frames in the last {current_time - last_log_time:.1f} seconds")
|
||||
frame_count = 0
|
||||
last_log_time = current_time
|
||||
|
||||
logger.debug(f"Successfully read frame from camera {camera_id}, shape: {frame.shape}")
|
||||
retries = 0
|
||||
|
||||
# Overwrite old frame if buffer is full
|
||||
if not buffer.empty():
|
||||
try:
|
||||
buffer.get_nowait()
|
||||
logger.debug(f"[frame_reader] Removed old frame from buffer for camera {camera_id}")
|
||||
except queue.Empty:
|
||||
pass
|
||||
buffer.put(frame)
|
||||
logger.debug(f"[frame_reader] Added new frame to buffer for camera {camera_id}. Buffer size: {buffer.qsize()}")
|
||||
|
||||
# Short sleep to avoid CPU overuse
|
||||
time.sleep(0.01)
|
||||
|
||||
except cv2.error as e:
|
||||
logger.error(f"OpenCV error for camera {camera_id}: {e}", exc_info=True)
|
||||
cap.release()
|
||||
time.sleep(reconnect_interval)
|
||||
retries += 1
|
||||
if retries > max_retries and max_retries != -1:
|
||||
logger.error(f"Max retries reached after OpenCV error for camera {camera_id}")
|
||||
break
|
||||
logger.info(f"Attempting to reopen RTSP stream after OpenCV error for camera: {camera_id}")
|
||||
cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"])
|
||||
if not cap.isOpened():
|
||||
logger.error(f"Failed to reopen RTSP stream for camera {camera_id} after OpenCV error")
|
||||
continue
|
||||
logger.info(f"Successfully reopened RTSP stream after OpenCV error for camera: {camera_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error for camera {camera_id}: {str(e)}", exc_info=True)
|
||||
cap.release()
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in frame_reader thread for camera {camera_id}: {str(e)}", exc_info=True)
|
||||
finally:
|
||||
logger.info(f"Frame reader thread for camera {camera_id} is exiting")
|
||||
if cap and cap.isOpened():
|
||||
cap.release()
|
||||
|
||||
def snapshot_reader(camera_id, snapshot_url, snapshot_interval, buffer, stop_event):
|
||||
"""Frame reader that fetches snapshots from HTTP/HTTPS URL at specified intervals"""
|
||||
retries = 0
|
||||
logger.info(f"Starting snapshot reader thread for camera {camera_id} from {snapshot_url}")
|
||||
frame_count = 0
|
||||
last_log_time = time.time()
|
||||
|
||||
try:
|
||||
interval_seconds = snapshot_interval / 1000.0 # Convert milliseconds to seconds
|
||||
logger.info(f"Snapshot interval for camera {camera_id}: {interval_seconds}s")
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
start_time = time.time()
|
||||
frame = fetch_snapshot(snapshot_url)
|
||||
|
||||
if frame is None:
|
||||
logger.warning(f"Failed to fetch snapshot for camera: {camera_id}, retry {retries+1}/{max_retries}")
|
||||
retries += 1
|
||||
if retries > max_retries and max_retries != -1:
|
||||
logger.error(f"Max retries reached for snapshot camera: {camera_id}, stopping reader")
|
||||
break
|
||||
time.sleep(min(interval_seconds, reconnect_interval))
|
||||
continue
|
||||
|
||||
# Successfully fetched a frame
|
||||
frame_count += 1
|
||||
current_time = time.time()
|
||||
# Log frame stats every 5 seconds
|
||||
if current_time - last_log_time > 5:
|
||||
logger.info(f"Camera {camera_id}: Fetched {frame_count} snapshots in the last {current_time - last_log_time:.1f} seconds")
|
||||
frame_count = 0
|
||||
last_log_time = current_time
|
||||
|
||||
logger.debug(f"Successfully fetched snapshot from camera {camera_id}, shape: {frame.shape}")
|
||||
retries = 0
|
||||
|
||||
# Overwrite old frame if buffer is full
|
||||
if not buffer.empty():
|
||||
try:
|
||||
buffer.get_nowait()
|
||||
logger.debug(f"[snapshot_reader] Removed old snapshot from buffer for camera {camera_id}")
|
||||
except queue.Empty:
|
||||
pass
|
||||
buffer.put(frame)
|
||||
logger.debug(f"[snapshot_reader] Added new snapshot to buffer for camera {camera_id}. Buffer size: {buffer.qsize()}")
|
||||
|
||||
# Wait for the specified interval
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(interval_seconds - elapsed, 0)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error fetching snapshot for camera {camera_id}: {str(e)}", exc_info=True)
|
||||
retries += 1
|
||||
if retries > max_retries and max_retries != -1:
|
||||
logger.error(f"Max retries reached after error for snapshot camera {camera_id}")
|
||||
break
|
||||
time.sleep(min(interval_seconds, reconnect_interval))
|
||||
except Exception as e:
|
||||
logger.error(f"Error in snapshot_reader thread for camera {camera_id}: {str(e)}", exc_info=True)
|
||||
finally:
|
||||
logger.info(f"Snapshot reader thread for camera {camera_id} is exiting")
|
||||
|
||||
async def process_streams():
|
||||
logger.info("Started processing streams")
|
||||
try:
|
||||
while True:
|
||||
start_time = time.time()
|
||||
with streams_lock:
|
||||
current_streams = list(streams.items())
|
||||
if current_streams:
|
||||
logger.debug(f"Processing {len(current_streams)} active streams")
|
||||
else:
|
||||
logger.debug("No active streams to process")
|
||||
|
||||
for camera_id, stream in current_streams:
|
||||
buffer = stream["buffer"]
|
||||
if buffer.empty():
|
||||
logger.debug(f"Frame buffer is empty for camera {camera_id}")
|
||||
continue
|
||||
|
||||
logger.debug(f"Got frame from buffer for camera {camera_id}")
|
||||
frame = buffer.get()
|
||||
|
||||
# Cache the frame for REST API access
|
||||
latest_frames[camera_id] = frame.copy()
|
||||
logger.debug(f"Cached frame for REST API access for camera {camera_id}")
|
||||
|
||||
with models_lock:
|
||||
model_tree = models.get(camera_id, {}).get(stream["modelId"])
|
||||
if not model_tree:
|
||||
logger.warning(f"Model not found for camera {camera_id}, modelId {stream['modelId']}")
|
||||
continue
|
||||
logger.debug(f"Found model tree for camera {camera_id}, modelId {stream['modelId']}")
|
||||
|
||||
key = (camera_id, stream["modelId"])
|
||||
persistent_data = persistent_data_dict.get(key, {})
|
||||
logger.debug(f"Starting detection for camera {camera_id} with modelId {stream['modelId']}")
|
||||
updated_persistent_data = await handle_detection(
|
||||
camera_id, stream, frame, websocket, model_tree, persistent_data
|
||||
)
|
||||
persistent_data_dict[key] = updated_persistent_data
|
||||
|
||||
elapsed_time = (time.time() - start_time) * 1000 # ms
|
||||
sleep_time = max(poll_interval - elapsed_time, 0)
|
||||
logger.debug(f"Frame processing cycle: {elapsed_time:.2f}ms, sleeping for: {sleep_time:.2f}ms")
|
||||
await asyncio.sleep(sleep_time / 1000.0)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Stream processing task cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process_streams: {str(e)}", exc_info=True)
|
||||
|
||||
async def send_heartbeat():
|
||||
while True:
|
||||
try:
|
||||
cpu_usage = psutil.cpu_percent()
|
||||
memory_usage = psutil.virtual_memory().percent
|
||||
if torch.cuda.is_available():
|
||||
gpu_usage = torch.cuda.utilization() if hasattr(torch.cuda, 'utilization') else None
|
||||
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2)
|
||||
else:
|
||||
gpu_usage = None
|
||||
gpu_memory_usage = None
|
||||
|
||||
camera_connections = [
|
||||
{
|
||||
"subscriptionIdentifier": stream["subscriptionIdentifier"],
|
||||
"modelId": stream["modelId"],
|
||||
"modelName": stream["modelName"],
|
||||
"online": True,
|
||||
**{k: v for k, v in get_crop_coords(stream).items() if v is not None}
|
||||
}
|
||||
for camera_id, stream in streams.items()
|
||||
]
|
||||
|
||||
state_report = {
|
||||
"type": "stateReport",
|
||||
"cpuUsage": cpu_usage,
|
||||
"memoryUsage": memory_usage,
|
||||
"gpuUsage": gpu_usage,
|
||||
"gpuMemoryUsage": gpu_memory_usage,
|
||||
"cameraConnections": camera_connections
|
||||
}
|
||||
await websocket.send_text(json.dumps(state_report))
|
||||
logger.debug(f"Sent stateReport as heartbeat: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, {len(camera_connections)} active cameras")
|
||||
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending stateReport heartbeat: {e}")
|
||||
break
|
||||
|
||||
async def on_message():
|
||||
while True:
|
||||
try:
|
||||
msg = await websocket.receive_text()
|
||||
logger.debug(f"Received message: {msg}")
|
||||
data = json.loads(msg)
|
||||
msg_type = data.get("type")
|
||||
|
||||
if msg_type == "subscribe":
|
||||
payload = data.get("payload", {})
|
||||
subscriptionIdentifier = payload.get("subscriptionIdentifier")
|
||||
rtsp_url = payload.get("rtspUrl")
|
||||
snapshot_url = payload.get("snapshotUrl")
|
||||
snapshot_interval = payload.get("snapshotInterval")
|
||||
model_url = payload.get("modelUrl")
|
||||
modelId = payload.get("modelId")
|
||||
modelName = payload.get("modelName")
|
||||
cropX1 = payload.get("cropX1")
|
||||
cropY1 = payload.get("cropY1")
|
||||
cropX2 = payload.get("cropX2")
|
||||
cropY2 = payload.get("cropY2")
|
||||
|
||||
# Extract camera_id from subscriptionIdentifier (format: displayIdentifier;cameraIdentifier)
|
||||
parts = subscriptionIdentifier.split(';')
|
||||
if len(parts) != 2:
|
||||
logger.error(f"Invalid subscriptionIdentifier format: {subscriptionIdentifier}")
|
||||
continue
|
||||
|
||||
display_identifier, camera_identifier = parts
|
||||
camera_id = subscriptionIdentifier # Use full subscriptionIdentifier as camera_id for mapping
|
||||
|
||||
if model_url:
|
||||
with models_lock:
|
||||
if (camera_id not in models) or (modelId not in models[camera_id]):
|
||||
logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}")
|
||||
extraction_dir = os.path.join("models", camera_identifier, str(modelId))
|
||||
os.makedirs(extraction_dir, exist_ok=True)
|
||||
# If model_url is remote, download it first.
|
||||
parsed = urlparse(model_url)
|
||||
if parsed.scheme in ("http", "https"):
|
||||
logger.info(f"Downloading remote .mpta file from {model_url}")
|
||||
filename = os.path.basename(parsed.path) or f"model_{modelId}.mpta"
|
||||
local_mpta = os.path.join(extraction_dir, filename)
|
||||
logger.debug(f"Download destination: {local_mpta}")
|
||||
local_path = download_mpta(model_url, local_mpta)
|
||||
if not local_path:
|
||||
logger.error(f"Failed to download the remote .mpta file from {model_url}")
|
||||
error_response = {
|
||||
"type": "error",
|
||||
"subscriptionIdentifier": subscriptionIdentifier,
|
||||
"error": f"Failed to download model from {model_url}"
|
||||
}
|
||||
await websocket.send_json(error_response)
|
||||
continue
|
||||
model_tree = load_pipeline_from_zip(local_path, extraction_dir)
|
||||
else:
|
||||
logger.info(f"Loading local .mpta file from {model_url}")
|
||||
# Check if file exists before attempting to load
|
||||
if not os.path.exists(model_url):
|
||||
logger.error(f"Local .mpta file not found: {model_url}")
|
||||
logger.debug(f"Current working directory: {os.getcwd()}")
|
||||
error_response = {
|
||||
"type": "error",
|
||||
"subscriptionIdentifier": subscriptionIdentifier,
|
||||
"error": f"Model file not found: {model_url}"
|
||||
}
|
||||
await websocket.send_json(error_response)
|
||||
continue
|
||||
model_tree = load_pipeline_from_zip(model_url, extraction_dir)
|
||||
if model_tree is None:
|
||||
logger.error(f"Failed to load model {modelId} from .mpta file for camera {camera_id}")
|
||||
error_response = {
|
||||
"type": "error",
|
||||
"subscriptionIdentifier": subscriptionIdentifier,
|
||||
"error": f"Failed to load model {modelId}"
|
||||
}
|
||||
await websocket.send_json(error_response)
|
||||
continue
|
||||
if camera_id not in models:
|
||||
models[camera_id] = {}
|
||||
models[camera_id][modelId] = model_tree
|
||||
logger.info(f"Successfully loaded model {modelId} for camera {camera_id}")
|
||||
logger.debug(f"Model extraction directory: {extraction_dir}")
|
||||
if camera_id and (rtsp_url or snapshot_url):
|
||||
with streams_lock:
|
||||
# Determine camera URL for shared stream management
|
||||
camera_url = snapshot_url if snapshot_url else rtsp_url
|
||||
|
||||
if camera_id not in streams and len(streams) < max_streams:
|
||||
# Check if we already have a stream for this camera URL
|
||||
shared_stream = camera_streams.get(camera_url)
|
||||
|
||||
if shared_stream:
|
||||
# Reuse existing stream
|
||||
logger.info(f"Reusing existing stream for camera URL: {camera_url}")
|
||||
buffer = shared_stream["buffer"]
|
||||
stop_event = shared_stream["stop_event"]
|
||||
thread = shared_stream["thread"]
|
||||
mode = shared_stream["mode"]
|
||||
|
||||
# Increment reference count
|
||||
shared_stream["ref_count"] = shared_stream.get("ref_count", 0) + 1
|
||||
else:
|
||||
# Create new stream
|
||||
buffer = queue.Queue(maxsize=1)
|
||||
stop_event = threading.Event()
|
||||
|
||||
if snapshot_url and snapshot_interval:
|
||||
logger.info(f"Creating new snapshot stream for camera {camera_id}: {snapshot_url}")
|
||||
thread = threading.Thread(target=snapshot_reader, args=(camera_id, snapshot_url, snapshot_interval, buffer, stop_event))
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
mode = "snapshot"
|
||||
|
||||
# Store shared stream info
|
||||
shared_stream = {
|
||||
"buffer": buffer,
|
||||
"thread": thread,
|
||||
"stop_event": stop_event,
|
||||
"mode": mode,
|
||||
"url": snapshot_url,
|
||||
"snapshot_interval": snapshot_interval,
|
||||
"ref_count": 1
|
||||
}
|
||||
camera_streams[camera_url] = shared_stream
|
||||
|
||||
elif rtsp_url:
|
||||
logger.info(f"Creating new RTSP stream for camera {camera_id}: {rtsp_url}")
|
||||
cap = cv2.VideoCapture(rtsp_url)
|
||||
if not cap.isOpened():
|
||||
logger.error(f"Failed to open RTSP stream for camera {camera_id}")
|
||||
continue
|
||||
thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event))
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
mode = "rtsp"
|
||||
|
||||
# Store shared stream info
|
||||
shared_stream = {
|
||||
"buffer": buffer,
|
||||
"thread": thread,
|
||||
"stop_event": stop_event,
|
||||
"mode": mode,
|
||||
"url": rtsp_url,
|
||||
"cap": cap,
|
||||
"ref_count": 1
|
||||
}
|
||||
camera_streams[camera_url] = shared_stream
|
||||
else:
|
||||
logger.error(f"No valid URL provided for camera {camera_id}")
|
||||
continue
|
||||
|
||||
# Create stream info for this subscription
|
||||
stream_info = {
|
||||
"buffer": buffer,
|
||||
"thread": thread,
|
||||
"stop_event": stop_event,
|
||||
"modelId": modelId,
|
||||
"modelName": modelName,
|
||||
"subscriptionIdentifier": subscriptionIdentifier,
|
||||
"cropX1": cropX1,
|
||||
"cropY1": cropY1,
|
||||
"cropX2": cropX2,
|
||||
"cropY2": cropY2,
|
||||
"mode": mode,
|
||||
"camera_url": camera_url
|
||||
}
|
||||
|
||||
if mode == "snapshot":
|
||||
stream_info["snapshot_url"] = snapshot_url
|
||||
stream_info["snapshot_interval"] = snapshot_interval
|
||||
elif mode == "rtsp":
|
||||
stream_info["rtsp_url"] = rtsp_url
|
||||
stream_info["cap"] = shared_stream["cap"]
|
||||
|
||||
streams[camera_id] = stream_info
|
||||
subscription_to_camera[camera_id] = camera_url
|
||||
|
||||
elif camera_id and camera_id in streams:
|
||||
# If already subscribed, unsubscribe first
|
||||
logger.info(f"Resubscribing to camera {camera_id}")
|
||||
# Note: Keep models in memory for reuse across subscriptions
|
||||
elif msg_type == "unsubscribe":
|
||||
payload = data.get("payload", {})
|
||||
subscriptionIdentifier = payload.get("subscriptionIdentifier")
|
||||
camera_id = subscriptionIdentifier
|
||||
with streams_lock:
|
||||
if camera_id and camera_id in streams:
|
||||
stream = streams.pop(camera_id)
|
||||
camera_url = subscription_to_camera.pop(camera_id, None)
|
||||
|
||||
if camera_url and camera_url in camera_streams:
|
||||
shared_stream = camera_streams[camera_url]
|
||||
shared_stream["ref_count"] -= 1
|
||||
|
||||
# If no more references, stop the shared stream
|
||||
if shared_stream["ref_count"] <= 0:
|
||||
logger.info(f"Stopping shared stream for camera URL: {camera_url}")
|
||||
shared_stream["stop_event"].set()
|
||||
shared_stream["thread"].join()
|
||||
if "cap" in shared_stream:
|
||||
shared_stream["cap"].release()
|
||||
del camera_streams[camera_url]
|
||||
else:
|
||||
logger.info(f"Shared stream for {camera_url} still has {shared_stream['ref_count']} references")
|
||||
|
||||
# Clean up cached frame
|
||||
latest_frames.pop(camera_id, None)
|
||||
logger.info(f"Unsubscribed from camera {camera_id}")
|
||||
# Note: Keep models in memory for potential reuse
|
||||
elif msg_type == "requestState":
|
||||
cpu_usage = psutil.cpu_percent()
|
||||
memory_usage = psutil.virtual_memory().percent
|
||||
if torch.cuda.is_available():
|
||||
gpu_usage = torch.cuda.utilization() if hasattr(torch.cuda, 'utilization') else None
|
||||
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2)
|
||||
else:
|
||||
gpu_usage = None
|
||||
gpu_memory_usage = None
|
||||
|
||||
camera_connections = [
|
||||
{
|
||||
"subscriptionIdentifier": stream["subscriptionIdentifier"],
|
||||
"modelId": stream["modelId"],
|
||||
"modelName": stream["modelName"],
|
||||
"online": True,
|
||||
**{k: v for k, v in get_crop_coords(stream).items() if v is not None}
|
||||
}
|
||||
for camera_id, stream in streams.items()
|
||||
]
|
||||
|
||||
state_report = {
|
||||
"type": "stateReport",
|
||||
"cpuUsage": cpu_usage,
|
||||
"memoryUsage": memory_usage,
|
||||
"gpuUsage": gpu_usage,
|
||||
"gpuMemoryUsage": gpu_memory_usage,
|
||||
"cameraConnections": camera_connections
|
||||
}
|
||||
await websocket.send_text(json.dumps(state_report))
|
||||
|
||||
elif msg_type == "setSessionId":
|
||||
payload = data.get("payload", {})
|
||||
display_identifier = payload.get("displayIdentifier")
|
||||
session_id = payload.get("sessionId")
|
||||
|
||||
if display_identifier:
|
||||
# Store session ID for this display
|
||||
if session_id is None:
|
||||
session_ids.pop(display_identifier, None)
|
||||
logger.info(f"Cleared session ID for display {display_identifier}")
|
||||
else:
|
||||
session_ids[display_identifier] = session_id
|
||||
logger.info(f"Set session ID {session_id} for display {display_identifier}")
|
||||
|
||||
elif msg_type == "patchSession":
|
||||
session_id = data.get("sessionId")
|
||||
patch_data = data.get("data", {})
|
||||
|
||||
# For now, just acknowledge the patch - actual implementation depends on backend requirements
|
||||
response = {
|
||||
"type": "patchSessionResult",
|
||||
"payload": {
|
||||
"sessionId": session_id,
|
||||
"success": True,
|
||||
"message": "Session patch acknowledged"
|
||||
}
|
||||
}
|
||||
await websocket.send_json(response)
|
||||
logger.info(f"Acknowledged patch for session {session_id}")
|
||||
|
||||
else:
|
||||
logger.error(f"Unknown message type: {msg_type}")
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Received invalid JSON message")
|
||||
except (WebSocketDisconnect, ConnectionClosedError) as e:
|
||||
logger.warning(f"WebSocket disconnected: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}")
|
||||
break
|
||||
try:
|
||||
await websocket.accept()
|
||||
stream_task = asyncio.create_task(process_streams())
|
||||
heartbeat_task = asyncio.create_task(send_heartbeat())
|
||||
message_task = asyncio.create_task(on_message())
|
||||
await asyncio.gather(heartbeat_task, message_task)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in detect websocket: {e}")
|
||||
finally:
|
||||
stream_task.cancel()
|
||||
await stream_task
|
||||
with streams_lock:
|
||||
# Clean up shared camera streams
|
||||
for camera_url, shared_stream in camera_streams.items():
|
||||
shared_stream["stop_event"].set()
|
||||
shared_stream["thread"].join()
|
||||
if "cap" in shared_stream:
|
||||
shared_stream["cap"].release()
|
||||
while not shared_stream["buffer"].empty():
|
||||
try:
|
||||
shared_stream["buffer"].get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
logger.info(f"Released shared camera stream for {camera_url}")
|
||||
|
||||
streams.clear()
|
||||
camera_streams.clear()
|
||||
subscription_to_camera.clear()
|
||||
with models_lock:
|
||||
models.clear()
|
||||
latest_frames.clear()
|
||||
session_ids.clear()
|
||||
logger.info("WebSocket connection closed")
|
||||
|
|
@ -1,9 +1,7 @@
|
|||
{
|
||||
"poll_interval_ms": 100,
|
||||
"max_streams": 20,
|
||||
"max_streams": 5,
|
||||
"target_fps": 2,
|
||||
"reconnect_interval_sec": 10,
|
||||
"max_retries": -1,
|
||||
"rtsp_buffer_size": 3,
|
||||
"rtsp_tcp_transport": true
|
||||
"reconnect_interval_sec": 5,
|
||||
"max_retries": -1
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
# Core package for detector worker
|
||||
|
|
@ -1 +0,0 @@
|
|||
# Communication module for WebSocket and HTTP handling
|
||||
|
|
@ -1,212 +0,0 @@
|
|||
"""
|
||||
Message types, constants, and validation functions for WebSocket communication.
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from .models import (
|
||||
IncomingMessage, OutgoingMessage,
|
||||
SetSubscriptionListMessage, SetSessionIdMessage, SetProgressionStageMessage,
|
||||
RequestStateMessage, PatchSessionResultMessage,
|
||||
StateReportMessage, ImageDetectionMessage, PatchSessionMessage
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Message type constants
|
||||
class MessageTypes:
|
||||
"""WebSocket message type constants."""
|
||||
|
||||
# Incoming from backend
|
||||
SET_SUBSCRIPTION_LIST = "setSubscriptionList"
|
||||
SET_SESSION_ID = "setSessionId"
|
||||
SET_PROGRESSION_STAGE = "setProgressionStage"
|
||||
REQUEST_STATE = "requestState"
|
||||
PATCH_SESSION_RESULT = "patchSessionResult"
|
||||
|
||||
# Outgoing to backend
|
||||
STATE_REPORT = "stateReport"
|
||||
IMAGE_DETECTION = "imageDetection"
|
||||
PATCH_SESSION = "patchSession"
|
||||
|
||||
|
||||
def parse_incoming_message(raw_message: str) -> Optional[IncomingMessage]:
|
||||
"""
|
||||
Parse incoming WebSocket message and validate against known types.
|
||||
|
||||
Args:
|
||||
raw_message: Raw JSON string from WebSocket
|
||||
|
||||
Returns:
|
||||
Parsed message object or None if invalid
|
||||
"""
|
||||
try:
|
||||
data = json.loads(raw_message)
|
||||
message_type = data.get("type")
|
||||
|
||||
if not message_type:
|
||||
logger.error("Message missing 'type' field")
|
||||
return None
|
||||
|
||||
# Route to appropriate message class
|
||||
if message_type == MessageTypes.SET_SUBSCRIPTION_LIST:
|
||||
return SetSubscriptionListMessage(**data)
|
||||
elif message_type == MessageTypes.SET_SESSION_ID:
|
||||
return SetSessionIdMessage(**data)
|
||||
elif message_type == MessageTypes.SET_PROGRESSION_STAGE:
|
||||
return SetProgressionStageMessage(**data)
|
||||
elif message_type == MessageTypes.REQUEST_STATE:
|
||||
return RequestStateMessage(**data)
|
||||
elif message_type == MessageTypes.PATCH_SESSION_RESULT:
|
||||
return PatchSessionResultMessage(**data)
|
||||
else:
|
||||
logger.warning(f"Unknown message type: {message_type}")
|
||||
return None
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to decode JSON message: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse incoming message: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def serialize_outgoing_message(message: OutgoingMessage) -> str:
|
||||
"""
|
||||
Serialize outgoing message to JSON string.
|
||||
|
||||
Args:
|
||||
message: Message object to serialize
|
||||
|
||||
Returns:
|
||||
JSON string representation
|
||||
"""
|
||||
try:
|
||||
# For ImageDetectionMessage, we need to include None values for abandonment detection
|
||||
from .models import ImageDetectionMessage
|
||||
if isinstance(message, ImageDetectionMessage):
|
||||
return message.model_dump_json(exclude_none=False)
|
||||
else:
|
||||
return message.model_dump_json(exclude_none=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to serialize outgoing message: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def validate_subscription_identifier(identifier: str) -> bool:
|
||||
"""
|
||||
Validate subscription identifier format (displayId;cameraId).
|
||||
|
||||
Args:
|
||||
identifier: Subscription identifier to validate
|
||||
|
||||
Returns:
|
||||
True if valid format, False otherwise
|
||||
"""
|
||||
if not identifier or not isinstance(identifier, str):
|
||||
return False
|
||||
|
||||
parts = identifier.split(';')
|
||||
if len(parts) != 2:
|
||||
logger.error(f"Invalid subscription identifier format: {identifier}")
|
||||
return False
|
||||
|
||||
display_id, camera_id = parts
|
||||
if not display_id or not camera_id:
|
||||
logger.error(f"Empty display or camera ID in identifier: {identifier}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def extract_display_identifier(subscription_identifier: str) -> Optional[str]:
|
||||
"""
|
||||
Extract display identifier from subscription identifier.
|
||||
|
||||
Args:
|
||||
subscription_identifier: Full subscription identifier (displayId;cameraId)
|
||||
|
||||
Returns:
|
||||
Display identifier or None if invalid format
|
||||
"""
|
||||
if not validate_subscription_identifier(subscription_identifier):
|
||||
return None
|
||||
|
||||
return subscription_identifier.split(';')[0]
|
||||
|
||||
|
||||
def create_state_report(cpu_usage: float, memory_usage: float,
|
||||
gpu_usage: Optional[float] = None,
|
||||
gpu_memory_usage: Optional[float] = None,
|
||||
camera_connections: Optional[list] = None) -> StateReportMessage:
|
||||
"""
|
||||
Create a state report message with system metrics.
|
||||
|
||||
Args:
|
||||
cpu_usage: CPU usage percentage
|
||||
memory_usage: Memory usage percentage
|
||||
gpu_usage: GPU usage percentage (optional)
|
||||
gpu_memory_usage: GPU memory usage in MB (optional)
|
||||
camera_connections: List of active camera connections
|
||||
|
||||
Returns:
|
||||
StateReportMessage object
|
||||
"""
|
||||
return StateReportMessage(
|
||||
cpuUsage=cpu_usage,
|
||||
memoryUsage=memory_usage,
|
||||
gpuUsage=gpu_usage,
|
||||
gpuMemoryUsage=gpu_memory_usage,
|
||||
cameraConnections=camera_connections or []
|
||||
)
|
||||
|
||||
|
||||
def create_image_detection(subscription_identifier: str, detection_data: Union[Dict[str, Any], None],
|
||||
model_id: int, model_name: str) -> ImageDetectionMessage:
|
||||
"""
|
||||
Create an image detection message.
|
||||
|
||||
Args:
|
||||
subscription_identifier: Camera subscription identifier
|
||||
detection_data: Detection results - Dict for data, {} for empty, None for abandonment
|
||||
model_id: Model identifier
|
||||
model_name: Model name
|
||||
|
||||
Returns:
|
||||
ImageDetectionMessage object
|
||||
"""
|
||||
from .models import DetectionData
|
||||
from typing import Union
|
||||
|
||||
# Handle three cases:
|
||||
# 1. None = car abandonment (detection: null)
|
||||
# 2. {} = empty detection (triggers session creation)
|
||||
# 3. {...} = full detection data (updates session)
|
||||
|
||||
data = DetectionData(
|
||||
detection=detection_data,
|
||||
modelId=model_id,
|
||||
modelName=model_name
|
||||
)
|
||||
|
||||
return ImageDetectionMessage(
|
||||
subscriptionIdentifier=subscription_identifier,
|
||||
data=data
|
||||
)
|
||||
|
||||
|
||||
def create_patch_session(session_id: int, patch_data: Dict[str, Any]) -> PatchSessionMessage:
|
||||
"""
|
||||
Create a patch session message.
|
||||
|
||||
Args:
|
||||
session_id: Session ID to patch
|
||||
patch_data: Partial session data to update
|
||||
|
||||
Returns:
|
||||
PatchSessionMessage object
|
||||
"""
|
||||
return PatchSessionMessage(
|
||||
sessionId=session_id,
|
||||
data=patch_data
|
||||
)
|
||||
|
|
@ -1,150 +0,0 @@
|
|||
"""
|
||||
Message data structures for WebSocket communication.
|
||||
Based on worker.md protocol specification.
|
||||
"""
|
||||
from typing import Dict, Any, List, Optional, Union, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class SubscriptionObject(BaseModel):
|
||||
"""Individual camera subscription configuration."""
|
||||
subscriptionIdentifier: str = Field(..., description="Format: displayId;cameraId")
|
||||
rtspUrl: Optional[str] = Field(None, description="RTSP stream URL")
|
||||
snapshotUrl: Optional[str] = Field(None, description="HTTP snapshot URL")
|
||||
snapshotInterval: Optional[int] = Field(None, description="Snapshot interval in milliseconds")
|
||||
modelUrl: str = Field(..., description="Pre-signed URL to .mpta file")
|
||||
modelId: int = Field(..., description="Unique model identifier")
|
||||
modelName: str = Field(..., description="Human-readable model name")
|
||||
cropX1: Optional[int] = Field(None, description="Crop region X1 coordinate")
|
||||
cropY1: Optional[int] = Field(None, description="Crop region Y1 coordinate")
|
||||
cropX2: Optional[int] = Field(None, description="Crop region X2 coordinate")
|
||||
cropY2: Optional[int] = Field(None, description="Crop region Y2 coordinate")
|
||||
|
||||
|
||||
class CameraConnection(BaseModel):
|
||||
"""Camera connection status for state reporting."""
|
||||
subscriptionIdentifier: str
|
||||
modelId: int
|
||||
modelName: str
|
||||
online: bool
|
||||
cropX1: Optional[int] = None
|
||||
cropY1: Optional[int] = None
|
||||
cropX2: Optional[int] = None
|
||||
cropY2: Optional[int] = None
|
||||
|
||||
|
||||
class DetectionData(BaseModel):
|
||||
"""
|
||||
Detection result data structure.
|
||||
|
||||
Supports three cases:
|
||||
1. Empty detection: detection = {} (triggers session creation)
|
||||
2. Full detection: detection = {"carBrand": "Honda", ...} (updates session)
|
||||
3. Null detection: detection = None (car abandonment)
|
||||
"""
|
||||
model_config = {
|
||||
"json_encoders": {type(None): lambda v: None},
|
||||
"arbitrary_types_allowed": True
|
||||
}
|
||||
|
||||
detection: Union[Dict[str, Any], None] = Field(
|
||||
default_factory=dict,
|
||||
description="Detection results: {} for empty, {...} for data, None/null for abandonment"
|
||||
)
|
||||
modelId: int
|
||||
modelName: str
|
||||
|
||||
|
||||
# Incoming Messages from Backend to Worker
|
||||
|
||||
class SetSubscriptionListMessage(BaseModel):
|
||||
"""Complete subscription list for declarative state management."""
|
||||
type: Literal["setSubscriptionList"] = "setSubscriptionList"
|
||||
subscriptions: List[SubscriptionObject]
|
||||
|
||||
|
||||
class SetSessionIdPayload(BaseModel):
|
||||
"""Session ID association payload."""
|
||||
displayIdentifier: str
|
||||
sessionId: Optional[int] = None
|
||||
|
||||
|
||||
class SetSessionIdMessage(BaseModel):
|
||||
"""Associate session ID with display."""
|
||||
type: Literal["setSessionId"] = "setSessionId"
|
||||
payload: SetSessionIdPayload
|
||||
|
||||
|
||||
class SetProgressionStagePayload(BaseModel):
|
||||
"""Progression stage payload."""
|
||||
displayIdentifier: str
|
||||
progressionStage: Optional[str] = None
|
||||
|
||||
|
||||
class SetProgressionStageMessage(BaseModel):
|
||||
"""Set progression stage for display."""
|
||||
type: Literal["setProgressionStage"] = "setProgressionStage"
|
||||
payload: SetProgressionStagePayload
|
||||
|
||||
|
||||
class RequestStateMessage(BaseModel):
|
||||
"""Request current worker state."""
|
||||
type: Literal["requestState"] = "requestState"
|
||||
|
||||
|
||||
class PatchSessionResultPayload(BaseModel):
|
||||
"""Patch session result payload."""
|
||||
sessionId: int
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class PatchSessionResultMessage(BaseModel):
|
||||
"""Response to patch session request."""
|
||||
type: Literal["patchSessionResult"] = "patchSessionResult"
|
||||
payload: PatchSessionResultPayload
|
||||
|
||||
|
||||
# Outgoing Messages from Worker to Backend
|
||||
|
||||
class StateReportMessage(BaseModel):
|
||||
"""Periodic heartbeat with system metrics."""
|
||||
type: Literal["stateReport"] = "stateReport"
|
||||
cpuUsage: float
|
||||
memoryUsage: float
|
||||
gpuUsage: Optional[float] = None
|
||||
gpuMemoryUsage: Optional[float] = None
|
||||
cameraConnections: List[CameraConnection]
|
||||
|
||||
|
||||
class ImageDetectionMessage(BaseModel):
|
||||
"""Detection event message."""
|
||||
type: Literal["imageDetection"] = "imageDetection"
|
||||
subscriptionIdentifier: str
|
||||
timestamp: str = Field(default_factory=lambda: datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ"))
|
||||
data: DetectionData
|
||||
|
||||
|
||||
class PatchSessionMessage(BaseModel):
|
||||
"""Request to modify session data."""
|
||||
type: Literal["patchSession"] = "patchSession"
|
||||
sessionId: int
|
||||
data: Dict[str, Any] = Field(..., description="Partial DisplayPersistentData structure")
|
||||
|
||||
|
||||
# Union type for all incoming messages
|
||||
IncomingMessage = Union[
|
||||
SetSubscriptionListMessage,
|
||||
SetSessionIdMessage,
|
||||
SetProgressionStageMessage,
|
||||
RequestStateMessage,
|
||||
PatchSessionResultMessage
|
||||
]
|
||||
|
||||
# Union type for all outgoing messages
|
||||
OutgoingMessage = Union[
|
||||
StateReportMessage,
|
||||
ImageDetectionMessage,
|
||||
PatchSessionMessage
|
||||
]
|
||||
|
|
@ -1,234 +0,0 @@
|
|||
"""
|
||||
Worker state management for system metrics and subscription tracking.
|
||||
"""
|
||||
import logging
|
||||
import psutil
|
||||
import threading
|
||||
from typing import Dict, Set, Optional, List
|
||||
from dataclasses import dataclass, field
|
||||
from .models import CameraConnection, SubscriptionObject
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import torch and pynvml for GPU monitoring
|
||||
try:
|
||||
import torch
|
||||
TORCH_AVAILABLE = True
|
||||
except ImportError:
|
||||
TORCH_AVAILABLE = False
|
||||
logger.warning("PyTorch not available, GPU metrics will not be collected")
|
||||
|
||||
try:
|
||||
import pynvml
|
||||
PYNVML_AVAILABLE = True
|
||||
pynvml.nvmlInit()
|
||||
logger.info("NVIDIA ML Python (pynvml) initialized successfully")
|
||||
except ImportError:
|
||||
PYNVML_AVAILABLE = False
|
||||
logger.debug("pynvml not available, falling back to PyTorch GPU monitoring")
|
||||
except Exception as e:
|
||||
PYNVML_AVAILABLE = False
|
||||
logger.warning(f"Failed to initialize pynvml: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerState:
|
||||
"""Central state management for the detector worker."""
|
||||
|
||||
# Active subscriptions indexed by subscription identifier
|
||||
subscriptions: Dict[str, SubscriptionObject] = field(default_factory=dict)
|
||||
|
||||
# Session ID mapping: display_identifier -> session_id
|
||||
session_ids: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
# Progression stage mapping: display_identifier -> stage
|
||||
progression_stages: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# Active camera connections for state reporting
|
||||
camera_connections: List[CameraConnection] = field(default_factory=list)
|
||||
|
||||
# Thread lock for state synchronization
|
||||
_lock: threading.RLock = field(default_factory=threading.RLock)
|
||||
|
||||
def set_subscriptions(self, new_subscriptions: List[SubscriptionObject]) -> None:
|
||||
"""
|
||||
Update active subscriptions with declarative list from backend.
|
||||
|
||||
Args:
|
||||
new_subscriptions: Complete list of desired subscriptions
|
||||
"""
|
||||
with self._lock:
|
||||
# Convert to dict for easy lookup
|
||||
new_sub_dict = {sub.subscriptionIdentifier: sub for sub in new_subscriptions}
|
||||
|
||||
# Log changes for debugging
|
||||
current_ids = set(self.subscriptions.keys())
|
||||
new_ids = set(new_sub_dict.keys())
|
||||
|
||||
added = new_ids - current_ids
|
||||
removed = current_ids - new_ids
|
||||
updated = current_ids & new_ids
|
||||
|
||||
if added:
|
||||
logger.info(f"[State Update] Adding subscriptions: {added}")
|
||||
if removed:
|
||||
logger.info(f"[State Update] Removing subscriptions: {removed}")
|
||||
if updated:
|
||||
logger.info(f"[State Update] Updating subscriptions: {updated}")
|
||||
|
||||
# Replace entire subscription dict
|
||||
self.subscriptions = new_sub_dict
|
||||
|
||||
# Update camera connections for state reporting
|
||||
self._update_camera_connections()
|
||||
|
||||
def get_subscription(self, subscription_identifier: str) -> Optional[SubscriptionObject]:
|
||||
"""Get subscription by identifier."""
|
||||
with self._lock:
|
||||
return self.subscriptions.get(subscription_identifier)
|
||||
|
||||
def get_all_subscriptions(self) -> List[SubscriptionObject]:
|
||||
"""Get all active subscriptions."""
|
||||
with self._lock:
|
||||
return list(self.subscriptions.values())
|
||||
|
||||
def set_session_id(self, display_identifier: str, session_id: Optional[int]) -> None:
|
||||
"""
|
||||
Set or clear session ID for a display.
|
||||
|
||||
Args:
|
||||
display_identifier: Display identifier
|
||||
session_id: Session ID to set, or None to clear
|
||||
"""
|
||||
with self._lock:
|
||||
if session_id is None:
|
||||
self.session_ids.pop(display_identifier, None)
|
||||
logger.info(f"[State Update] Cleared session ID for display {display_identifier}")
|
||||
else:
|
||||
self.session_ids[display_identifier] = session_id
|
||||
logger.info(f"[State Update] Set session ID {session_id} for display {display_identifier}")
|
||||
|
||||
def get_session_id(self, display_identifier: str) -> Optional[int]:
|
||||
"""Get session ID for display identifier."""
|
||||
with self._lock:
|
||||
return self.session_ids.get(display_identifier)
|
||||
|
||||
def get_session_id_for_subscription(self, subscription_identifier: str) -> Optional[int]:
|
||||
"""Get session ID for subscription by extracting display identifier."""
|
||||
from .messages import extract_display_identifier
|
||||
|
||||
display_id = extract_display_identifier(subscription_identifier)
|
||||
if display_id:
|
||||
return self.get_session_id(display_id)
|
||||
return None
|
||||
|
||||
def set_progression_stage(self, display_identifier: str, stage: Optional[str]) -> None:
|
||||
"""
|
||||
Set or clear progression stage for a display.
|
||||
|
||||
Args:
|
||||
display_identifier: Display identifier
|
||||
stage: Progression stage to set, or None to clear
|
||||
"""
|
||||
with self._lock:
|
||||
if stage is None:
|
||||
self.progression_stages.pop(display_identifier, None)
|
||||
logger.info(f"[State Update] Cleared progression stage for display {display_identifier}")
|
||||
else:
|
||||
self.progression_stages[display_identifier] = stage
|
||||
logger.info(f"[State Update] Set progression stage '{stage}' for display {display_identifier}")
|
||||
|
||||
def get_progression_stage(self, display_identifier: str) -> Optional[str]:
|
||||
"""Get progression stage for display identifier."""
|
||||
with self._lock:
|
||||
return self.progression_stages.get(display_identifier)
|
||||
|
||||
def _update_camera_connections(self) -> None:
|
||||
"""Update camera connections list for state reporting."""
|
||||
connections = []
|
||||
|
||||
for sub in self.subscriptions.values():
|
||||
connection = CameraConnection(
|
||||
subscriptionIdentifier=sub.subscriptionIdentifier,
|
||||
modelId=sub.modelId,
|
||||
modelName=sub.modelName,
|
||||
online=True, # TODO: Add actual online status tracking
|
||||
cropX1=sub.cropX1,
|
||||
cropY1=sub.cropY1,
|
||||
cropX2=sub.cropX2,
|
||||
cropY2=sub.cropY2
|
||||
)
|
||||
connections.append(connection)
|
||||
|
||||
self.camera_connections = connections
|
||||
|
||||
def get_camera_connections(self) -> List[CameraConnection]:
|
||||
"""Get current camera connections for state reporting."""
|
||||
with self._lock:
|
||||
return self.camera_connections.copy()
|
||||
|
||||
|
||||
class SystemMetrics:
|
||||
"""System metrics collection for state reporting."""
|
||||
|
||||
@staticmethod
|
||||
def get_cpu_usage() -> float:
|
||||
"""Get current CPU usage percentage."""
|
||||
try:
|
||||
return psutil.cpu_percent(interval=0.1)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get CPU usage: {e}")
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def get_memory_usage() -> float:
|
||||
"""Get current memory usage percentage."""
|
||||
try:
|
||||
return psutil.virtual_memory().percent
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get memory usage: {e}")
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def get_gpu_usage() -> Optional[float]:
|
||||
"""Get current GPU usage percentage."""
|
||||
try:
|
||||
# Prefer pynvml for accurate GPU utilization
|
||||
if PYNVML_AVAILABLE:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(0) # First GPU
|
||||
utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
|
||||
return float(utilization.gpu)
|
||||
|
||||
# Fallback to PyTorch memory-based estimation
|
||||
elif TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
if hasattr(torch.cuda, 'utilization'):
|
||||
return torch.cuda.utilization()
|
||||
else:
|
||||
# Estimate based on memory usage
|
||||
allocated = torch.cuda.memory_allocated()
|
||||
reserved = torch.cuda.memory_reserved()
|
||||
if reserved > 0:
|
||||
return (allocated / reserved) * 100
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get GPU usage: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_gpu_memory_usage() -> Optional[float]:
|
||||
"""Get current GPU memory usage in MB."""
|
||||
if not TORCH_AVAILABLE:
|
||||
return None
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get GPU memory usage: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Global worker state instance
|
||||
worker_state = WorkerState()
|
||||
|
|
@ -1,677 +0,0 @@
|
|||
"""
|
||||
WebSocket message handling and protocol implementation.
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import cv2
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
|
||||
from .messages import (
|
||||
parse_incoming_message, serialize_outgoing_message,
|
||||
MessageTypes, create_state_report
|
||||
)
|
||||
from .models import (
|
||||
SetSubscriptionListMessage, SetSessionIdMessage, SetProgressionStageMessage,
|
||||
RequestStateMessage, PatchSessionResultMessage
|
||||
)
|
||||
from .state import worker_state, SystemMetrics
|
||||
from ..models import ModelManager
|
||||
from ..streaming.manager import shared_stream_manager
|
||||
from ..tracking.integration import TrackingPipelineIntegration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
HEARTBEAT_INTERVAL = 2.0 # seconds
|
||||
WORKER_TIMEOUT_MS = 10000
|
||||
|
||||
# Global model manager instance
|
||||
model_manager = ModelManager()
|
||||
|
||||
|
||||
class WebSocketHandler:
|
||||
"""
|
||||
Handles WebSocket connection lifecycle and message processing.
|
||||
"""
|
||||
|
||||
def __init__(self, websocket: WebSocket):
|
||||
self.websocket = websocket
|
||||
self.connected = False
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self._message_task: Optional[asyncio.Task] = None
|
||||
self._heartbeat_count = 0
|
||||
self._last_processed_models: set = set() # Cache of last processed model IDs
|
||||
|
||||
async def handle_connection(self) -> None:
|
||||
"""
|
||||
Main connection handler that manages the WebSocket lifecycle.
|
||||
Based on the original architecture from archive/app.py
|
||||
"""
|
||||
client_info = f"{self.websocket.client.host}:{self.websocket.client.port}" if self.websocket.client else "unknown"
|
||||
logger.info(f"Starting WebSocket handler for {client_info}")
|
||||
|
||||
stream_task = None
|
||||
try:
|
||||
logger.info(f"Accepting WebSocket connection from {client_info}")
|
||||
await self.websocket.accept()
|
||||
self.connected = True
|
||||
logger.info(f"WebSocket connection accepted and established for {client_info}")
|
||||
|
||||
# Send immediate heartbeat to show connection is alive
|
||||
await self._send_immediate_heartbeat()
|
||||
|
||||
# Start background tasks (matching original architecture)
|
||||
stream_task = asyncio.create_task(self._process_streams())
|
||||
heartbeat_task = asyncio.create_task(self._send_heartbeat())
|
||||
message_task = asyncio.create_task(self._handle_messages())
|
||||
|
||||
logger.info(f"WebSocket background tasks started for {client_info} (stream + heartbeat + message handler)")
|
||||
|
||||
# Wait for heartbeat and message tasks (stream runs independently)
|
||||
await asyncio.gather(heartbeat_task, message_task)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in WebSocket connection for {client_info}: {e}", exc_info=True)
|
||||
finally:
|
||||
logger.info(f"Cleaning up connection for {client_info}")
|
||||
# Cancel stream task
|
||||
if stream_task and not stream_task.done():
|
||||
stream_task.cancel()
|
||||
try:
|
||||
await stream_task
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"Stream task cancelled for {client_info}")
|
||||
await self._cleanup()
|
||||
|
||||
async def _send_immediate_heartbeat(self) -> None:
|
||||
"""Send immediate heartbeat on connection to show we're alive."""
|
||||
try:
|
||||
cpu_usage = SystemMetrics.get_cpu_usage()
|
||||
memory_usage = SystemMetrics.get_memory_usage()
|
||||
gpu_usage = SystemMetrics.get_gpu_usage()
|
||||
gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
|
||||
camera_connections = worker_state.get_camera_connections()
|
||||
|
||||
state_report = create_state_report(
|
||||
cpu_usage=cpu_usage,
|
||||
memory_usage=memory_usage,
|
||||
gpu_usage=gpu_usage,
|
||||
gpu_memory_usage=gpu_memory_usage,
|
||||
camera_connections=camera_connections
|
||||
)
|
||||
|
||||
await self._send_message(state_report)
|
||||
logger.info(f"[TX → Backend] Initial stateReport: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, "
|
||||
f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending immediate heartbeat: {e}")
|
||||
|
||||
async def _send_heartbeat(self) -> None:
|
||||
"""Send periodic state reports as heartbeat."""
|
||||
while self.connected:
|
||||
try:
|
||||
# Collect system metrics
|
||||
cpu_usage = SystemMetrics.get_cpu_usage()
|
||||
memory_usage = SystemMetrics.get_memory_usage()
|
||||
gpu_usage = SystemMetrics.get_gpu_usage()
|
||||
gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
|
||||
camera_connections = worker_state.get_camera_connections()
|
||||
|
||||
# Create and send state report
|
||||
state_report = create_state_report(
|
||||
cpu_usage=cpu_usage,
|
||||
memory_usage=memory_usage,
|
||||
gpu_usage=gpu_usage,
|
||||
gpu_memory_usage=gpu_memory_usage,
|
||||
camera_connections=camera_connections
|
||||
)
|
||||
|
||||
await self._send_message(state_report)
|
||||
|
||||
# Only log full details every 10th heartbeat, otherwise just show a dot
|
||||
self._heartbeat_count += 1
|
||||
if self._heartbeat_count % 10 == 0:
|
||||
logger.info(f"[TX → Backend] Heartbeat #{self._heartbeat_count}: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, "
|
||||
f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras")
|
||||
else:
|
||||
print(".", end="", flush=True) # Just show a dot to indicate heartbeat activity
|
||||
|
||||
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending heartbeat: {e}")
|
||||
break
|
||||
|
||||
async def _handle_messages(self) -> None:
|
||||
"""Handle incoming WebSocket messages."""
|
||||
while self.connected:
|
||||
try:
|
||||
raw_message = await self.websocket.receive_text()
|
||||
logger.info(f"[RX ← Backend] {raw_message}")
|
||||
|
||||
# Parse incoming message
|
||||
message = parse_incoming_message(raw_message)
|
||||
if not message:
|
||||
logger.warning("Failed to parse incoming message")
|
||||
continue
|
||||
|
||||
# Route message to appropriate handler
|
||||
await self._route_message(message)
|
||||
|
||||
except (WebSocketDisconnect, ConnectionClosedError) as e:
|
||||
logger.warning(f"WebSocket disconnected: {e}")
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Received invalid JSON message")
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}")
|
||||
break
|
||||
|
||||
async def _route_message(self, message) -> None:
|
||||
"""Route parsed message to appropriate handler."""
|
||||
message_type = message.type
|
||||
|
||||
try:
|
||||
if message_type == MessageTypes.SET_SUBSCRIPTION_LIST:
|
||||
await self._handle_set_subscription_list(message)
|
||||
elif message_type == MessageTypes.SET_SESSION_ID:
|
||||
await self._handle_set_session_id(message)
|
||||
elif message_type == MessageTypes.SET_PROGRESSION_STAGE:
|
||||
await self._handle_set_progression_stage(message)
|
||||
elif message_type == MessageTypes.REQUEST_STATE:
|
||||
await self._handle_request_state(message)
|
||||
elif message_type == MessageTypes.PATCH_SESSION_RESULT:
|
||||
await self._handle_patch_session_result(message)
|
||||
else:
|
||||
logger.warning(f"Unknown message type: {message_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling {message_type} message: {e}")
|
||||
|
||||
async def _handle_set_subscription_list(self, message: SetSubscriptionListMessage) -> None:
|
||||
"""Handle setSubscriptionList message for declarative subscription management."""
|
||||
logger.info(f"[RX Processing] setSubscriptionList with {len(message.subscriptions)} subscriptions")
|
||||
|
||||
# Update worker state with new subscriptions
|
||||
worker_state.set_subscriptions(message.subscriptions)
|
||||
|
||||
# Phase 2: Download and manage models
|
||||
await self._ensure_models(message.subscriptions)
|
||||
|
||||
# Phase 3 & 4: Integrate with streaming management and tracking
|
||||
await self._update_stream_subscriptions(message.subscriptions)
|
||||
|
||||
logger.info("Subscription list updated successfully")
|
||||
|
||||
async def _ensure_models(self, subscriptions) -> None:
|
||||
"""Ensure all required models are downloaded and available."""
|
||||
# Extract unique model requirements
|
||||
unique_models = {}
|
||||
for subscription in subscriptions:
|
||||
model_id = subscription.modelId
|
||||
if model_id not in unique_models:
|
||||
unique_models[model_id] = {
|
||||
'model_url': subscription.modelUrl,
|
||||
'model_name': subscription.modelName
|
||||
}
|
||||
|
||||
# Check if model set has changed to avoid redundant processing
|
||||
current_model_ids = set(unique_models.keys())
|
||||
if current_model_ids == self._last_processed_models:
|
||||
logger.debug(f"[Model Management] Model set unchanged {list(current_model_ids)}, skipping checks")
|
||||
return
|
||||
|
||||
logger.info(f"[Model Management] Processing {len(unique_models)} unique models: {list(unique_models.keys())}")
|
||||
self._last_processed_models = current_model_ids
|
||||
|
||||
# Check and download models concurrently
|
||||
download_tasks = []
|
||||
for model_id, model_info in unique_models.items():
|
||||
task = asyncio.create_task(
|
||||
self._ensure_single_model(model_id, model_info['model_url'], model_info['model_name'])
|
||||
)
|
||||
download_tasks.append(task)
|
||||
|
||||
# Wait for all downloads to complete
|
||||
if download_tasks:
|
||||
results = await asyncio.gather(*download_tasks, return_exceptions=True)
|
||||
|
||||
# Log results
|
||||
success_count = 0
|
||||
for i, result in enumerate(results):
|
||||
model_id = list(unique_models.keys())[i]
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"[Model Management] Failed to ensure model {model_id}: {result}")
|
||||
elif result:
|
||||
success_count += 1
|
||||
logger.info(f"[Model Management] Model {model_id} ready for use")
|
||||
else:
|
||||
logger.error(f"[Model Management] Failed to ensure model {model_id}")
|
||||
|
||||
logger.info(f"[Model Management] Successfully ensured {success_count}/{len(unique_models)} models")
|
||||
|
||||
async def _update_stream_subscriptions(self, subscriptions) -> None:
|
||||
"""Update streaming subscriptions with tracking integration."""
|
||||
try:
|
||||
# Convert subscriptions to the format expected by StreamManager
|
||||
subscription_payloads = []
|
||||
for subscription in subscriptions:
|
||||
payload = {
|
||||
'subscriptionIdentifier': subscription.subscriptionIdentifier,
|
||||
'rtspUrl': subscription.rtspUrl,
|
||||
'snapshotUrl': subscription.snapshotUrl,
|
||||
'snapshotInterval': subscription.snapshotInterval,
|
||||
'modelId': subscription.modelId,
|
||||
'modelUrl': subscription.modelUrl,
|
||||
'modelName': subscription.modelName
|
||||
}
|
||||
# Add crop coordinates if present
|
||||
if hasattr(subscription, 'cropX1'):
|
||||
payload.update({
|
||||
'cropX1': subscription.cropX1,
|
||||
'cropY1': subscription.cropY1,
|
||||
'cropX2': subscription.cropX2,
|
||||
'cropY2': subscription.cropY2
|
||||
})
|
||||
subscription_payloads.append(payload)
|
||||
|
||||
# Reconcile subscriptions with StreamManager
|
||||
logger.info("[Streaming] Reconciling stream subscriptions with tracking")
|
||||
reconcile_result = await self._reconcile_subscriptions_with_tracking(subscription_payloads)
|
||||
|
||||
logger.info(f"[Streaming] Subscription reconciliation complete: "
|
||||
f"added={reconcile_result.get('added', 0)}, "
|
||||
f"removed={reconcile_result.get('removed', 0)}, "
|
||||
f"failed={reconcile_result.get('failed', 0)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating stream subscriptions: {e}", exc_info=True)
|
||||
|
||||
async def _reconcile_subscriptions_with_tracking(self, target_subscriptions) -> dict:
|
||||
"""Reconcile subscriptions with tracking integration."""
|
||||
try:
|
||||
# Create separate tracking integrations for each subscription (camera isolation)
|
||||
tracking_integrations = {}
|
||||
|
||||
for subscription_payload in target_subscriptions:
|
||||
subscription_id = subscription_payload['subscriptionIdentifier']
|
||||
model_id = subscription_payload['modelId']
|
||||
|
||||
# Create separate tracking integration per subscription for camera isolation
|
||||
# Get pipeline configuration for this model
|
||||
pipeline_parser = model_manager.get_pipeline_config(model_id)
|
||||
if pipeline_parser:
|
||||
# Create tracking integration with message sender (separate instance per camera)
|
||||
tracking_integration = TrackingPipelineIntegration(
|
||||
pipeline_parser, model_manager, model_id, self._send_message
|
||||
)
|
||||
|
||||
# Initialize tracking model
|
||||
success = await tracking_integration.initialize_tracking_model()
|
||||
if success:
|
||||
tracking_integrations[subscription_id] = tracking_integration
|
||||
logger.info(f"[Tracking] Created isolated tracking integration for subscription {subscription_id} (model {model_id})")
|
||||
else:
|
||||
logger.warning(f"[Tracking] Failed to initialize tracking for subscription {subscription_id} (model {model_id})")
|
||||
else:
|
||||
logger.warning(f"[Tracking] No pipeline config found for model {model_id} in subscription {subscription_id}")
|
||||
|
||||
# Now reconcile with StreamManager, adding tracking integrations
|
||||
current_subscription_ids = set()
|
||||
for subscription_info in shared_stream_manager.get_all_subscriptions():
|
||||
current_subscription_ids.add(subscription_info.subscription_id)
|
||||
|
||||
target_subscription_ids = {sub['subscriptionIdentifier'] for sub in target_subscriptions}
|
||||
|
||||
# Find subscriptions to remove and add
|
||||
to_remove = current_subscription_ids - target_subscription_ids
|
||||
to_add = target_subscription_ids - current_subscription_ids
|
||||
|
||||
# Remove old subscriptions
|
||||
removed_count = 0
|
||||
for subscription_id in to_remove:
|
||||
if shared_stream_manager.remove_subscription(subscription_id):
|
||||
removed_count += 1
|
||||
logger.info(f"[Streaming] Removed subscription {subscription_id}")
|
||||
|
||||
# Add new subscriptions with tracking
|
||||
added_count = 0
|
||||
failed_count = 0
|
||||
for subscription_payload in target_subscriptions:
|
||||
subscription_id = subscription_payload['subscriptionIdentifier']
|
||||
if subscription_id in to_add:
|
||||
success = await self._add_subscription_with_tracking(
|
||||
subscription_payload, tracking_integrations
|
||||
)
|
||||
if success:
|
||||
added_count += 1
|
||||
logger.info(f"[Streaming] Added subscription {subscription_id} with tracking")
|
||||
else:
|
||||
failed_count += 1
|
||||
logger.error(f"[Streaming] Failed to add subscription {subscription_id}")
|
||||
|
||||
return {
|
||||
'removed': removed_count,
|
||||
'added': added_count,
|
||||
'failed': failed_count,
|
||||
'total_active': len(shared_stream_manager.get_all_subscriptions())
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in subscription reconciliation with tracking: {e}", exc_info=True)
|
||||
return {'removed': 0, 'added': 0, 'failed': 0, 'total_active': 0}
|
||||
|
||||
async def _add_subscription_with_tracking(self, payload, tracking_integrations) -> bool:
|
||||
"""Add a subscription with tracking integration."""
|
||||
try:
|
||||
from ..streaming.manager import StreamConfig
|
||||
|
||||
subscription_id = payload['subscriptionIdentifier']
|
||||
camera_id = subscription_id.split(';')[-1]
|
||||
model_id = payload['modelId']
|
||||
|
||||
logger.info(f"[SUBSCRIPTION_MAPPING] subscription_id='{subscription_id}' → camera_id='{camera_id}'")
|
||||
|
||||
# Get tracking integration for this subscription (camera-isolated)
|
||||
tracking_integration = tracking_integrations.get(subscription_id)
|
||||
|
||||
# Extract crop coordinates if present
|
||||
crop_coords = None
|
||||
if all(key in payload for key in ['cropX1', 'cropY1', 'cropX2', 'cropY2']):
|
||||
crop_coords = (
|
||||
payload['cropX1'],
|
||||
payload['cropY1'],
|
||||
payload['cropX2'],
|
||||
payload['cropY2']
|
||||
)
|
||||
|
||||
# Create stream configuration
|
||||
stream_config = StreamConfig(
|
||||
camera_id=camera_id,
|
||||
rtsp_url=payload.get('rtspUrl'),
|
||||
snapshot_url=payload.get('snapshotUrl'),
|
||||
snapshot_interval=payload.get('snapshotInterval', 5000),
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
# Add subscription to StreamManager with tracking
|
||||
success = shared_stream_manager.add_subscription(
|
||||
subscription_id=subscription_id,
|
||||
stream_config=stream_config,
|
||||
crop_coords=crop_coords,
|
||||
model_id=model_id,
|
||||
model_url=payload.get('modelUrl'),
|
||||
tracking_integration=tracking_integration
|
||||
)
|
||||
|
||||
if success and tracking_integration:
|
||||
logger.info(f"[Tracking] Subscription {subscription_id} configured with isolated tracking for model {model_id}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding subscription with tracking: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def _ensure_single_model(self, model_id: int, model_url: str, model_name: str) -> bool:
|
||||
"""Ensure a single model is downloaded and available."""
|
||||
try:
|
||||
# Check if model is already available
|
||||
if model_manager.is_model_downloaded(model_id):
|
||||
logger.info(f"[Model Management] Model {model_id} ({model_name}) already available")
|
||||
return True
|
||||
|
||||
# Download and extract model in a thread pool to avoid blocking the event loop
|
||||
logger.info(f"[Model Management] Downloading model {model_id} ({model_name}) from {model_url}")
|
||||
|
||||
# Use asyncio.to_thread for CPU-bound operations (Python 3.9+)
|
||||
# For compatibility, we'll use run_in_executor
|
||||
loop = asyncio.get_event_loop()
|
||||
model_path = await loop.run_in_executor(
|
||||
None,
|
||||
model_manager.ensure_model,
|
||||
model_id,
|
||||
model_url,
|
||||
model_name
|
||||
)
|
||||
|
||||
if model_path:
|
||||
logger.info(f"[Model Management] Successfully prepared model {model_id} at {model_path}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[Model Management] Failed to prepare model {model_id}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Model Management] Exception ensuring model {model_id}: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def _save_snapshot(self, display_identifier: str, session_id: int) -> None:
|
||||
"""
|
||||
Save snapshot image to images folder after receiving sessionId.
|
||||
|
||||
Args:
|
||||
display_identifier: Display identifier to match with subscriptionIdentifier
|
||||
session_id: Session ID to include in filename
|
||||
"""
|
||||
try:
|
||||
# Find subscription that matches the displayIdentifier
|
||||
matching_subscription = None
|
||||
for subscription in worker_state.get_all_subscriptions():
|
||||
# Extract display ID from subscriptionIdentifier (format: displayId;cameraId)
|
||||
from .messages import extract_display_identifier
|
||||
sub_display_id = extract_display_identifier(subscription.subscriptionIdentifier)
|
||||
if sub_display_id == display_identifier:
|
||||
matching_subscription = subscription
|
||||
break
|
||||
|
||||
if not matching_subscription:
|
||||
logger.error(f"[Snapshot Save] No subscription found for display {display_identifier}")
|
||||
return
|
||||
|
||||
if not matching_subscription.snapshotUrl:
|
||||
logger.error(f"[Snapshot Save] No snapshotUrl found for display {display_identifier}")
|
||||
return
|
||||
|
||||
# Ensure images directory exists (relative path for Docker bind mount)
|
||||
images_dir = Path("images")
|
||||
images_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Generate filename with timestamp and session ID
|
||||
timestamp = datetime.now(tz=timezone(timedelta(hours=7))).strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{session_id}_{display_identifier}_{timestamp}.jpg"
|
||||
filepath = images_dir / filename
|
||||
|
||||
# Use existing HTTPSnapshotReader to fetch snapshot
|
||||
logger.info(f"[Snapshot Save] Fetching snapshot from {matching_subscription.snapshotUrl}")
|
||||
|
||||
# Run snapshot fetch in thread pool to avoid blocking async loop
|
||||
loop = asyncio.get_event_loop()
|
||||
frame = await loop.run_in_executor(None, self._fetch_snapshot_sync, matching_subscription.snapshotUrl)
|
||||
|
||||
if frame is not None:
|
||||
# Save the image using OpenCV
|
||||
success = cv2.imwrite(str(filepath), frame)
|
||||
if success:
|
||||
logger.info(f"[Snapshot Save] Successfully saved snapshot to {filepath}")
|
||||
else:
|
||||
logger.error(f"[Snapshot Save] Failed to save image file {filepath}")
|
||||
else:
|
||||
logger.error(f"[Snapshot Save] Failed to fetch snapshot from {matching_subscription.snapshotUrl}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Snapshot Save] Error saving snapshot for display {display_identifier}: {e}", exc_info=True)
|
||||
|
||||
def _fetch_snapshot_sync(self, snapshot_url: str):
|
||||
"""
|
||||
Synchronous snapshot fetching using existing HTTPSnapshotReader infrastructure.
|
||||
|
||||
Args:
|
||||
snapshot_url: URL to fetch snapshot from
|
||||
|
||||
Returns:
|
||||
np.ndarray or None: Fetched frame or None on error
|
||||
"""
|
||||
try:
|
||||
from ..streaming.readers import HTTPSnapshotReader
|
||||
|
||||
# Create temporary snapshot reader for single fetch
|
||||
snapshot_reader = HTTPSnapshotReader(
|
||||
camera_id="temp_snapshot",
|
||||
snapshot_url=snapshot_url,
|
||||
interval_ms=5000 # Not used for single fetch
|
||||
)
|
||||
|
||||
# Use existing fetch_single_snapshot method
|
||||
return snapshot_reader.fetch_single_snapshot()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in sync snapshot fetch: {e}")
|
||||
return None
|
||||
|
||||
async def _handle_set_session_id(self, message: SetSessionIdMessage) -> None:
|
||||
"""Handle setSessionId message."""
|
||||
display_identifier = message.payload.displayIdentifier
|
||||
session_id = str(message.payload.sessionId) if message.payload.sessionId is not None else None
|
||||
|
||||
logger.info(f"[RX Processing] setSessionId for display {display_identifier}: {session_id}")
|
||||
|
||||
# Update worker state
|
||||
worker_state.set_session_id(display_identifier, session_id)
|
||||
|
||||
# Update tracking integrations with session ID
|
||||
shared_stream_manager.set_session_id(display_identifier, session_id)
|
||||
|
||||
async def _handle_set_progression_stage(self, message: SetProgressionStageMessage) -> None:
|
||||
"""Handle setProgressionStage message."""
|
||||
display_identifier = message.payload.displayIdentifier
|
||||
stage = message.payload.progressionStage
|
||||
|
||||
logger.info(f"[RX Processing] setProgressionStage for display {display_identifier}: {stage}")
|
||||
|
||||
# Update worker state
|
||||
worker_state.set_progression_stage(display_identifier, stage)
|
||||
|
||||
# Update tracking integration for car abandonment detection
|
||||
session_id = worker_state.get_session_id(display_identifier)
|
||||
if session_id:
|
||||
shared_stream_manager.set_progression_stage(session_id, stage)
|
||||
|
||||
# Save snapshot image when progression stage is car_fueling
|
||||
if stage == 'car_fueling' and session_id:
|
||||
await self._save_snapshot(display_identifier, session_id)
|
||||
|
||||
# If stage indicates session is cleared/finished, clear from tracking
|
||||
if stage in ['finished', 'cleared', 'idle']:
|
||||
# Get session ID for this display and clear it
|
||||
if session_id:
|
||||
shared_stream_manager.clear_session_id(session_id)
|
||||
logger.info(f"[Tracking] Cleared session {session_id} due to progression stage: {stage}")
|
||||
|
||||
async def _handle_request_state(self, message: RequestStateMessage) -> None:
|
||||
"""Handle requestState message by sending immediate state report."""
|
||||
logger.debug("[RX Processing] requestState - sending immediate state report")
|
||||
|
||||
# Collect metrics and send state report
|
||||
cpu_usage = SystemMetrics.get_cpu_usage()
|
||||
memory_usage = SystemMetrics.get_memory_usage()
|
||||
gpu_usage = SystemMetrics.get_gpu_usage()
|
||||
gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
|
||||
camera_connections = worker_state.get_camera_connections()
|
||||
|
||||
state_report = create_state_report(
|
||||
cpu_usage=cpu_usage,
|
||||
memory_usage=memory_usage,
|
||||
gpu_usage=gpu_usage,
|
||||
gpu_memory_usage=gpu_memory_usage,
|
||||
camera_connections=camera_connections
|
||||
)
|
||||
|
||||
await self._send_message(state_report)
|
||||
|
||||
async def _handle_patch_session_result(self, message: PatchSessionResultMessage) -> None:
|
||||
"""Handle patchSessionResult message."""
|
||||
payload = message.payload
|
||||
logger.info(f"[RX Processing] patchSessionResult for session {payload.sessionId}: "
|
||||
f"success={payload.success}, message='{payload.message}'")
|
||||
|
||||
# TODO: Handle patch session result if needed
|
||||
# For now, just log the response
|
||||
|
||||
async def _send_message(self, message) -> None:
|
||||
"""Send message to backend via WebSocket."""
|
||||
if not self.connected:
|
||||
logger.warning("Cannot send message: WebSocket not connected")
|
||||
return
|
||||
|
||||
try:
|
||||
json_message = serialize_outgoing_message(message)
|
||||
await self.websocket.send_text(json_message)
|
||||
# Log non-heartbeat messages only (heartbeats are logged in their respective functions)
|
||||
if not (hasattr(message, 'type') and message.type == 'stateReport'):
|
||||
logger.info(f"[TX → Backend] {json_message}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send WebSocket message: {e}")
|
||||
raise
|
||||
|
||||
async def _process_streams(self) -> None:
|
||||
"""
|
||||
Stream processing task that handles frame processing and detection.
|
||||
This is a placeholder for Phase 2 - currently just logs that it's running.
|
||||
"""
|
||||
logger.info("Stream processing task started")
|
||||
try:
|
||||
while self.connected:
|
||||
# Get current subscriptions
|
||||
subscriptions = worker_state.get_all_subscriptions()
|
||||
|
||||
# TODO: Phase 2 - Add actual frame processing logic here
|
||||
# This will include:
|
||||
# - Frame reading from RTSP/HTTP streams
|
||||
# - Model inference using loaded pipelines
|
||||
# - Detection result sending via WebSocket
|
||||
|
||||
# Sleep to prevent excessive CPU usage (similar to old poll_interval)
|
||||
await asyncio.sleep(0.1) # 100ms polling interval
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Stream processing task cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream processing: {e}", exc_info=True)
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
"""Clean up resources when connection closes."""
|
||||
logger.info("Cleaning up WebSocket connection")
|
||||
self.connected = False
|
||||
|
||||
# Cancel background tasks
|
||||
if self._heartbeat_task and not self._heartbeat_task.done():
|
||||
self._heartbeat_task.cancel()
|
||||
if self._message_task and not self._message_task.done():
|
||||
self._message_task.cancel()
|
||||
|
||||
# Clear worker state
|
||||
worker_state.set_subscriptions([])
|
||||
worker_state.session_ids.clear()
|
||||
worker_state.progression_stages.clear()
|
||||
|
||||
logger.info("WebSocket connection cleanup completed")
|
||||
|
||||
|
||||
# Factory function for FastAPI integration
|
||||
async def websocket_endpoint(websocket: WebSocket) -> None:
|
||||
"""
|
||||
FastAPI WebSocket endpoint handler.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection
|
||||
"""
|
||||
handler = WebSocketHandler(websocket)
|
||||
await handler.handle_connection()
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
"""
|
||||
Detection module for the Python Detector Worker.
|
||||
|
||||
This module provides the main detection pipeline orchestration and parallel branch processing
|
||||
for advanced computer vision detection systems.
|
||||
"""
|
||||
from .pipeline import DetectionPipeline
|
||||
from .branches import BranchProcessor
|
||||
|
||||
__all__ = ['DetectionPipeline', 'BranchProcessor']
|
||||
|
|
@ -1,936 +0,0 @@
|
|||
"""
|
||||
Parallel Branch Processing Module.
|
||||
Handles concurrent execution of classification branches and result synchronization.
|
||||
"""
|
||||
import logging
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from ..models.inference import YOLOWrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BranchProcessor:
|
||||
"""
|
||||
Handles parallel processing of classification branches.
|
||||
Manages branch synchronization and result collection.
|
||||
"""
|
||||
|
||||
def __init__(self, model_manager: Any, model_id: int):
|
||||
"""
|
||||
Initialize branch processor.
|
||||
|
||||
Args:
|
||||
model_manager: Model manager for loading models
|
||||
model_id: The model ID to use for loading models
|
||||
"""
|
||||
self.model_manager = model_manager
|
||||
self.model_id = model_id
|
||||
|
||||
# Branch models cache
|
||||
self.branch_models: Dict[str, YOLOWrapper] = {}
|
||||
|
||||
# Dynamic field mapping: branch_id → output_field_name (e.g., {"car_brand_cls_v3": "brand"})
|
||||
self.branch_output_fields: Dict[str, str] = {}
|
||||
|
||||
# Thread pool for parallel execution
|
||||
self.executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
# Storage managers (set during initialization)
|
||||
self.redis_manager = None
|
||||
# self.db_manager = None # Disabled - PostgreSQL operations moved to microservices
|
||||
|
||||
# Branch execution timeout (seconds)
|
||||
self.branch_timeout = 30.0
|
||||
|
||||
# Statistics
|
||||
self.stats = {
|
||||
'branches_processed': 0,
|
||||
'parallel_executions': 0,
|
||||
'total_processing_time': 0.0,
|
||||
'models_loaded': 0,
|
||||
'branches_timed_out': 0,
|
||||
'branches_failed': 0
|
||||
}
|
||||
|
||||
logger.info("BranchProcessor initialized")
|
||||
|
||||
async def initialize(self, pipeline_config: Any, redis_manager: Any, db_manager: Any = None) -> bool:
|
||||
"""
|
||||
Initialize branch processor with pipeline configuration.
|
||||
|
||||
Args:
|
||||
pipeline_config: Pipeline configuration object
|
||||
redis_manager: Redis manager instance
|
||||
db_manager: Database manager instance (deprecated, not used)
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
self.redis_manager = redis_manager
|
||||
# self.db_manager = db_manager # Disabled - PostgreSQL operations moved to microservices
|
||||
|
||||
# Parse field mappings from parallelActions to enable dynamic field extraction
|
||||
self._parse_branch_output_fields(pipeline_config)
|
||||
|
||||
# Pre-load branch models if they exist
|
||||
branches = getattr(pipeline_config, 'branches', [])
|
||||
if branches:
|
||||
await self._preload_branch_models(branches)
|
||||
|
||||
logger.info(f"BranchProcessor initialized with {len(self.branch_models)} models")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing branch processor: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def _preload_branch_models(self, branches: List[Any]) -> None:
|
||||
"""
|
||||
Pre-load all branch models for faster execution.
|
||||
|
||||
Args:
|
||||
branches: List of branch configurations
|
||||
"""
|
||||
for branch in branches:
|
||||
try:
|
||||
await self._load_branch_model(branch)
|
||||
|
||||
# Recursively load nested branches
|
||||
nested_branches = getattr(branch, 'branches', [])
|
||||
if nested_branches:
|
||||
await self._preload_branch_models(nested_branches)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preloading branch model {getattr(branch, 'model_id', 'unknown')}: {e}")
|
||||
|
||||
async def _load_branch_model(self, branch_config: Any) -> Optional[YOLOWrapper]:
|
||||
"""
|
||||
Load a branch model if not already loaded.
|
||||
|
||||
Args:
|
||||
branch_config: Branch configuration object
|
||||
|
||||
Returns:
|
||||
Loaded YOLO model wrapper or None
|
||||
"""
|
||||
try:
|
||||
model_id = getattr(branch_config, 'model_id', None)
|
||||
model_file = getattr(branch_config, 'model_file', None)
|
||||
|
||||
if not model_id or not model_file:
|
||||
logger.warning(f"Invalid branch config: model_id={model_id}, model_file={model_file}")
|
||||
return None
|
||||
|
||||
# Check if model is already loaded
|
||||
if model_id in self.branch_models:
|
||||
logger.debug(f"Branch model {model_id} already loaded")
|
||||
return self.branch_models[model_id]
|
||||
|
||||
# Load model
|
||||
logger.info(f"Loading branch model: {model_id} ({model_file})")
|
||||
|
||||
# Load model using the proper model ID
|
||||
model = self.model_manager.get_yolo_model(self.model_id, model_file)
|
||||
|
||||
if model:
|
||||
self.branch_models[model_id] = model
|
||||
self.stats['models_loaded'] += 1
|
||||
logger.info(f"Branch model {model_id} loaded successfully")
|
||||
return model
|
||||
else:
|
||||
logger.error(f"Failed to load branch model {model_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading branch model {getattr(branch_config, 'model_id', 'unknown')}: {e}")
|
||||
return None
|
||||
|
||||
def _parse_branch_output_fields(self, pipeline_config: Any) -> None:
|
||||
"""
|
||||
Parse parallelActions.fields to determine what output field each branch produces.
|
||||
Creates dynamic mapping from branch_id to output field name.
|
||||
|
||||
Example:
|
||||
Input: parallelActions.fields = {"car_brand": "{car_brand_cls_v3.brand}"}
|
||||
Output: self.branch_output_fields = {"car_brand_cls_v3": "brand"}
|
||||
|
||||
Args:
|
||||
pipeline_config: Pipeline configuration object
|
||||
"""
|
||||
try:
|
||||
if not pipeline_config or not hasattr(pipeline_config, 'parallel_actions'):
|
||||
logger.debug("[FIELD MAPPING] No parallelActions found in pipeline config")
|
||||
return
|
||||
|
||||
for action in pipeline_config.parallel_actions:
|
||||
# Skip PostgreSQL actions - they are disabled
|
||||
if action.type.value == 'postgresql_update_combined':
|
||||
logger.debug(f"[FIELD MAPPING] Skipping PostgreSQL action (disabled)")
|
||||
continue # Skip field parsing for disabled PostgreSQL operations
|
||||
# fields = action.params.get('fields', {})
|
||||
#
|
||||
# # Parse each field template to extract branch_id and field_name
|
||||
# for db_field_name, template in fields.items():
|
||||
# # Template format: "{branch_id.field_name}"
|
||||
# if template.startswith('{') and template.endswith('}'):
|
||||
# var_name = template[1:-1] # Remove { }
|
||||
#
|
||||
# if '.' in var_name:
|
||||
# branch_id, field_name = var_name.split('.', 1)
|
||||
#
|
||||
# # Store the mapping
|
||||
# self.branch_output_fields[branch_id] = field_name
|
||||
#
|
||||
# logger.info(f"[FIELD MAPPING] Branch '{branch_id}' → outputs field '{field_name}'")
|
||||
|
||||
logger.info(f"[FIELD MAPPING] Parsed {len(self.branch_output_fields)} branch output field mappings")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[FIELD MAPPING] Error parsing branch output fields: {e}", exc_info=True)
|
||||
|
||||
async def execute_branches(self,
|
||||
frame: np.ndarray,
|
||||
branches: List[Any],
|
||||
detected_regions: Dict[str, Any],
|
||||
detection_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute all branches in parallel and collect results.
|
||||
|
||||
Args:
|
||||
frame: Input frame
|
||||
branches: List of branch configurations
|
||||
detected_regions: Dictionary of detected regions from main detection
|
||||
detection_context: Detection context data
|
||||
|
||||
Returns:
|
||||
Dictionary with branch execution results
|
||||
"""
|
||||
start_time = time.time()
|
||||
branch_results = {}
|
||||
|
||||
try:
|
||||
# Separate parallel and sequential branches
|
||||
parallel_branches = []
|
||||
sequential_branches = []
|
||||
|
||||
for branch in branches:
|
||||
if getattr(branch, 'parallel', False):
|
||||
parallel_branches.append(branch)
|
||||
else:
|
||||
sequential_branches.append(branch)
|
||||
|
||||
# Execute parallel branches concurrently
|
||||
if parallel_branches:
|
||||
logger.info(f"Executing {len(parallel_branches)} branches in parallel")
|
||||
parallel_results = await self._execute_parallel_branches(
|
||||
frame, parallel_branches, detected_regions, detection_context
|
||||
)
|
||||
branch_results.update(parallel_results)
|
||||
self.stats['parallel_executions'] += 1
|
||||
|
||||
# Execute sequential branches one by one
|
||||
if sequential_branches:
|
||||
logger.info(f"Executing {len(sequential_branches)} branches sequentially")
|
||||
sequential_results = await self._execute_sequential_branches(
|
||||
frame, sequential_branches, detected_regions, detection_context
|
||||
)
|
||||
branch_results.update(sequential_results)
|
||||
|
||||
# Update statistics
|
||||
self.stats['branches_processed'] += len(branches)
|
||||
processing_time = time.time() - start_time
|
||||
self.stats['total_processing_time'] += processing_time
|
||||
|
||||
logger.info(f"Branch execution completed in {processing_time:.3f}s with {len(branch_results)} results")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in branch execution: {e}", exc_info=True)
|
||||
|
||||
return branch_results
|
||||
|
||||
async def _execute_parallel_branches(self,
|
||||
frame: np.ndarray,
|
||||
branches: List[Any],
|
||||
detected_regions: Dict[str, Any],
|
||||
detection_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute branches in parallel using ThreadPoolExecutor.
|
||||
|
||||
Args:
|
||||
frame: Input frame
|
||||
branches: List of parallel branch configurations
|
||||
detected_regions: Dictionary of detected regions
|
||||
detection_context: Detection context data
|
||||
|
||||
Returns:
|
||||
Dictionary with parallel branch results
|
||||
"""
|
||||
results = {}
|
||||
|
||||
# Submit all branches for parallel execution
|
||||
future_to_branch = {}
|
||||
|
||||
for branch in branches:
|
||||
branch_id = getattr(branch, 'model_id', 'unknown')
|
||||
logger.info(f"[PARALLEL SUBMIT] {branch_id}: Submitting branch to thread pool")
|
||||
|
||||
future = self.executor.submit(
|
||||
self._execute_single_branch_sync,
|
||||
frame, branch, detected_regions, detection_context
|
||||
)
|
||||
future_to_branch[future] = branch
|
||||
|
||||
# Collect results as they complete with timeout
|
||||
try:
|
||||
for future in as_completed(future_to_branch, timeout=self.branch_timeout):
|
||||
branch = future_to_branch[future]
|
||||
branch_id = getattr(branch, 'model_id', 'unknown')
|
||||
|
||||
try:
|
||||
# Get result with timeout to prevent indefinite hanging
|
||||
result = future.result(timeout=self.branch_timeout)
|
||||
results[branch_id] = result
|
||||
logger.info(f"[PARALLEL COMPLETE] {branch_id}: Branch completed successfully")
|
||||
except TimeoutError:
|
||||
logger.error(f"[TIMEOUT] Branch {branch_id} exceeded timeout of {self.branch_timeout}s")
|
||||
self.stats['branches_timed_out'] += 1
|
||||
results[branch_id] = {
|
||||
'status': 'timeout',
|
||||
'message': f'Branch execution timeout after {self.branch_timeout}s',
|
||||
'processing_time': self.branch_timeout
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[ERROR] Error in parallel branch {branch_id}: {e}", exc_info=True)
|
||||
self.stats['branches_failed'] += 1
|
||||
results[branch_id] = {
|
||||
'status': 'error',
|
||||
'message': str(e),
|
||||
'processing_time': 0.0
|
||||
}
|
||||
except TimeoutError:
|
||||
# as_completed iterator timed out - mark remaining futures as timed out
|
||||
logger.error(f"[TIMEOUT] Branch execution timeout after {self.branch_timeout}s - some branches did not complete")
|
||||
for future, branch in future_to_branch.items():
|
||||
branch_id = getattr(branch, 'model_id', 'unknown')
|
||||
if branch_id not in results:
|
||||
logger.error(f"[TIMEOUT] Branch {branch_id} did not complete within timeout")
|
||||
self.stats['branches_timed_out'] += 1
|
||||
results[branch_id] = {
|
||||
'status': 'timeout',
|
||||
'message': f'Branch did not complete within {self.branch_timeout}s timeout',
|
||||
'processing_time': self.branch_timeout
|
||||
}
|
||||
|
||||
# Flatten nested branch results to top level for database access
|
||||
flattened_results = {}
|
||||
for branch_id, branch_result in results.items():
|
||||
# Add the branch result itself
|
||||
flattened_results[branch_id] = branch_result
|
||||
|
||||
# If this branch has nested branches, add them to the top level too
|
||||
if isinstance(branch_result, dict) and 'nested_branches' in branch_result:
|
||||
nested_branches = branch_result['nested_branches']
|
||||
for nested_branch_id, nested_result in nested_branches.items():
|
||||
flattened_results[nested_branch_id] = nested_result
|
||||
logger.info(f"[FLATTEN] Added nested branch {nested_branch_id} to top-level results")
|
||||
|
||||
# Log summary of branch execution results
|
||||
succeeded = [bid for bid, res in results.items() if res.get('status') == 'success']
|
||||
failed = [bid for bid, res in results.items() if res.get('status') == 'error']
|
||||
timed_out = [bid for bid, res in results.items() if res.get('status') == 'timeout']
|
||||
skipped = [bid for bid, res in results.items() if res.get('status') == 'skipped']
|
||||
|
||||
summary_parts = []
|
||||
if succeeded:
|
||||
summary_parts.append(f"{len(succeeded)} succeeded: {', '.join(succeeded)}")
|
||||
if failed:
|
||||
summary_parts.append(f"{len(failed)} FAILED: {', '.join(failed)}")
|
||||
if timed_out:
|
||||
summary_parts.append(f"{len(timed_out)} TIMED OUT: {', '.join(timed_out)}")
|
||||
if skipped:
|
||||
summary_parts.append(f"{len(skipped)} skipped: {', '.join(skipped)}")
|
||||
|
||||
logger.info(f"[PARALLEL SUMMARY] Branch execution completed: {' | '.join(summary_parts) if summary_parts else 'no branches'}")
|
||||
|
||||
return flattened_results
|
||||
|
||||
async def _execute_sequential_branches(self,
|
||||
frame: np.ndarray,
|
||||
branches: List[Any],
|
||||
detected_regions: Dict[str, Any],
|
||||
detection_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute branches sequentially.
|
||||
|
||||
Args:
|
||||
frame: Input frame
|
||||
branches: List of sequential branch configurations
|
||||
detected_regions: Dictionary of detected regions
|
||||
detection_context: Detection context data
|
||||
|
||||
Returns:
|
||||
Dictionary with sequential branch results
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for branch in branches:
|
||||
branch_id = getattr(branch, 'model_id', 'unknown')
|
||||
|
||||
try:
|
||||
result = await asyncio.get_event_loop().run_in_executor(
|
||||
self.executor,
|
||||
self._execute_single_branch_sync,
|
||||
frame, branch, detected_regions, detection_context
|
||||
)
|
||||
results[branch_id] = result
|
||||
logger.debug(f"Sequential branch {branch_id} completed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in sequential branch {branch_id}: {e}")
|
||||
results[branch_id] = {
|
||||
'status': 'error',
|
||||
'message': str(e),
|
||||
'processing_time': 0.0
|
||||
}
|
||||
|
||||
# Flatten nested branch results to top level for database access
|
||||
flattened_results = {}
|
||||
for branch_id, branch_result in results.items():
|
||||
# Add the branch result itself
|
||||
flattened_results[branch_id] = branch_result
|
||||
|
||||
# If this branch has nested branches, add them to the top level too
|
||||
if isinstance(branch_result, dict) and 'nested_branches' in branch_result:
|
||||
nested_branches = branch_result['nested_branches']
|
||||
for nested_branch_id, nested_result in nested_branches.items():
|
||||
flattened_results[nested_branch_id] = nested_result
|
||||
logger.info(f"[FLATTEN] Added nested branch {nested_branch_id} to top-level results")
|
||||
|
||||
return flattened_results
|
||||
|
||||
def _execute_single_branch_sync(self,
|
||||
frame: np.ndarray,
|
||||
branch_config: Any,
|
||||
detected_regions: Dict[str, Any],
|
||||
detection_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Synchronous execution of a single branch (for ThreadPoolExecutor).
|
||||
|
||||
Args:
|
||||
frame: Input frame
|
||||
branch_config: Branch configuration object
|
||||
detected_regions: Dictionary of detected regions
|
||||
detection_context: Detection context data
|
||||
|
||||
Returns:
|
||||
Dictionary with branch execution result
|
||||
"""
|
||||
start_time = time.time()
|
||||
branch_id = getattr(branch_config, 'model_id', 'unknown')
|
||||
|
||||
logger.info(f"[BRANCH START] {branch_id}: Starting branch execution")
|
||||
logger.debug(f"[BRANCH CONFIG] {branch_id}: crop={getattr(branch_config, 'crop', False)}, "
|
||||
f"trigger_classes={getattr(branch_config, 'trigger_classes', [])}, "
|
||||
f"min_confidence={getattr(branch_config, 'min_confidence', 0.6)}")
|
||||
|
||||
# Check if branch should execute based on triggerClasses (execution conditions)
|
||||
trigger_classes = getattr(branch_config, 'trigger_classes', [])
|
||||
logger.info(f"[DETECTED REGIONS] {branch_id}: Available parent detections: {list(detected_regions.keys())}")
|
||||
for region_name, region_data in detected_regions.items():
|
||||
# Handle both list (new) and single dict (backward compat)
|
||||
if isinstance(region_data, list):
|
||||
for i, region in enumerate(region_data):
|
||||
logger.debug(f"[REGION DATA] {branch_id}: '{region_name}[{i}]' -> bbox={region.get('bbox')}, conf={region.get('confidence')}")
|
||||
else:
|
||||
logger.debug(f"[REGION DATA] {branch_id}: '{region_name}' -> bbox={region_data.get('bbox')}, conf={region_data.get('confidence')}")
|
||||
|
||||
if trigger_classes:
|
||||
# Check if any parent detection matches our trigger classes (case-insensitive)
|
||||
should_execute = False
|
||||
for trigger_class in trigger_classes:
|
||||
# Case-insensitive comparison for robustness
|
||||
if trigger_class.lower() in [k.lower() for k in detected_regions.keys()]:
|
||||
should_execute = True
|
||||
logger.info(f"[TRIGGER CHECK] {branch_id}: Found '{trigger_class}' in parent detections - branch will execute")
|
||||
break
|
||||
|
||||
if not should_execute:
|
||||
logger.warning(f"[TRIGGER CHECK] {branch_id}: None of trigger classes {trigger_classes} found in parent detections {list(detected_regions.keys())} - skipping branch")
|
||||
return {
|
||||
'status': 'skipped',
|
||||
'branch_id': branch_id,
|
||||
'message': f'No trigger classes {trigger_classes} found in parent detections',
|
||||
'processing_time': time.time() - start_time
|
||||
}
|
||||
|
||||
result = {
|
||||
'status': 'success',
|
||||
'branch_id': branch_id,
|
||||
'result': {},
|
||||
'processing_time': 0.0,
|
||||
'timestamp': time.time()
|
||||
}
|
||||
|
||||
try:
|
||||
# Get or load branch model
|
||||
if branch_id not in self.branch_models:
|
||||
logger.warning(f"Branch model {branch_id} not preloaded, loading now...")
|
||||
# This should be rare since models are preloaded
|
||||
return {
|
||||
'status': 'error',
|
||||
'message': f'Branch model {branch_id} not available',
|
||||
'processing_time': time.time() - start_time
|
||||
}
|
||||
|
||||
model = self.branch_models[branch_id]
|
||||
|
||||
# Get configuration values first
|
||||
min_confidence = getattr(branch_config, 'min_confidence', 0.6)
|
||||
|
||||
# Prepare input frame for this branch
|
||||
input_frame = frame
|
||||
|
||||
# Handle cropping if required - use biggest bbox that passes min_confidence
|
||||
if getattr(branch_config, 'crop', False):
|
||||
crop_classes = getattr(branch_config, 'crop_class', [])
|
||||
if isinstance(crop_classes, str):
|
||||
crop_classes = [crop_classes]
|
||||
|
||||
# Find the biggest bbox that passes min_confidence threshold
|
||||
best_region = None
|
||||
best_class = None
|
||||
best_area = 0.0
|
||||
|
||||
for crop_class in crop_classes:
|
||||
if crop_class in detected_regions:
|
||||
regions = detected_regions[crop_class]
|
||||
|
||||
# Handle both list (new) and single dict (backward compat)
|
||||
if not isinstance(regions, list):
|
||||
regions = [regions]
|
||||
|
||||
# Find largest bbox from all detections of this class
|
||||
for region in regions:
|
||||
confidence = region.get('confidence', 0.0)
|
||||
bbox = region['bbox']
|
||||
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) # width * height
|
||||
|
||||
# Choose biggest bbox among all available detections
|
||||
if area > best_area:
|
||||
best_region = region
|
||||
best_class = crop_class
|
||||
best_area = area
|
||||
logger.debug(f"[CROP] Selected larger bbox for '{crop_class}': area={area:.0f}px², conf={confidence:.3f}")
|
||||
|
||||
if best_region:
|
||||
bbox = best_region['bbox']
|
||||
x1, y1, x2, y2 = [int(coord) for coord in bbox]
|
||||
cropped = frame[y1:y2, x1:x2]
|
||||
if cropped.size > 0:
|
||||
input_frame = cropped
|
||||
confidence = best_region.get('confidence', 0.0)
|
||||
logger.info(f"[CROP SUCCESS] {branch_id}: cropped '{best_class}' region (conf={confidence:.3f}, area={int(best_area)}) -> shape={cropped.shape}")
|
||||
else:
|
||||
logger.warning(f"Branch {branch_id}: empty crop, using full frame")
|
||||
else:
|
||||
logger.warning(f"Branch {branch_id}: no valid crop regions found (min_conf={min_confidence})")
|
||||
|
||||
logger.info(f"[INFERENCE START] {branch_id}: Running inference on {'cropped' if input_frame is not frame else 'full'} frame "
|
||||
f"({input_frame.shape[1]}x{input_frame.shape[0]}) with confidence={min_confidence}")
|
||||
|
||||
# Use .predict() method for both detection and classification models
|
||||
inference_start = time.time()
|
||||
try:
|
||||
detection_results = model.model.predict(input_frame, conf=min_confidence, verbose=False)
|
||||
inference_time = time.time() - inference_start
|
||||
logger.info(f"[INFERENCE DONE] {branch_id}: Predict completed in {inference_time:.3f}s using .predict() method")
|
||||
except Exception as inference_error:
|
||||
inference_time = time.time() - inference_start
|
||||
logger.error(f"[INFERENCE ERROR] {branch_id}: Model inference failed after {inference_time:.3f}s: {inference_error}", exc_info=True)
|
||||
return {
|
||||
'status': 'error',
|
||||
'branch_id': branch_id,
|
||||
'message': f'Model inference failed: {str(inference_error)}',
|
||||
'processing_time': time.time() - start_time
|
||||
}
|
||||
|
||||
# Initialize branch_detections outside the conditional
|
||||
branch_detections = []
|
||||
|
||||
# Process results using clean, unified logic
|
||||
if detection_results and len(detection_results) > 0:
|
||||
result_obj = detection_results[0]
|
||||
|
||||
# Handle detection models (have .boxes attribute)
|
||||
if hasattr(result_obj, 'boxes') and result_obj.boxes is not None:
|
||||
logger.info(f"[RAW DETECTIONS] {branch_id}: Found {len(result_obj.boxes)} raw detections")
|
||||
|
||||
for i, box in enumerate(result_obj.boxes):
|
||||
class_id = int(box.cls[0])
|
||||
confidence = float(box.conf[0])
|
||||
bbox = box.xyxy[0].cpu().numpy().tolist() # [x1, y1, x2, y2]
|
||||
class_name = model.model.names[class_id]
|
||||
|
||||
logger.debug(f"[RAW DETECTION {i+1}] {branch_id}: '{class_name}', conf={confidence:.3f}")
|
||||
|
||||
# All detections are included - no filtering by trigger_classes here
|
||||
branch_detections.append({
|
||||
'class_name': class_name,
|
||||
'confidence': confidence,
|
||||
'bbox': bbox
|
||||
})
|
||||
|
||||
# Handle classification models (have .probs attribute)
|
||||
elif hasattr(result_obj, 'probs') and result_obj.probs is not None:
|
||||
logger.info(f"[RAW CLASSIFICATION] {branch_id}: Processing classification results")
|
||||
|
||||
probs = result_obj.probs
|
||||
top_indices = probs.top5 # Get top 5 predictions
|
||||
top_conf = probs.top5conf.cpu().numpy()
|
||||
|
||||
# For classification: take only TOP-1 prediction (not all top-5)
|
||||
# This prevents empty results when all top-5 predictions are below threshold
|
||||
if len(top_indices) > 0 and len(top_conf) > 0:
|
||||
top_idx = top_indices[0]
|
||||
top_confidence = float(top_conf[0])
|
||||
|
||||
# Apply minConfidence threshold to top-1 only
|
||||
if top_confidence >= min_confidence:
|
||||
class_name = model.model.names[int(top_idx)]
|
||||
logger.info(f"[CLASSIFICATION TOP-1] {branch_id}: '{class_name}', conf={top_confidence:.3f}")
|
||||
|
||||
# For classification, use full input frame dimensions as bbox
|
||||
branch_detections.append({
|
||||
'class_name': class_name,
|
||||
'confidence': top_confidence,
|
||||
'bbox': [0, 0, input_frame.shape[1], input_frame.shape[0]]
|
||||
})
|
||||
else:
|
||||
logger.warning(f"[CLASSIFICATION FILTERED] {branch_id}: Top prediction conf={top_confidence:.3f} < threshold={min_confidence}")
|
||||
else:
|
||||
logger.warning(f"[UNKNOWN MODEL] {branch_id}: Model results have no .boxes or .probs")
|
||||
|
||||
result['result'] = {
|
||||
'detections': branch_detections,
|
||||
'detection_count': len(branch_detections)
|
||||
}
|
||||
|
||||
logger.info(f"[FINAL RESULTS] {branch_id}: {len(branch_detections)} detections processed")
|
||||
|
||||
# Determine output field name from dynamic mapping (parsed from parallelActions.fields)
|
||||
output_field = self.branch_output_fields.get(branch_id)
|
||||
|
||||
# Always initialize the field (even if None) to ensure it exists for database update
|
||||
if output_field:
|
||||
result['result'][output_field] = None
|
||||
logger.debug(f"[FIELD INIT] {branch_id}: Initialized field '{output_field}' = None")
|
||||
|
||||
# Extract best detection if available
|
||||
if branch_detections:
|
||||
best_detection = max(branch_detections, key=lambda x: x['confidence'])
|
||||
logger.info(f"[BEST DETECTION] {branch_id}: '{best_detection['class_name']}' with confidence {best_detection['confidence']:.3f}")
|
||||
|
||||
# Set the output field value using dynamic mapping
|
||||
if output_field:
|
||||
result['result'][output_field] = best_detection['class_name']
|
||||
logger.info(f"[FIELD SET] {branch_id}: Set field '{output_field}' = '{best_detection['class_name']}'")
|
||||
else:
|
||||
logger.warning(f"[NO MAPPING] {branch_id}: No output field defined in parallelActions.fields")
|
||||
else:
|
||||
logger.warning(f"[NO RESULTS] {branch_id}: No detections found, field '{output_field}' remains None")
|
||||
|
||||
# Execute branch actions if this branch found valid detections
|
||||
actions_executed = []
|
||||
branch_actions = getattr(branch_config, 'actions', [])
|
||||
if branch_actions and branch_detections:
|
||||
logger.info(f"[BRANCH ACTIONS] {branch_id}: Executing {len(branch_actions)} actions")
|
||||
|
||||
# Create detected_regions from THIS branch's detections for actions
|
||||
branch_detected_regions = {}
|
||||
for detection in branch_detections:
|
||||
branch_detected_regions[detection['class_name']] = {
|
||||
'bbox': detection['bbox'],
|
||||
'confidence': detection['confidence']
|
||||
}
|
||||
|
||||
for action in branch_actions:
|
||||
try:
|
||||
action_type = action.type.value # Access the enum value
|
||||
logger.info(f"[ACTION EXECUTE] {branch_id}: Executing action '{action_type}'")
|
||||
|
||||
if action_type == 'redis_save_image':
|
||||
action_result = self._execute_redis_save_image_sync(
|
||||
action, input_frame, branch_detected_regions, detection_context
|
||||
)
|
||||
elif action_type == 'redis_publish':
|
||||
action_result = self._execute_redis_publish_sync(
|
||||
action, detection_context
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[ACTION UNKNOWN] {branch_id}: Unknown action type '{action_type}'")
|
||||
action_result = {'status': 'error', 'message': f'Unknown action type: {action_type}'}
|
||||
|
||||
actions_executed.append({
|
||||
'action_type': action_type,
|
||||
'result': action_result
|
||||
})
|
||||
|
||||
logger.info(f"[ACTION COMPLETE] {branch_id}: Action '{action_type}' result: {action_result.get('status')}")
|
||||
|
||||
except Exception as e:
|
||||
action_type = getattr(action, 'type', None)
|
||||
if action_type:
|
||||
action_type = action_type.value if hasattr(action_type, 'value') else str(action_type)
|
||||
logger.error(f"[ACTION ERROR] {branch_id}: Error executing action '{action_type}': {e}", exc_info=True)
|
||||
actions_executed.append({
|
||||
'action_type': action_type,
|
||||
'result': {'status': 'error', 'message': str(e)}
|
||||
})
|
||||
|
||||
# Add actions executed to result
|
||||
if actions_executed:
|
||||
result['actions_executed'] = actions_executed
|
||||
|
||||
# Handle nested branches ONLY if parent found valid detections
|
||||
nested_branches = getattr(branch_config, 'branches', [])
|
||||
if nested_branches:
|
||||
# Check if parent branch found any valid detections
|
||||
if not branch_detections:
|
||||
logger.warning(f"[BRANCH SKIP] {branch_id}: Skipping {len(nested_branches)} nested branches - parent found no valid detections")
|
||||
else:
|
||||
logger.debug(f"Branch {branch_id}: executing {len(nested_branches)} nested branches")
|
||||
|
||||
# Create detected_regions from THIS branch's detections for nested branches
|
||||
# Nested branches should see their immediate parent's detections, not the root pipeline
|
||||
nested_detected_regions = {}
|
||||
for detection in branch_detections:
|
||||
nested_detected_regions[detection['class_name']] = {
|
||||
'bbox': detection['bbox'],
|
||||
'confidence': detection['confidence']
|
||||
}
|
||||
|
||||
logger.info(f"[NESTED REGIONS] {branch_id}: Passing {list(nested_detected_regions.keys())} to nested branches")
|
||||
|
||||
# Note: For simplicity, nested branches are executed sequentially in this sync method
|
||||
# In a full async implementation, these could also be parallelized
|
||||
nested_results = {}
|
||||
for nested_branch in nested_branches:
|
||||
nested_result = self._execute_single_branch_sync(
|
||||
input_frame, nested_branch, nested_detected_regions, detection_context
|
||||
)
|
||||
nested_branch_id = getattr(nested_branch, 'model_id', 'unknown')
|
||||
nested_results[nested_branch_id] = nested_result
|
||||
|
||||
result['nested_branches'] = nested_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[BRANCH ERROR] {branch_id}: Error in execution: {e}", exc_info=True)
|
||||
result['status'] = 'error'
|
||||
result['message'] = str(e)
|
||||
|
||||
result['processing_time'] = time.time() - start_time
|
||||
|
||||
# Summary log
|
||||
logger.info(f"[BRANCH COMPLETE] {branch_id}: status={result['status']}, "
|
||||
f"processing_time={result['processing_time']:.3f}s, "
|
||||
f"result_keys={list(result['result'].keys()) if result['result'] else 'none'}")
|
||||
|
||||
return result
|
||||
|
||||
def _execute_redis_save_image_sync(self,
|
||||
action: Dict,
|
||||
frame: np.ndarray,
|
||||
detected_regions: Dict[str, Any],
|
||||
context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Execute redis_save_image action synchronously."""
|
||||
if not self.redis_manager:
|
||||
return {'status': 'error', 'message': 'Redis not available'}
|
||||
|
||||
try:
|
||||
# Get image to save (cropped or full frame)
|
||||
image_to_save = frame
|
||||
region_name = action.params.get('region')
|
||||
|
||||
bbox = None
|
||||
if region_name and region_name in detected_regions:
|
||||
# Crop the specified region
|
||||
# Handle both list (new) and single dict (backward compat)
|
||||
regions = detected_regions[region_name]
|
||||
if isinstance(regions, list):
|
||||
# Multiple detections - select largest bbox
|
||||
if regions:
|
||||
best_region = max(regions, key=lambda r: (r['bbox'][2] - r['bbox'][0]) * (r['bbox'][3] - r['bbox'][1]))
|
||||
bbox = best_region['bbox']
|
||||
else:
|
||||
bbox = regions['bbox']
|
||||
elif region_name and region_name.lower() == 'frontal' and 'front_rear' in detected_regions:
|
||||
# Special case: "frontal" region maps to "front_rear" detection
|
||||
# Handle both list (new) and single dict (backward compat)
|
||||
regions = detected_regions['front_rear']
|
||||
if isinstance(regions, list):
|
||||
# Multiple detections - select largest bbox
|
||||
if regions:
|
||||
best_region = max(regions, key=lambda r: (r['bbox'][2] - r['bbox'][0]) * (r['bbox'][3] - r['bbox'][1]))
|
||||
bbox = best_region['bbox']
|
||||
else:
|
||||
bbox = regions['bbox']
|
||||
|
||||
if bbox is not None:
|
||||
x1, y1, x2, y2 = [int(coord) for coord in bbox]
|
||||
cropped = frame[y1:y2, x1:x2]
|
||||
if cropped.size > 0:
|
||||
image_to_save = cropped
|
||||
logger.debug(f"Cropped region '{region_name}' for redis_save_image")
|
||||
else:
|
||||
logger.warning(f"Empty crop for region '{region_name}', using full frame")
|
||||
|
||||
# Format key with context
|
||||
key = action.params['key'].format(**context)
|
||||
|
||||
# Convert image to bytes
|
||||
import cv2
|
||||
image_format = action.params.get('format', 'jpeg')
|
||||
quality = action.params.get('quality', 90)
|
||||
|
||||
if image_format.lower() == 'jpeg':
|
||||
encode_param = [cv2.IMWRITE_JPEG_QUALITY, quality]
|
||||
_, image_bytes = cv2.imencode('.jpg', image_to_save, encode_param)
|
||||
else:
|
||||
_, image_bytes = cv2.imencode('.png', image_to_save)
|
||||
|
||||
# Save to Redis synchronously using a sync Redis client
|
||||
try:
|
||||
import redis
|
||||
import cv2
|
||||
|
||||
# Create a synchronous Redis client with same connection details
|
||||
sync_redis = redis.Redis(
|
||||
host=self.redis_manager.host,
|
||||
port=self.redis_manager.port,
|
||||
password=self.redis_manager.password,
|
||||
db=self.redis_manager.db,
|
||||
decode_responses=False, # We're storing binary data
|
||||
socket_timeout=self.redis_manager.socket_timeout,
|
||||
socket_connect_timeout=self.redis_manager.socket_connect_timeout
|
||||
)
|
||||
|
||||
# Encode the image
|
||||
if image_format.lower() == 'jpeg':
|
||||
encode_param = [cv2.IMWRITE_JPEG_QUALITY, quality]
|
||||
success, encoded_image = cv2.imencode('.jpg', image_to_save, encode_param)
|
||||
else:
|
||||
success, encoded_image = cv2.imencode('.png', image_to_save)
|
||||
|
||||
if not success:
|
||||
return {'status': 'error', 'message': 'Failed to encode image'}
|
||||
|
||||
# Save to Redis with expiration
|
||||
expire_seconds = action.params.get('expire_seconds', 600)
|
||||
result = sync_redis.setex(key, expire_seconds, encoded_image.tobytes())
|
||||
|
||||
sync_redis.close() # Clean up connection
|
||||
|
||||
if result:
|
||||
# Add image_key to context for subsequent actions
|
||||
context['image_key'] = key
|
||||
return {'status': 'success', 'key': key}
|
||||
else:
|
||||
return {'status': 'error', 'message': 'Failed to save image to Redis'}
|
||||
|
||||
except Exception as redis_error:
|
||||
logger.error(f"Error calling Redis from sync context: {redis_error}")
|
||||
return {'status': 'error', 'message': f'Redis operation failed: {redis_error}'}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in redis_save_image action: {e}", exc_info=True)
|
||||
return {'status': 'error', 'message': str(e)}
|
||||
|
||||
def _execute_redis_publish_sync(self, action: Dict, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Execute redis_publish action synchronously."""
|
||||
if not self.redis_manager:
|
||||
return {'status': 'error', 'message': 'Redis not available'}
|
||||
|
||||
try:
|
||||
channel = action.params['channel']
|
||||
message_template = action.params['message']
|
||||
|
||||
# Debug the message template
|
||||
logger.debug(f"Message template: {repr(message_template)}")
|
||||
logger.debug(f"Context keys: {list(context.keys())}")
|
||||
|
||||
# Format message with context - handle JSON string formatting carefully
|
||||
# The message template contains JSON which causes issues with .format()
|
||||
# Use string replacement instead of format to avoid JSON brace conflicts
|
||||
try:
|
||||
# Ensure image_key is available for message formatting
|
||||
if 'image_key' not in context:
|
||||
context['image_key'] = '' # Default empty value if redis_save_image failed
|
||||
|
||||
# Use string replacement to avoid JSON formatting issues
|
||||
message = message_template
|
||||
for key, value in context.items():
|
||||
placeholder = '{' + key + '}'
|
||||
message = message.replace(placeholder, str(value))
|
||||
|
||||
logger.debug(f"Formatted message using replacement: {message}")
|
||||
except Exception as e:
|
||||
logger.error(f"Message formatting failed: {e}")
|
||||
logger.error(f"Template: {repr(message_template)}")
|
||||
logger.error(f"Context: {context}")
|
||||
return {'status': 'error', 'message': f'Message formatting failed: {e}'}
|
||||
|
||||
# Publish message synchronously using a sync Redis client
|
||||
try:
|
||||
import redis
|
||||
|
||||
# Create a synchronous Redis client with same connection details
|
||||
sync_redis = redis.Redis(
|
||||
host=self.redis_manager.host,
|
||||
port=self.redis_manager.port,
|
||||
password=self.redis_manager.password,
|
||||
db=self.redis_manager.db,
|
||||
decode_responses=True, # For publishing text messages
|
||||
socket_timeout=self.redis_manager.socket_timeout,
|
||||
socket_connect_timeout=self.redis_manager.socket_connect_timeout
|
||||
)
|
||||
|
||||
# Publish message
|
||||
result = sync_redis.publish(channel, message)
|
||||
sync_redis.close() # Clean up connection
|
||||
|
||||
if result >= 0: # Redis publish returns number of subscribers
|
||||
return {'status': 'success', 'subscribers': result, 'channel': channel}
|
||||
else:
|
||||
return {'status': 'error', 'message': 'Failed to publish message to Redis'}
|
||||
|
||||
except Exception as redis_error:
|
||||
logger.error(f"Error calling Redis from sync context: {redis_error}")
|
||||
return {'status': 'error', 'message': f'Redis operation failed: {redis_error}'}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in redis_publish action: {e}", exc_info=True)
|
||||
return {'status': 'error', 'message': str(e)}
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get branch processor statistics."""
|
||||
return {
|
||||
**self.stats,
|
||||
'loaded_models': list(self.branch_models.keys()),
|
||||
'model_count': len(self.branch_models)
|
||||
}
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources."""
|
||||
if self.executor:
|
||||
self.executor.shutdown(wait=False)
|
||||
|
||||
# Clear model cache
|
||||
self.branch_models.clear()
|
||||
|
||||
logger.info("BranchProcessor cleaned up")
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,42 +0,0 @@
|
|||
"""
|
||||
Models Module - MPTA management, pipeline configuration, and YOLO inference
|
||||
"""
|
||||
|
||||
from .manager import ModelManager
|
||||
from .pipeline import (
|
||||
PipelineParser,
|
||||
PipelineConfig,
|
||||
TrackingConfig,
|
||||
ModelBranch,
|
||||
Action,
|
||||
ActionType,
|
||||
RedisConfig,
|
||||
# PostgreSQLConfig # Disabled - moved to microservices
|
||||
)
|
||||
from .inference import (
|
||||
YOLOWrapper,
|
||||
ModelInferenceManager,
|
||||
Detection,
|
||||
InferenceResult
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Manager
|
||||
'ModelManager',
|
||||
|
||||
# Pipeline
|
||||
'PipelineParser',
|
||||
'PipelineConfig',
|
||||
'TrackingConfig',
|
||||
'ModelBranch',
|
||||
'Action',
|
||||
'ActionType',
|
||||
'RedisConfig',
|
||||
# 'PostgreSQLConfig', # Disabled - moved to microservices
|
||||
|
||||
# Inference
|
||||
'YOLOWrapper',
|
||||
'ModelInferenceManager',
|
||||
'Detection',
|
||||
'InferenceResult',
|
||||
]
|
||||
|
|
@ -1,447 +0,0 @@
|
|||
"""
|
||||
YOLO Model Inference Wrapper - Handles model loading and inference optimization
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any, Tuple, Union
|
||||
from threading import Lock
|
||||
from dataclasses import dataclass
|
||||
import cv2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Detection:
|
||||
"""Represents a single detection result"""
|
||||
bbox: List[float] # [x1, y1, x2, y2]
|
||||
confidence: float
|
||||
class_id: int
|
||||
class_name: str
|
||||
track_id: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceResult:
|
||||
"""Result from model inference"""
|
||||
detections: List[Detection]
|
||||
image_shape: Tuple[int, int] # (height, width)
|
||||
inference_time: float
|
||||
model_id: str
|
||||
|
||||
|
||||
class YOLOWrapper:
|
||||
"""Wrapper for YOLO models with caching and optimization"""
|
||||
|
||||
# Class-level model cache shared across all instances
|
||||
_model_cache: Dict[str, Any] = {}
|
||||
_cache_lock = Lock()
|
||||
|
||||
def __init__(self, model_path: Path, model_id: str, device: Optional[str] = None):
|
||||
"""
|
||||
Initialize YOLO wrapper
|
||||
|
||||
Args:
|
||||
model_path: Path to the .pt model file
|
||||
model_id: Unique identifier for the model
|
||||
device: Device to run inference on ('cuda', 'cpu', or None for auto)
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.model_id = model_id
|
||||
|
||||
# Auto-detect device if not specified
|
||||
if device is None:
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
self.model = None
|
||||
self._class_names = []
|
||||
|
||||
|
||||
self._load_model()
|
||||
|
||||
logger.info(f"Initialized YOLO wrapper for {model_id} on {self.device}")
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Load the YOLO model with caching"""
|
||||
cache_key = str(self.model_path)
|
||||
|
||||
with self._cache_lock:
|
||||
# Check if model is already cached
|
||||
if cache_key in self._model_cache:
|
||||
logger.info(f"Loading model {self.model_id} from cache")
|
||||
self.model = self._model_cache[cache_key]
|
||||
self._extract_class_names()
|
||||
return
|
||||
|
||||
# Load model
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
|
||||
logger.info(f"Loading YOLO model from {self.model_path}")
|
||||
self.model = YOLO(str(self.model_path))
|
||||
|
||||
# Move model to device
|
||||
if self.device == 'cuda' and torch.cuda.is_available():
|
||||
self.model.to('cuda')
|
||||
logger.info(f"Model {self.model_id} moved to GPU")
|
||||
|
||||
# Cache the model
|
||||
self._model_cache[cache_key] = self.model
|
||||
self._extract_class_names()
|
||||
|
||||
logger.info(f"Successfully loaded model {self.model_id}")
|
||||
|
||||
except ImportError:
|
||||
logger.error("Ultralytics YOLO not installed. Install with: pip install ultralytics")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load YOLO model {self.model_id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _extract_class_names(self) -> None:
|
||||
"""Extract class names from the model"""
|
||||
try:
|
||||
if hasattr(self.model, 'names'):
|
||||
self._class_names = self.model.names
|
||||
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'names'):
|
||||
self._class_names = self.model.model.names
|
||||
else:
|
||||
logger.warning(f"Could not extract class names from model {self.model_id}")
|
||||
self._class_names = {}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract class names: {str(e)}")
|
||||
self._class_names = {}
|
||||
|
||||
|
||||
def infer(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
confidence_threshold: float = 0.5,
|
||||
trigger_classes: Optional[List[str]] = None,
|
||||
iou_threshold: float = 0.45
|
||||
) -> InferenceResult:
|
||||
"""
|
||||
Run inference on an image
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array (BGR format)
|
||||
confidence_threshold: Minimum confidence for detections
|
||||
trigger_classes: List of class names to filter (None = all classes)
|
||||
iou_threshold: IoU threshold for NMS
|
||||
|
||||
Returns:
|
||||
InferenceResult containing detections
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError(f"Model {self.model_id} not loaded")
|
||||
|
||||
try:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# Run inference
|
||||
results = self.model(
|
||||
image,
|
||||
conf=confidence_threshold,
|
||||
iou=iou_threshold,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
inference_time = time.time() - start_time
|
||||
|
||||
# Parse results
|
||||
detections = self._parse_results(results[0], trigger_classes)
|
||||
|
||||
return InferenceResult(
|
||||
detections=detections,
|
||||
image_shape=(image.shape[0], image.shape[1]),
|
||||
inference_time=inference_time,
|
||||
model_id=self.model_id
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Inference failed for model {self.model_id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _parse_results(
|
||||
self,
|
||||
result: Any,
|
||||
trigger_classes: Optional[List[str]] = None
|
||||
) -> List[Detection]:
|
||||
"""
|
||||
Parse YOLO results into Detection objects
|
||||
|
||||
Args:
|
||||
result: YOLO result object
|
||||
trigger_classes: Optional list of class names to filter
|
||||
|
||||
Returns:
|
||||
List of Detection objects
|
||||
"""
|
||||
detections = []
|
||||
|
||||
try:
|
||||
if result.boxes is None:
|
||||
return detections
|
||||
|
||||
boxes = result.boxes
|
||||
for i in range(len(boxes)):
|
||||
# Get box coordinates
|
||||
box = boxes.xyxy[i].cpu().numpy()
|
||||
x1, y1, x2, y2 = box
|
||||
|
||||
# Get confidence and class
|
||||
conf = float(boxes.conf[i])
|
||||
cls_id = int(boxes.cls[i])
|
||||
|
||||
# Get class name
|
||||
class_name = self._class_names.get(cls_id, f"class_{cls_id}")
|
||||
|
||||
# Filter by trigger classes if specified
|
||||
if trigger_classes and class_name not in trigger_classes:
|
||||
continue
|
||||
|
||||
# Get track ID if available
|
||||
track_id = None
|
||||
if hasattr(boxes, 'id') and boxes.id is not None:
|
||||
track_id = int(boxes.id[i])
|
||||
|
||||
detection = Detection(
|
||||
bbox=[float(x1), float(y1), float(x2), float(y2)],
|
||||
confidence=conf,
|
||||
class_id=cls_id,
|
||||
class_name=class_name,
|
||||
track_id=track_id
|
||||
)
|
||||
detections.append(detection)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse results: {str(e)}", exc_info=True)
|
||||
|
||||
return detections
|
||||
|
||||
|
||||
def track(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
confidence_threshold: float = 0.5,
|
||||
trigger_classes: Optional[List[str]] = None,
|
||||
persist: bool = True,
|
||||
camera_id: Optional[str] = None
|
||||
) -> InferenceResult:
|
||||
"""
|
||||
Run detection (tracking will be handled by external tracker)
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array (BGR format)
|
||||
confidence_threshold: Minimum confidence for detections
|
||||
trigger_classes: List of class names to filter
|
||||
persist: Ignored - tracking handled externally
|
||||
camera_id: Ignored - tracking handled externally
|
||||
|
||||
Returns:
|
||||
InferenceResult containing detections (no track IDs from YOLO)
|
||||
"""
|
||||
# Just do detection - no YOLO tracking
|
||||
return self.infer(image, confidence_threshold, trigger_classes)
|
||||
|
||||
def predict_classification(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
top_k: int = 1
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Run classification on an image
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array (BGR format)
|
||||
top_k: Number of top predictions to return
|
||||
|
||||
Returns:
|
||||
Dictionary of class_name -> confidence scores
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError(f"Model {self.model_id} not loaded")
|
||||
|
||||
try:
|
||||
# Run inference
|
||||
results = self.model(image, verbose=False)
|
||||
|
||||
# For classification models, extract probabilities
|
||||
if hasattr(results[0], 'probs'):
|
||||
probs = results[0].probs
|
||||
top_indices = probs.top5[:top_k]
|
||||
top_conf = probs.top5conf[:top_k].cpu().numpy()
|
||||
|
||||
predictions = {}
|
||||
for idx, conf in zip(top_indices, top_conf):
|
||||
class_name = self._class_names.get(int(idx), f"class_{idx}")
|
||||
predictions[class_name] = float(conf)
|
||||
|
||||
return predictions
|
||||
else:
|
||||
logger.warning(f"Model {self.model_id} does not support classification")
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Classification failed for model {self.model_id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def crop_detection(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
detection: Detection,
|
||||
padding: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Crop image to detection bounding box
|
||||
|
||||
Args:
|
||||
image: Original image
|
||||
detection: Detection to crop
|
||||
padding: Additional padding around the box
|
||||
|
||||
Returns:
|
||||
Cropped image region
|
||||
"""
|
||||
h, w = image.shape[:2]
|
||||
x1, y1, x2, y2 = detection.bbox
|
||||
|
||||
# Add padding and clip to image boundaries
|
||||
x1 = max(0, int(x1) - padding)
|
||||
y1 = max(0, int(y1) - padding)
|
||||
x2 = min(w, int(x2) + padding)
|
||||
y2 = min(h, int(y2) + padding)
|
||||
|
||||
return image[y1:y2, x1:x2]
|
||||
|
||||
def get_class_names(self) -> Dict[int, str]:
|
||||
"""Get the class names dictionary"""
|
||||
return self._class_names.copy()
|
||||
|
||||
def get_num_classes(self) -> int:
|
||||
"""Get the number of classes the model can detect"""
|
||||
return len(self._class_names)
|
||||
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the model cache"""
|
||||
with self._cache_lock:
|
||||
cache_key = str(self.model_path)
|
||||
if cache_key in self._model_cache:
|
||||
del self._model_cache[cache_key]
|
||||
logger.info(f"Cleared cache for model {self.model_id}")
|
||||
|
||||
@classmethod
|
||||
def clear_all_cache(cls) -> None:
|
||||
"""Clear all cached models"""
|
||||
with cls._cache_lock:
|
||||
cls._model_cache.clear()
|
||||
logger.info("Cleared all model cache")
|
||||
|
||||
def warmup(self, image_size: Tuple[int, int] = (640, 640)) -> None:
|
||||
"""
|
||||
Warmup the model with a dummy inference
|
||||
|
||||
Args:
|
||||
image_size: Size of dummy image (height, width)
|
||||
"""
|
||||
try:
|
||||
dummy_image = np.zeros((image_size[0], image_size[1], 3), dtype=np.uint8)
|
||||
self.infer(dummy_image, confidence_threshold=0.5)
|
||||
logger.info(f"Model {self.model_id} warmed up")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to warmup model {self.model_id}: {str(e)}")
|
||||
|
||||
|
||||
class ModelInferenceManager:
|
||||
"""Manages multiple YOLO models for a pipeline"""
|
||||
|
||||
def __init__(self, model_dir: Path):
|
||||
"""
|
||||
Initialize the inference manager
|
||||
|
||||
Args:
|
||||
model_dir: Directory containing model files
|
||||
"""
|
||||
self.model_dir = model_dir
|
||||
self.models: Dict[str, YOLOWrapper] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
logger.info(f"Initialized ModelInferenceManager with model directory: {model_dir}")
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
model_id: str,
|
||||
model_file: str,
|
||||
device: Optional[str] = None
|
||||
) -> YOLOWrapper:
|
||||
"""
|
||||
Load a model for inference
|
||||
|
||||
Args:
|
||||
model_id: Unique identifier for the model
|
||||
model_file: Filename of the model
|
||||
device: Device to run on
|
||||
|
||||
Returns:
|
||||
YOLOWrapper instance
|
||||
"""
|
||||
with self._lock:
|
||||
# Check if already loaded
|
||||
if model_id in self.models:
|
||||
logger.debug(f"Model {model_id} already loaded")
|
||||
return self.models[model_id]
|
||||
|
||||
# Load the model
|
||||
model_path = self.model_dir / model_file
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
wrapper = YOLOWrapper(model_path, model_id, device)
|
||||
self.models[model_id] = wrapper
|
||||
|
||||
return wrapper
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[YOLOWrapper]:
|
||||
"""
|
||||
Get a loaded model
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
YOLOWrapper instance or None if not loaded
|
||||
"""
|
||||
return self.models.get(model_id)
|
||||
|
||||
def unload_model(self, model_id: str) -> bool:
|
||||
"""
|
||||
Unload a model to free memory
|
||||
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
True if unloaded, False if not found
|
||||
"""
|
||||
with self._lock:
|
||||
if model_id in self.models:
|
||||
self.models[model_id].clear_cache()
|
||||
del self.models[model_id]
|
||||
logger.info(f"Unloaded model {model_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def unload_all(self) -> None:
|
||||
"""Unload all models"""
|
||||
with self._lock:
|
||||
for model_id in list(self.models.keys()):
|
||||
self.models[model_id].clear_cache()
|
||||
self.models.clear()
|
||||
logger.info("Unloaded all models")
|
||||
|
|
@ -1,439 +0,0 @@
|
|||
"""
|
||||
Model Manager Module - Handles MPTA download, extraction, and model loading
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import zipfile
|
||||
import json
|
||||
import hashlib
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Any, Set
|
||||
from threading import Lock
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""Manages MPTA model downloads, extraction, and caching"""
|
||||
|
||||
def __init__(self, models_dir: str = "models"):
|
||||
"""
|
||||
Initialize the Model Manager
|
||||
|
||||
Args:
|
||||
models_dir: Base directory for storing models
|
||||
"""
|
||||
self.models_dir = Path(models_dir)
|
||||
self.models_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Track downloaded models to avoid duplicates
|
||||
self._downloaded_models: Set[int] = set()
|
||||
self._model_paths: Dict[int, Path] = {}
|
||||
self._download_lock = Lock()
|
||||
|
||||
# Scan existing models
|
||||
self._scan_existing_models()
|
||||
|
||||
logger.info(f"ModelManager initialized with models directory: {self.models_dir}")
|
||||
logger.info(f"Found existing models: {list(self._downloaded_models)}")
|
||||
|
||||
def _scan_existing_models(self) -> None:
|
||||
"""Scan the models directory for existing downloaded models"""
|
||||
if not self.models_dir.exists():
|
||||
return
|
||||
|
||||
for model_dir in self.models_dir.iterdir():
|
||||
if model_dir.is_dir() and model_dir.name.isdigit():
|
||||
model_id = int(model_dir.name)
|
||||
# Check if extraction was successful by looking for pipeline.json
|
||||
extracted_dirs = list(model_dir.glob("*/pipeline.json"))
|
||||
if extracted_dirs:
|
||||
self._downloaded_models.add(model_id)
|
||||
# Store path to the extracted model directory
|
||||
self._model_paths[model_id] = extracted_dirs[0].parent
|
||||
logger.debug(f"Found existing model {model_id} at {extracted_dirs[0].parent}")
|
||||
|
||||
def get_model_path(self, model_id: int) -> Optional[Path]:
|
||||
"""
|
||||
Get the path to an extracted model directory
|
||||
|
||||
Args:
|
||||
model_id: The model ID
|
||||
|
||||
Returns:
|
||||
Path to the extracted model directory or None if not found
|
||||
"""
|
||||
return self._model_paths.get(model_id)
|
||||
|
||||
def is_model_downloaded(self, model_id: int) -> bool:
|
||||
"""
|
||||
Check if a model has already been downloaded and extracted
|
||||
|
||||
Args:
|
||||
model_id: The model ID to check
|
||||
|
||||
Returns:
|
||||
True if the model is already available
|
||||
"""
|
||||
return model_id in self._downloaded_models
|
||||
|
||||
def ensure_model(self, model_id: int, model_url: str, model_name: str = None) -> Optional[Path]:
|
||||
"""
|
||||
Ensure a model is downloaded and extracted, downloading if necessary
|
||||
|
||||
Args:
|
||||
model_id: The model ID
|
||||
model_url: URL to download the MPTA file from
|
||||
model_name: Optional model name for logging
|
||||
|
||||
Returns:
|
||||
Path to the extracted model directory or None if failed
|
||||
"""
|
||||
# Check if already downloaded
|
||||
if self.is_model_downloaded(model_id):
|
||||
logger.info(f"Model {model_id} already available at {self._model_paths[model_id]}")
|
||||
return self._model_paths[model_id]
|
||||
|
||||
# Download and extract with lock to prevent concurrent downloads of same model
|
||||
with self._download_lock:
|
||||
# Double-check after acquiring lock
|
||||
if self.is_model_downloaded(model_id):
|
||||
return self._model_paths[model_id]
|
||||
|
||||
logger.info(f"Model {model_id} not found locally, downloading from {model_url}")
|
||||
|
||||
# Create model directory
|
||||
model_dir = self.models_dir / str(model_id)
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Extract filename from URL
|
||||
mpta_filename = self._extract_filename_from_url(model_url, model_name, model_id)
|
||||
mpta_path = model_dir / mpta_filename
|
||||
|
||||
# Download MPTA file
|
||||
if not self._download_mpta(model_url, mpta_path):
|
||||
logger.error(f"Failed to download model {model_id}")
|
||||
return None
|
||||
|
||||
# Extract MPTA file
|
||||
extracted_path = self._extract_mpta(mpta_path, model_dir)
|
||||
if not extracted_path:
|
||||
logger.error(f"Failed to extract model {model_id}")
|
||||
return None
|
||||
|
||||
# Mark as downloaded and store path
|
||||
self._downloaded_models.add(model_id)
|
||||
self._model_paths[model_id] = extracted_path
|
||||
|
||||
logger.info(f"Successfully prepared model {model_id} at {extracted_path}")
|
||||
return extracted_path
|
||||
|
||||
def _extract_filename_from_url(self, url: str, model_name: str = None, model_id: int = None) -> str:
|
||||
"""
|
||||
Extract a suitable filename from the URL
|
||||
|
||||
Args:
|
||||
url: The URL to extract filename from
|
||||
model_name: Optional model name
|
||||
model_id: Optional model ID
|
||||
|
||||
Returns:
|
||||
A suitable filename for the MPTA file
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
path = parsed.path
|
||||
|
||||
# Try to get filename from path
|
||||
if path:
|
||||
filename = os.path.basename(path)
|
||||
if filename and filename.endswith('.mpta'):
|
||||
return filename
|
||||
|
||||
# Fallback to constructed name
|
||||
if model_name:
|
||||
return f"{model_name}-{model_id}.mpta"
|
||||
else:
|
||||
return f"model-{model_id}.mpta"
|
||||
|
||||
def _download_mpta(self, url: str, dest_path: Path) -> bool:
|
||||
"""
|
||||
Download an MPTA file from a URL
|
||||
|
||||
Args:
|
||||
url: URL to download from
|
||||
dest_path: Destination path for the file
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting download of model from {url}")
|
||||
logger.debug(f"Download destination: {dest_path}")
|
||||
|
||||
response = requests.get(url, stream=True, timeout=300)
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to download MPTA file (status {response.status_code})")
|
||||
return False
|
||||
|
||||
file_size = int(response.headers.get('content-length', 0))
|
||||
logger.info(f"Model file size: {file_size/1024/1024:.2f} MB")
|
||||
|
||||
downloaded = 0
|
||||
last_log_percent = 0
|
||||
|
||||
with open(dest_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
|
||||
# Log progress every 10%
|
||||
if file_size > 0:
|
||||
percent = int(downloaded * 100 / file_size)
|
||||
if percent >= last_log_percent + 10:
|
||||
logger.debug(f"Download progress: {percent}%")
|
||||
last_log_percent = percent
|
||||
|
||||
logger.info(f"Successfully downloaded MPTA file to {dest_path}")
|
||||
return True
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Network error downloading MPTA: {str(e)}", exc_info=True)
|
||||
# Clean up partial download
|
||||
if dest_path.exists():
|
||||
dest_path.unlink()
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error downloading MPTA: {str(e)}", exc_info=True)
|
||||
# Clean up partial download
|
||||
if dest_path.exists():
|
||||
dest_path.unlink()
|
||||
return False
|
||||
|
||||
def _extract_mpta(self, mpta_path: Path, target_dir: Path) -> Optional[Path]:
|
||||
"""
|
||||
Extract an MPTA (ZIP) file to the target directory
|
||||
|
||||
Args:
|
||||
mpta_path: Path to the MPTA file
|
||||
target_dir: Directory to extract to
|
||||
|
||||
Returns:
|
||||
Path to the extracted model directory containing pipeline.json, or None if failed
|
||||
"""
|
||||
try:
|
||||
if not mpta_path.exists():
|
||||
logger.error(f"MPTA file not found: {mpta_path}")
|
||||
return None
|
||||
|
||||
logger.info(f"Extracting MPTA file from {mpta_path} to {target_dir}")
|
||||
|
||||
with zipfile.ZipFile(mpta_path, 'r') as zip_ref:
|
||||
# Get list of files
|
||||
file_list = zip_ref.namelist()
|
||||
logger.debug(f"Files in MPTA archive: {len(file_list)} files")
|
||||
|
||||
# Extract all files
|
||||
zip_ref.extractall(target_dir)
|
||||
|
||||
logger.info(f"Successfully extracted MPTA file to {target_dir}")
|
||||
|
||||
# Find the directory containing pipeline.json
|
||||
pipeline_files = list(target_dir.glob("*/pipeline.json"))
|
||||
if not pipeline_files:
|
||||
# Check if pipeline.json is in root
|
||||
if (target_dir / "pipeline.json").exists():
|
||||
logger.info(f"Found pipeline.json in root of {target_dir}")
|
||||
return target_dir
|
||||
logger.error(f"No pipeline.json found after extraction in {target_dir}")
|
||||
return None
|
||||
|
||||
# Return the directory containing pipeline.json
|
||||
extracted_dir = pipeline_files[0].parent
|
||||
logger.info(f"Extracted model to {extracted_dir}")
|
||||
|
||||
# Keep the MPTA file for reference but could delete if space is a concern
|
||||
# mpta_path.unlink()
|
||||
# logger.debug(f"Removed MPTA file after extraction: {mpta_path}")
|
||||
|
||||
return extracted_dir
|
||||
|
||||
except zipfile.BadZipFile as e:
|
||||
logger.error(f"Invalid ZIP/MPTA file {mpta_path}: {str(e)}", exc_info=True)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract MPTA file {mpta_path}: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def load_pipeline_config(self, model_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Load the pipeline.json configuration for a model
|
||||
|
||||
Args:
|
||||
model_id: The model ID
|
||||
|
||||
Returns:
|
||||
The pipeline configuration dictionary or None if not found
|
||||
"""
|
||||
model_path = self.get_model_path(model_id)
|
||||
if not model_path:
|
||||
logger.error(f"Model {model_id} not found")
|
||||
return None
|
||||
|
||||
pipeline_path = model_path / "pipeline.json"
|
||||
if not pipeline_path.exists():
|
||||
logger.error(f"pipeline.json not found for model {model_id}")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(pipeline_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
logger.debug(f"Loaded pipeline config for model {model_id}")
|
||||
return config
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Invalid JSON in pipeline.json for model {model_id}: {str(e)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load pipeline.json for model {model_id}: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_model_file_path(self, model_id: int, filename: str) -> Optional[Path]:
|
||||
"""
|
||||
Get the full path to a model file (e.g., .pt file)
|
||||
|
||||
Args:
|
||||
model_id: The model ID
|
||||
filename: The filename within the model directory
|
||||
|
||||
Returns:
|
||||
Full path to the model file or None if not found
|
||||
"""
|
||||
model_path = self.get_model_path(model_id)
|
||||
if not model_path:
|
||||
return None
|
||||
|
||||
file_path = model_path / filename
|
||||
if not file_path.exists():
|
||||
logger.error(f"Model file {filename} not found in model {model_id}")
|
||||
return None
|
||||
|
||||
return file_path
|
||||
|
||||
def cleanup_model(self, model_id: int) -> bool:
|
||||
"""
|
||||
Remove a downloaded model to free up space
|
||||
|
||||
Args:
|
||||
model_id: The model ID to remove
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
if model_id not in self._downloaded_models:
|
||||
logger.warning(f"Model {model_id} not in downloaded models")
|
||||
return False
|
||||
|
||||
try:
|
||||
model_dir = self.models_dir / str(model_id)
|
||||
if model_dir.exists():
|
||||
import shutil
|
||||
shutil.rmtree(model_dir)
|
||||
logger.info(f"Removed model directory: {model_dir}")
|
||||
|
||||
self._downloaded_models.discard(model_id)
|
||||
self._model_paths.pop(model_id, None)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup model {model_id}: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_all_downloaded_models(self) -> Set[int]:
|
||||
"""
|
||||
Get a set of all downloaded model IDs
|
||||
|
||||
Returns:
|
||||
Set of model IDs that are currently downloaded
|
||||
"""
|
||||
return self._downloaded_models.copy()
|
||||
|
||||
def get_pipeline_config(self, model_id: int) -> Optional[Any]:
|
||||
"""
|
||||
Get the pipeline configuration for a model.
|
||||
|
||||
Args:
|
||||
model_id: The model ID
|
||||
|
||||
Returns:
|
||||
PipelineConfig object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
if model_id not in self._downloaded_models:
|
||||
logger.warning(f"Model {model_id} not downloaded")
|
||||
return None
|
||||
|
||||
model_path = self._model_paths.get(model_id)
|
||||
if not model_path:
|
||||
logger.warning(f"Model path not found for model {model_id}")
|
||||
return None
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .pipeline import PipelineParser
|
||||
|
||||
# Load pipeline.json
|
||||
pipeline_file = model_path / "pipeline.json"
|
||||
if not pipeline_file.exists():
|
||||
logger.warning(f"No pipeline.json found for model {model_id}")
|
||||
return None
|
||||
|
||||
# Create PipelineParser object and parse the configuration
|
||||
pipeline_parser = PipelineParser()
|
||||
success = pipeline_parser.parse(pipeline_file)
|
||||
|
||||
if success:
|
||||
return pipeline_parser
|
||||
else:
|
||||
logger.error(f"Failed to parse pipeline.json for model {model_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting pipeline config for model {model_id}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def get_yolo_model(self, model_id: int, model_filename: str) -> Optional[Any]:
|
||||
"""
|
||||
Create a YOLOWrapper instance for a specific model file.
|
||||
|
||||
Args:
|
||||
model_id: The model ID
|
||||
model_filename: The .pt model filename
|
||||
|
||||
Returns:
|
||||
YOLOWrapper instance if successful, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Get the model file path
|
||||
model_file_path = self.get_model_file_path(model_id, model_filename)
|
||||
if not model_file_path or not model_file_path.exists():
|
||||
logger.error(f"Model file {model_filename} not found for model {model_id}")
|
||||
return None
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .inference import YOLOWrapper
|
||||
|
||||
# Create YOLOWrapper instance
|
||||
yolo_model = YOLOWrapper(
|
||||
model_path=model_file_path,
|
||||
model_id=f"{model_id}_{model_filename}",
|
||||
device=None # Auto-detect device
|
||||
)
|
||||
|
||||
logger.info(f"Created YOLOWrapper for model {model_id}: {model_filename}")
|
||||
return yolo_model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating YOLO model for {model_id}:{model_filename}: {e}", exc_info=True)
|
||||
return None
|
||||
|
|
@ -1,373 +0,0 @@
|
|||
"""
|
||||
Pipeline Configuration Parser - Handles pipeline.json parsing and validation
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any, Optional, Set
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActionType(Enum):
|
||||
"""Supported action types in pipeline"""
|
||||
REDIS_SAVE_IMAGE = "redis_save_image"
|
||||
REDIS_PUBLISH = "redis_publish"
|
||||
# PostgreSQL actions below are DEPRECATED - kept for backward compatibility only
|
||||
# These actions will be silently skipped during pipeline execution
|
||||
POSTGRESQL_UPDATE = "postgresql_update"
|
||||
POSTGRESQL_UPDATE_COMBINED = "postgresql_update_combined"
|
||||
POSTGRESQL_INSERT = "postgresql_insert"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RedisConfig:
|
||||
"""Redis connection configuration"""
|
||||
host: str
|
||||
port: int = 6379
|
||||
password: Optional[str] = None
|
||||
db: int = 0
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'RedisConfig':
|
||||
return cls(
|
||||
host=data['host'],
|
||||
port=data.get('port', 6379),
|
||||
password=data.get('password'),
|
||||
db=data.get('db', 0)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PostgreSQLConfig:
|
||||
"""
|
||||
PostgreSQL connection configuration - DISABLED
|
||||
|
||||
NOTE: This configuration is kept for backward compatibility with existing
|
||||
pipeline.json files, but PostgreSQL operations are disabled. All database
|
||||
operations have been moved to microservices architecture.
|
||||
|
||||
This config will be parsed but not used for any database connections.
|
||||
"""
|
||||
host: str
|
||||
port: int
|
||||
database: str
|
||||
username: str
|
||||
password: str
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'PostgreSQLConfig':
|
||||
"""Parse PostgreSQL config from dict (kept for backward compatibility)"""
|
||||
return cls(
|
||||
host=data['host'],
|
||||
port=data.get('port', 5432),
|
||||
database=data['database'],
|
||||
username=data['username'],
|
||||
password=data['password']
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Action:
|
||||
"""Represents an action in the pipeline"""
|
||||
type: ActionType
|
||||
params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'Action':
|
||||
action_type = ActionType(data['type'])
|
||||
params = {k: v for k, v in data.items() if k != 'type'}
|
||||
return cls(type=action_type, params=params)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelBranch:
|
||||
"""Represents a branch in the pipeline with its own model"""
|
||||
model_id: str
|
||||
model_file: str
|
||||
trigger_classes: List[str]
|
||||
min_confidence: float = 0.5
|
||||
crop: bool = False
|
||||
crop_class: Optional[Any] = None # Can be string or list
|
||||
parallel: bool = False
|
||||
actions: List[Action] = field(default_factory=list)
|
||||
branches: List['ModelBranch'] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ModelBranch':
|
||||
actions = [Action.from_dict(a) for a in data.get('actions', [])]
|
||||
branches = [cls.from_dict(b) for b in data.get('branches', [])]
|
||||
|
||||
return cls(
|
||||
model_id=data['modelId'],
|
||||
model_file=data['modelFile'],
|
||||
trigger_classes=data.get('triggerClasses', []),
|
||||
min_confidence=data.get('minConfidence', 0.5),
|
||||
crop=data.get('crop', False),
|
||||
crop_class=data.get('cropClass'),
|
||||
parallel=data.get('parallel', False),
|
||||
actions=actions,
|
||||
branches=branches
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrackingConfig:
|
||||
"""Configuration for the tracking phase"""
|
||||
model_id: str
|
||||
model_file: str
|
||||
trigger_classes: List[str]
|
||||
min_confidence: float = 0.6
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'TrackingConfig':
|
||||
return cls(
|
||||
model_id=data['modelId'],
|
||||
model_file=data['modelFile'],
|
||||
trigger_classes=data.get('triggerClasses', []),
|
||||
min_confidence=data.get('minConfidence', 0.6)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineConfig:
|
||||
"""Main pipeline configuration"""
|
||||
model_id: str
|
||||
model_file: str
|
||||
trigger_classes: List[str]
|
||||
min_confidence: float = 0.5
|
||||
crop: bool = False
|
||||
branches: List[ModelBranch] = field(default_factory=list)
|
||||
parallel_actions: List[Action] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'PipelineConfig':
|
||||
branches = [ModelBranch.from_dict(b) for b in data.get('branches', [])]
|
||||
parallel_actions = [Action.from_dict(a) for a in data.get('parallelActions', [])]
|
||||
|
||||
return cls(
|
||||
model_id=data['modelId'],
|
||||
model_file=data['modelFile'],
|
||||
trigger_classes=data.get('triggerClasses', []),
|
||||
min_confidence=data.get('minConfidence', 0.5),
|
||||
crop=data.get('crop', False),
|
||||
branches=branches,
|
||||
parallel_actions=parallel_actions
|
||||
)
|
||||
|
||||
|
||||
class PipelineParser:
|
||||
"""Parser for pipeline.json configuration files"""
|
||||
|
||||
def __init__(self):
|
||||
self.redis_config: Optional[RedisConfig] = None
|
||||
self.postgresql_config: Optional[PostgreSQLConfig] = None
|
||||
self.tracking_config: Optional[TrackingConfig] = None
|
||||
self.pipeline_config: Optional[PipelineConfig] = None
|
||||
self._model_dependencies: Set[str] = set()
|
||||
|
||||
def parse(self, config_path: Path) -> bool:
|
||||
"""
|
||||
Parse a pipeline.json configuration file
|
||||
|
||||
Args:
|
||||
config_path: Path to the pipeline.json file
|
||||
|
||||
Returns:
|
||||
True if parsing was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not config_path.exists():
|
||||
logger.error(f"Pipeline config not found: {config_path}")
|
||||
return False
|
||||
|
||||
with open(config_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
return self.parse_dict(data)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Invalid JSON in pipeline config: {str(e)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse pipeline config: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def parse_dict(self, data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Parse a pipeline configuration from a dictionary
|
||||
|
||||
Args:
|
||||
data: The configuration dictionary
|
||||
|
||||
Returns:
|
||||
True if parsing was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Parse Redis configuration
|
||||
if 'redis' in data:
|
||||
self.redis_config = RedisConfig.from_dict(data['redis'])
|
||||
logger.debug(f"Parsed Redis config: {self.redis_config.host}:{self.redis_config.port}")
|
||||
|
||||
# Parse PostgreSQL configuration
|
||||
if 'postgresql' in data:
|
||||
self.postgresql_config = PostgreSQLConfig.from_dict(data['postgresql'])
|
||||
logger.debug(f"Parsed PostgreSQL config: {self.postgresql_config.host}:{self.postgresql_config.port}/{self.postgresql_config.database}")
|
||||
|
||||
# Parse tracking configuration
|
||||
if 'tracking' in data:
|
||||
self.tracking_config = TrackingConfig.from_dict(data['tracking'])
|
||||
self._model_dependencies.add(self.tracking_config.model_file)
|
||||
logger.debug(f"Parsed tracking config: {self.tracking_config.model_id}")
|
||||
|
||||
# Parse main pipeline configuration
|
||||
if 'pipeline' in data:
|
||||
self.pipeline_config = PipelineConfig.from_dict(data['pipeline'])
|
||||
self._collect_model_dependencies(self.pipeline_config)
|
||||
logger.debug(f"Parsed pipeline config: {self.pipeline_config.model_id}")
|
||||
|
||||
logger.info(f"Successfully parsed pipeline configuration")
|
||||
logger.debug(f"Model dependencies: {self._model_dependencies}")
|
||||
return True
|
||||
|
||||
except KeyError as e:
|
||||
logger.error(f"Missing required field in pipeline config: {str(e)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse pipeline config: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def _collect_model_dependencies(self, config: Any) -> None:
|
||||
"""
|
||||
Recursively collect all model file dependencies
|
||||
|
||||
Args:
|
||||
config: Pipeline or branch configuration
|
||||
"""
|
||||
if hasattr(config, 'model_file'):
|
||||
self._model_dependencies.add(config.model_file)
|
||||
|
||||
if hasattr(config, 'branches'):
|
||||
for branch in config.branches:
|
||||
self._collect_model_dependencies(branch)
|
||||
|
||||
def get_model_dependencies(self) -> Set[str]:
|
||||
"""
|
||||
Get all model file dependencies from the pipeline
|
||||
|
||||
Returns:
|
||||
Set of model filenames required by the pipeline
|
||||
"""
|
||||
return self._model_dependencies.copy()
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""
|
||||
Validate the parsed configuration
|
||||
|
||||
Returns:
|
||||
True if configuration is valid, False otherwise
|
||||
"""
|
||||
if not self.pipeline_config:
|
||||
logger.error("No pipeline configuration found")
|
||||
return False
|
||||
|
||||
# Check that all required model files are specified
|
||||
if not self.pipeline_config.model_file:
|
||||
logger.error("Main pipeline model file not specified")
|
||||
return False
|
||||
|
||||
# Validate action configurations
|
||||
if not self._validate_actions(self.pipeline_config):
|
||||
return False
|
||||
|
||||
# Validate parallel actions (PostgreSQL actions are skipped)
|
||||
for action in self.pipeline_config.parallel_actions:
|
||||
if action.type == ActionType.POSTGRESQL_UPDATE_COMBINED:
|
||||
logger.warning(f"PostgreSQL parallel action {action.type.value} found but will be SKIPPED (PostgreSQL disabled)")
|
||||
# Skip validation for PostgreSQL actions since they won't be executed
|
||||
# wait_for = action.params.get('waitForBranches', [])
|
||||
# if wait_for:
|
||||
# # Check that referenced branches exist
|
||||
# branch_ids = self._get_all_branch_ids(self.pipeline_config)
|
||||
# for branch_id in wait_for:
|
||||
# if branch_id not in branch_ids:
|
||||
# logger.error(f"Referenced branch '{branch_id}' in waitForBranches not found")
|
||||
# return False
|
||||
|
||||
logger.info("Pipeline configuration validated successfully")
|
||||
return True
|
||||
|
||||
def _validate_actions(self, config: Any) -> bool:
|
||||
"""
|
||||
Validate actions in a pipeline or branch configuration
|
||||
|
||||
Args:
|
||||
config: Pipeline or branch configuration
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
if hasattr(config, 'actions'):
|
||||
for action in config.actions:
|
||||
# Validate Redis actions need Redis config
|
||||
if action.type in [ActionType.REDIS_SAVE_IMAGE, ActionType.REDIS_PUBLISH]:
|
||||
if not self.redis_config:
|
||||
logger.error(f"Action {action.type} requires Redis configuration")
|
||||
return False
|
||||
|
||||
# PostgreSQL actions are disabled - log warning instead of failing
|
||||
# Kept for backward compatibility with existing pipeline.json files
|
||||
if action.type in [ActionType.POSTGRESQL_UPDATE, ActionType.POSTGRESQL_UPDATE_COMBINED, ActionType.POSTGRESQL_INSERT]:
|
||||
logger.warning(f"PostgreSQL action {action.type.value} found but will be SKIPPED (PostgreSQL disabled)")
|
||||
# Do not fail validation - just skip these actions during execution
|
||||
# if not self.postgresql_config:
|
||||
# logger.error(f"Action {action.type} requires PostgreSQL configuration")
|
||||
# return False
|
||||
|
||||
# Recursively validate branches
|
||||
if hasattr(config, 'branches'):
|
||||
for branch in config.branches:
|
||||
if not self._validate_actions(branch):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_all_branch_ids(self, config: Any, branch_ids: Set[str] = None) -> Set[str]:
|
||||
"""
|
||||
Recursively collect all branch model IDs
|
||||
|
||||
Args:
|
||||
config: Pipeline or branch configuration
|
||||
branch_ids: Set to collect IDs into
|
||||
|
||||
Returns:
|
||||
Set of all branch model IDs
|
||||
"""
|
||||
if branch_ids is None:
|
||||
branch_ids = set()
|
||||
|
||||
if hasattr(config, 'branches'):
|
||||
for branch in config.branches:
|
||||
branch_ids.add(branch.model_id)
|
||||
self._get_all_branch_ids(branch, branch_ids)
|
||||
|
||||
return branch_ids
|
||||
|
||||
def get_redis_config(self) -> Optional[RedisConfig]:
|
||||
"""Get the Redis configuration"""
|
||||
return self.redis_config
|
||||
|
||||
def get_postgresql_config(self) -> Optional[PostgreSQLConfig]:
|
||||
"""Get the PostgreSQL configuration"""
|
||||
return self.postgresql_config
|
||||
|
||||
def get_tracking_config(self) -> Optional[TrackingConfig]:
|
||||
"""Get the tracking configuration"""
|
||||
return self.tracking_config
|
||||
|
||||
def get_pipeline_config(self) -> Optional[PipelineConfig]:
|
||||
"""Get the main pipeline configuration"""
|
||||
return self.pipeline_config
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
"""
|
||||
Comprehensive health monitoring system for detector worker.
|
||||
Tracks stream health, thread responsiveness, and system performance.
|
||||
"""
|
||||
|
||||
from .health import HealthMonitor, HealthStatus, HealthCheck
|
||||
from .stream_health import StreamHealthTracker
|
||||
from .thread_health import ThreadHealthMonitor
|
||||
from .recovery import RecoveryManager
|
||||
|
||||
__all__ = [
|
||||
'HealthMonitor',
|
||||
'HealthStatus',
|
||||
'HealthCheck',
|
||||
'StreamHealthTracker',
|
||||
'ThreadHealthMonitor',
|
||||
'RecoveryManager'
|
||||
]
|
||||
|
|
@ -1,456 +0,0 @@
|
|||
"""
|
||||
Core health monitoring system for comprehensive stream and system health tracking.
|
||||
Provides centralized health status, alerting, and recovery coordination.
|
||||
"""
|
||||
import time
|
||||
import threading
|
||||
import logging
|
||||
import psutil
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from collections import defaultdict, deque
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HealthStatus(Enum):
|
||||
"""Health status levels."""
|
||||
HEALTHY = "healthy"
|
||||
WARNING = "warning"
|
||||
CRITICAL = "critical"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HealthCheck:
|
||||
"""Individual health check result."""
|
||||
name: str
|
||||
status: HealthStatus
|
||||
message: str
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
details: Dict[str, Any] = field(default_factory=dict)
|
||||
recovery_action: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class HealthMetrics:
|
||||
"""Health metrics for a component."""
|
||||
component_id: str
|
||||
last_update: float
|
||||
frame_count: int = 0
|
||||
error_count: int = 0
|
||||
warning_count: int = 0
|
||||
restart_count: int = 0
|
||||
avg_frame_interval: float = 0.0
|
||||
last_frame_time: Optional[float] = None
|
||||
thread_alive: bool = True
|
||||
connection_healthy: bool = True
|
||||
memory_usage_mb: float = 0.0
|
||||
cpu_usage_percent: float = 0.0
|
||||
|
||||
|
||||
class HealthMonitor:
|
||||
"""Comprehensive health monitoring system."""
|
||||
|
||||
def __init__(self, check_interval: float = 30.0):
|
||||
"""
|
||||
Initialize health monitor.
|
||||
|
||||
Args:
|
||||
check_interval: Interval between health checks in seconds
|
||||
"""
|
||||
self.check_interval = check_interval
|
||||
self.running = False
|
||||
self.monitor_thread = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Health data storage
|
||||
self.health_checks: Dict[str, HealthCheck] = {}
|
||||
self.metrics: Dict[str, HealthMetrics] = {}
|
||||
self.alert_history: deque = deque(maxlen=1000)
|
||||
self.recovery_actions: deque = deque(maxlen=500)
|
||||
|
||||
# Thresholds (configurable)
|
||||
self.thresholds = {
|
||||
'frame_stale_warning_seconds': 120, # 2 minutes
|
||||
'frame_stale_critical_seconds': 300, # 5 minutes
|
||||
'thread_unresponsive_seconds': 60, # 1 minute
|
||||
'memory_warning_mb': 500, # 500MB per stream
|
||||
'memory_critical_mb': 1000, # 1GB per stream
|
||||
'cpu_warning_percent': 80, # 80% CPU
|
||||
'cpu_critical_percent': 95, # 95% CPU
|
||||
'error_rate_warning': 0.1, # 10% error rate
|
||||
'error_rate_critical': 0.3, # 30% error rate
|
||||
'restart_threshold': 3 # Max restarts per hour
|
||||
}
|
||||
|
||||
# Health check functions
|
||||
self.health_checkers: List[Callable[[], List[HealthCheck]]] = []
|
||||
self.recovery_callbacks: Dict[str, Callable[[str, HealthCheck], bool]] = {}
|
||||
|
||||
# System monitoring
|
||||
self.process = psutil.Process()
|
||||
self.system_start_time = time.time()
|
||||
|
||||
def start(self):
|
||||
"""Start health monitoring."""
|
||||
if self.running:
|
||||
logger.warning("Health monitor already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
||||
self.monitor_thread.start()
|
||||
logger.info(f"Health monitor started (check interval: {self.check_interval}s)")
|
||||
|
||||
def stop(self):
|
||||
"""Stop health monitoring."""
|
||||
self.running = False
|
||||
if self.monitor_thread:
|
||||
self.monitor_thread.join(timeout=5.0)
|
||||
logger.info("Health monitor stopped")
|
||||
|
||||
def register_health_checker(self, checker: Callable[[], List[HealthCheck]]):
|
||||
"""Register a health check function."""
|
||||
self.health_checkers.append(checker)
|
||||
logger.debug(f"Registered health checker: {checker.__name__}")
|
||||
|
||||
def register_recovery_callback(self, component: str, callback: Callable[[str, HealthCheck], bool]):
|
||||
"""Register a recovery callback for a component."""
|
||||
self.recovery_callbacks[component] = callback
|
||||
logger.debug(f"Registered recovery callback for {component}")
|
||||
|
||||
def update_metrics(self, component_id: str, **kwargs):
|
||||
"""Update metrics for a component."""
|
||||
with self._lock:
|
||||
if component_id not in self.metrics:
|
||||
self.metrics[component_id] = HealthMetrics(
|
||||
component_id=component_id,
|
||||
last_update=time.time()
|
||||
)
|
||||
|
||||
metrics = self.metrics[component_id]
|
||||
metrics.last_update = time.time()
|
||||
|
||||
# Update provided metrics
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(metrics, key):
|
||||
setattr(metrics, key, value)
|
||||
|
||||
def report_frame_received(self, component_id: str):
|
||||
"""Report that a frame was received for a component."""
|
||||
current_time = time.time()
|
||||
with self._lock:
|
||||
if component_id not in self.metrics:
|
||||
self.metrics[component_id] = HealthMetrics(
|
||||
component_id=component_id,
|
||||
last_update=current_time
|
||||
)
|
||||
|
||||
metrics = self.metrics[component_id]
|
||||
|
||||
# Update frame metrics
|
||||
if metrics.last_frame_time:
|
||||
interval = current_time - metrics.last_frame_time
|
||||
# Moving average of frame intervals
|
||||
if metrics.avg_frame_interval == 0:
|
||||
metrics.avg_frame_interval = interval
|
||||
else:
|
||||
metrics.avg_frame_interval = (metrics.avg_frame_interval * 0.9) + (interval * 0.1)
|
||||
|
||||
metrics.last_frame_time = current_time
|
||||
metrics.frame_count += 1
|
||||
metrics.last_update = current_time
|
||||
|
||||
def report_error(self, component_id: str, error_type: str = "general"):
|
||||
"""Report an error for a component."""
|
||||
with self._lock:
|
||||
if component_id not in self.metrics:
|
||||
self.metrics[component_id] = HealthMetrics(
|
||||
component_id=component_id,
|
||||
last_update=time.time()
|
||||
)
|
||||
|
||||
self.metrics[component_id].error_count += 1
|
||||
self.metrics[component_id].last_update = time.time()
|
||||
|
||||
logger.debug(f"Error reported for {component_id}: {error_type}")
|
||||
|
||||
def report_warning(self, component_id: str, warning_type: str = "general"):
|
||||
"""Report a warning for a component."""
|
||||
with self._lock:
|
||||
if component_id not in self.metrics:
|
||||
self.metrics[component_id] = HealthMetrics(
|
||||
component_id=component_id,
|
||||
last_update=time.time()
|
||||
)
|
||||
|
||||
self.metrics[component_id].warning_count += 1
|
||||
self.metrics[component_id].last_update = time.time()
|
||||
|
||||
logger.debug(f"Warning reported for {component_id}: {warning_type}")
|
||||
|
||||
def report_restart(self, component_id: str):
|
||||
"""Report that a component was restarted."""
|
||||
with self._lock:
|
||||
if component_id not in self.metrics:
|
||||
self.metrics[component_id] = HealthMetrics(
|
||||
component_id=component_id,
|
||||
last_update=time.time()
|
||||
)
|
||||
|
||||
self.metrics[component_id].restart_count += 1
|
||||
self.metrics[component_id].last_update = time.time()
|
||||
|
||||
# Log recovery action
|
||||
recovery_action = {
|
||||
'timestamp': time.time(),
|
||||
'component': component_id,
|
||||
'action': 'restart',
|
||||
'reason': 'manual_restart'
|
||||
}
|
||||
|
||||
with self._lock:
|
||||
self.recovery_actions.append(recovery_action)
|
||||
|
||||
logger.info(f"Restart reported for {component_id}")
|
||||
|
||||
def get_health_status(self, component_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get comprehensive health status."""
|
||||
with self._lock:
|
||||
if component_id:
|
||||
# Get health for specific component
|
||||
return self._get_component_health(component_id)
|
||||
else:
|
||||
# Get overall health status
|
||||
return self._get_overall_health()
|
||||
|
||||
def _get_component_health(self, component_id: str) -> Dict[str, Any]:
|
||||
"""Get health status for a specific component."""
|
||||
if component_id not in self.metrics:
|
||||
return {
|
||||
'component_id': component_id,
|
||||
'status': HealthStatus.UNKNOWN.value,
|
||||
'message': 'No metrics available',
|
||||
'metrics': {}
|
||||
}
|
||||
|
||||
metrics = self.metrics[component_id]
|
||||
current_time = time.time()
|
||||
|
||||
# Determine health status
|
||||
status = HealthStatus.HEALTHY
|
||||
issues = []
|
||||
|
||||
# Check frame freshness
|
||||
if metrics.last_frame_time:
|
||||
frame_age = current_time - metrics.last_frame_time
|
||||
if frame_age > self.thresholds['frame_stale_critical_seconds']:
|
||||
status = HealthStatus.CRITICAL
|
||||
issues.append(f"Frames stale for {frame_age:.1f}s")
|
||||
elif frame_age > self.thresholds['frame_stale_warning_seconds']:
|
||||
if status == HealthStatus.HEALTHY:
|
||||
status = HealthStatus.WARNING
|
||||
issues.append(f"Frames aging ({frame_age:.1f}s)")
|
||||
|
||||
# Check error rates
|
||||
if metrics.frame_count > 0:
|
||||
error_rate = metrics.error_count / metrics.frame_count
|
||||
if error_rate > self.thresholds['error_rate_critical']:
|
||||
status = HealthStatus.CRITICAL
|
||||
issues.append(f"High error rate ({error_rate:.1%})")
|
||||
elif error_rate > self.thresholds['error_rate_warning']:
|
||||
if status == HealthStatus.HEALTHY:
|
||||
status = HealthStatus.WARNING
|
||||
issues.append(f"Elevated error rate ({error_rate:.1%})")
|
||||
|
||||
# Check restart frequency
|
||||
restart_rate = metrics.restart_count / max(1, (current_time - self.system_start_time) / 3600)
|
||||
if restart_rate > self.thresholds['restart_threshold']:
|
||||
status = HealthStatus.CRITICAL
|
||||
issues.append(f"Frequent restarts ({restart_rate:.1f}/hour)")
|
||||
|
||||
# Check thread health
|
||||
if not metrics.thread_alive:
|
||||
status = HealthStatus.CRITICAL
|
||||
issues.append("Thread not alive")
|
||||
|
||||
# Check connection health
|
||||
if not metrics.connection_healthy:
|
||||
if status == HealthStatus.HEALTHY:
|
||||
status = HealthStatus.WARNING
|
||||
issues.append("Connection unhealthy")
|
||||
|
||||
return {
|
||||
'component_id': component_id,
|
||||
'status': status.value,
|
||||
'message': '; '.join(issues) if issues else 'All checks passing',
|
||||
'metrics': {
|
||||
'frame_count': metrics.frame_count,
|
||||
'error_count': metrics.error_count,
|
||||
'warning_count': metrics.warning_count,
|
||||
'restart_count': metrics.restart_count,
|
||||
'avg_frame_interval': metrics.avg_frame_interval,
|
||||
'last_frame_age': current_time - metrics.last_frame_time if metrics.last_frame_time else None,
|
||||
'thread_alive': metrics.thread_alive,
|
||||
'connection_healthy': metrics.connection_healthy,
|
||||
'memory_usage_mb': metrics.memory_usage_mb,
|
||||
'cpu_usage_percent': metrics.cpu_usage_percent,
|
||||
'uptime_seconds': current_time - self.system_start_time
|
||||
},
|
||||
'last_update': metrics.last_update
|
||||
}
|
||||
|
||||
def _get_overall_health(self) -> Dict[str, Any]:
|
||||
"""Get overall system health status."""
|
||||
current_time = time.time()
|
||||
components = {}
|
||||
overall_status = HealthStatus.HEALTHY
|
||||
|
||||
# Get health for all components
|
||||
for component_id in self.metrics.keys():
|
||||
component_health = self._get_component_health(component_id)
|
||||
components[component_id] = component_health
|
||||
|
||||
# Determine overall status
|
||||
component_status = HealthStatus(component_health['status'])
|
||||
if component_status == HealthStatus.CRITICAL:
|
||||
overall_status = HealthStatus.CRITICAL
|
||||
elif component_status == HealthStatus.WARNING and overall_status == HealthStatus.HEALTHY:
|
||||
overall_status = HealthStatus.WARNING
|
||||
|
||||
# System metrics
|
||||
try:
|
||||
system_memory = self.process.memory_info()
|
||||
system_cpu = self.process.cpu_percent()
|
||||
except Exception:
|
||||
system_memory = None
|
||||
system_cpu = 0.0
|
||||
|
||||
return {
|
||||
'overall_status': overall_status.value,
|
||||
'timestamp': current_time,
|
||||
'uptime_seconds': current_time - self.system_start_time,
|
||||
'total_components': len(self.metrics),
|
||||
'components': components,
|
||||
'system_metrics': {
|
||||
'memory_mb': system_memory.rss / (1024 * 1024) if system_memory else 0,
|
||||
'cpu_percent': system_cpu,
|
||||
'process_id': self.process.pid
|
||||
},
|
||||
'recent_alerts': list(self.alert_history)[-10:], # Last 10 alerts
|
||||
'recent_recoveries': list(self.recovery_actions)[-10:] # Last 10 recovery actions
|
||||
}
|
||||
|
||||
def _monitor_loop(self):
|
||||
"""Main health monitoring loop."""
|
||||
logger.info("Health monitor loop started")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Run all registered health checks
|
||||
all_checks = []
|
||||
for checker in self.health_checkers:
|
||||
try:
|
||||
checks = checker()
|
||||
all_checks.extend(checks)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in health checker {checker.__name__}: {e}")
|
||||
|
||||
# Process health checks and trigger recovery if needed
|
||||
for check in all_checks:
|
||||
self._process_health_check(check)
|
||||
|
||||
# Update system metrics
|
||||
self._update_system_metrics()
|
||||
|
||||
# Sleep until next check
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, self.check_interval - elapsed)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in health monitor loop: {e}")
|
||||
time.sleep(5.0) # Fallback sleep
|
||||
|
||||
logger.info("Health monitor loop ended")
|
||||
|
||||
def _process_health_check(self, check: HealthCheck):
|
||||
"""Process a health check result and trigger recovery if needed."""
|
||||
with self._lock:
|
||||
# Store health check
|
||||
self.health_checks[check.name] = check
|
||||
|
||||
# Log alerts for non-healthy status
|
||||
if check.status != HealthStatus.HEALTHY:
|
||||
alert = {
|
||||
'timestamp': check.timestamp,
|
||||
'component': check.name,
|
||||
'status': check.status.value,
|
||||
'message': check.message,
|
||||
'details': check.details
|
||||
}
|
||||
self.alert_history.append(alert)
|
||||
|
||||
logger.warning(f"Health alert [{check.status.value.upper()}] {check.name}: {check.message}")
|
||||
|
||||
# Trigger recovery if critical and recovery action available
|
||||
if check.status == HealthStatus.CRITICAL and check.recovery_action:
|
||||
self._trigger_recovery(check.name, check)
|
||||
|
||||
def _trigger_recovery(self, component: str, check: HealthCheck):
|
||||
"""Trigger recovery action for a component."""
|
||||
if component in self.recovery_callbacks:
|
||||
try:
|
||||
logger.info(f"Triggering recovery for {component}: {check.recovery_action}")
|
||||
|
||||
success = self.recovery_callbacks[component](component, check)
|
||||
|
||||
recovery_action = {
|
||||
'timestamp': time.time(),
|
||||
'component': component,
|
||||
'action': check.recovery_action,
|
||||
'reason': check.message,
|
||||
'success': success
|
||||
}
|
||||
|
||||
with self._lock:
|
||||
self.recovery_actions.append(recovery_action)
|
||||
|
||||
if success:
|
||||
logger.info(f"Recovery successful for {component}")
|
||||
else:
|
||||
logger.error(f"Recovery failed for {component}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in recovery callback for {component}: {e}")
|
||||
|
||||
def _update_system_metrics(self):
|
||||
"""Update system-level metrics."""
|
||||
try:
|
||||
# Update process metrics for all components
|
||||
current_time = time.time()
|
||||
|
||||
with self._lock:
|
||||
for component_id, metrics in self.metrics.items():
|
||||
# Update CPU and memory if available
|
||||
try:
|
||||
# This is a simplified approach - in practice you'd want
|
||||
# per-thread or per-component resource tracking
|
||||
metrics.cpu_usage_percent = self.process.cpu_percent() / len(self.metrics)
|
||||
memory_info = self.process.memory_info()
|
||||
metrics.memory_usage_mb = memory_info.rss / (1024 * 1024) / len(self.metrics)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating system metrics: {e}")
|
||||
|
||||
|
||||
# Global health monitor instance
|
||||
health_monitor = HealthMonitor()
|
||||
|
|
@ -1,385 +0,0 @@
|
|||
"""
|
||||
Recovery manager for automatic handling of health issues.
|
||||
Provides circuit breaker patterns, automatic restarts, and graceful degradation.
|
||||
"""
|
||||
import time
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from collections import defaultdict, deque
|
||||
|
||||
from .health import HealthCheck, HealthStatus, health_monitor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RecoveryAction(Enum):
|
||||
"""Types of recovery actions."""
|
||||
RESTART_STREAM = "restart_stream"
|
||||
RESTART_THREAD = "restart_thread"
|
||||
CLEAR_BUFFER = "clear_buffer"
|
||||
RECONNECT = "reconnect"
|
||||
THROTTLE = "throttle"
|
||||
DISABLE = "disable"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecoveryAttempt:
|
||||
"""Record of a recovery attempt."""
|
||||
timestamp: float
|
||||
component: str
|
||||
action: RecoveryAction
|
||||
reason: str
|
||||
success: bool
|
||||
details: Dict[str, Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecoveryState:
|
||||
"""Recovery state for a component - simplified without circuit breaker."""
|
||||
failure_count: int = 0
|
||||
success_count: int = 0
|
||||
last_failure_time: Optional[float] = None
|
||||
last_success_time: Optional[float] = None
|
||||
|
||||
|
||||
class RecoveryManager:
|
||||
"""Manages automatic recovery actions for health issues."""
|
||||
|
||||
def __init__(self):
|
||||
self.recovery_handlers: Dict[str, Callable[[str, HealthCheck], bool]] = {}
|
||||
self.recovery_states: Dict[str, RecoveryState] = {}
|
||||
self.recovery_history: deque = deque(maxlen=1000)
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Configuration - simplified without circuit breaker
|
||||
self.recovery_cooldown = 30 # 30 seconds between recovery attempts
|
||||
self.max_attempts_per_hour = 20 # Still limit to prevent spam, but much higher
|
||||
|
||||
# Track recovery attempts per component
|
||||
self.recovery_attempts: Dict[str, deque] = defaultdict(lambda: deque(maxlen=50))
|
||||
|
||||
# Register with health monitor
|
||||
health_monitor.register_recovery_callback("stream", self._handle_stream_recovery)
|
||||
health_monitor.register_recovery_callback("thread", self._handle_thread_recovery)
|
||||
health_monitor.register_recovery_callback("buffer", self._handle_buffer_recovery)
|
||||
|
||||
def register_recovery_handler(self, action: RecoveryAction, handler: Callable[[str, Dict[str, Any]], bool]):
|
||||
"""
|
||||
Register a recovery handler for a specific action.
|
||||
|
||||
Args:
|
||||
action: Type of recovery action
|
||||
handler: Function that performs the recovery
|
||||
"""
|
||||
self.recovery_handlers[action.value] = handler
|
||||
logger.info(f"Registered recovery handler for {action.value}")
|
||||
|
||||
def can_attempt_recovery(self, component: str) -> bool:
|
||||
"""
|
||||
Check if recovery can be attempted for a component.
|
||||
|
||||
Args:
|
||||
component: Component identifier
|
||||
|
||||
Returns:
|
||||
True if recovery can be attempted (always allow with minimal throttling)
|
||||
"""
|
||||
with self._lock:
|
||||
current_time = time.time()
|
||||
|
||||
# Check recovery attempt rate limiting (much more permissive)
|
||||
recent_attempts = [
|
||||
attempt for attempt in self.recovery_attempts[component]
|
||||
if current_time - attempt <= 3600 # Last hour
|
||||
]
|
||||
|
||||
# Only block if truly excessive attempts
|
||||
if len(recent_attempts) >= self.max_attempts_per_hour:
|
||||
logger.warning(f"Recovery rate limit exceeded for {component} "
|
||||
f"({len(recent_attempts)} attempts in last hour)")
|
||||
return False
|
||||
|
||||
# Check cooldown period (shorter cooldown)
|
||||
if recent_attempts:
|
||||
last_attempt = max(recent_attempts)
|
||||
if current_time - last_attempt < self.recovery_cooldown:
|
||||
logger.debug(f"Recovery cooldown active for {component} "
|
||||
f"(last attempt {current_time - last_attempt:.1f}s ago)")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def attempt_recovery(self, component: str, action: RecoveryAction, reason: str,
|
||||
details: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
Attempt recovery for a component.
|
||||
|
||||
Args:
|
||||
component: Component identifier
|
||||
action: Recovery action to perform
|
||||
reason: Reason for recovery
|
||||
details: Additional details
|
||||
|
||||
Returns:
|
||||
True if recovery was successful
|
||||
"""
|
||||
if not self.can_attempt_recovery(component):
|
||||
return False
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
logger.info(f"Attempting recovery for {component}: {action.value} ({reason})")
|
||||
|
||||
try:
|
||||
# Record recovery attempt
|
||||
with self._lock:
|
||||
self.recovery_attempts[component].append(current_time)
|
||||
|
||||
# Perform recovery action
|
||||
success = self._execute_recovery_action(component, action, details or {})
|
||||
|
||||
# Record recovery result
|
||||
attempt = RecoveryAttempt(
|
||||
timestamp=current_time,
|
||||
component=component,
|
||||
action=action,
|
||||
reason=reason,
|
||||
success=success,
|
||||
details=details
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self.recovery_history.append(attempt)
|
||||
|
||||
# Update recovery state
|
||||
self._update_recovery_state(component, success)
|
||||
|
||||
if success:
|
||||
logger.info(f"Recovery successful for {component}: {action.value}")
|
||||
else:
|
||||
logger.error(f"Recovery failed for {component}: {action.value}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during recovery for {component}: {e}")
|
||||
self._update_recovery_state(component, False)
|
||||
return False
|
||||
|
||||
def _execute_recovery_action(self, component: str, action: RecoveryAction,
|
||||
details: Dict[str, Any]) -> bool:
|
||||
"""Execute a specific recovery action."""
|
||||
handler_key = action.value
|
||||
|
||||
if handler_key not in self.recovery_handlers:
|
||||
logger.error(f"No recovery handler registered for action: {handler_key}")
|
||||
return False
|
||||
|
||||
try:
|
||||
handler = self.recovery_handlers[handler_key]
|
||||
return handler(component, details)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing recovery action {handler_key} for {component}: {e}")
|
||||
return False
|
||||
|
||||
def _update_recovery_state(self, component: str, success: bool):
|
||||
"""Update recovery state based on recovery result."""
|
||||
current_time = time.time()
|
||||
|
||||
with self._lock:
|
||||
if component not in self.recovery_states:
|
||||
self.recovery_states[component] = RecoveryState()
|
||||
|
||||
state = self.recovery_states[component]
|
||||
|
||||
if success:
|
||||
state.success_count += 1
|
||||
state.last_success_time = current_time
|
||||
# Reset failure count on success
|
||||
state.failure_count = max(0, state.failure_count - 1)
|
||||
logger.debug(f"Recovery success for {component} (total successes: {state.success_count})")
|
||||
else:
|
||||
state.failure_count += 1
|
||||
state.last_failure_time = current_time
|
||||
logger.debug(f"Recovery failure for {component} (total failures: {state.failure_count})")
|
||||
|
||||
def _handle_stream_recovery(self, component: str, health_check: HealthCheck) -> bool:
|
||||
"""Handle recovery for stream-related issues."""
|
||||
if "frames" in health_check.name:
|
||||
# Frame-related issue - restart stream
|
||||
return self.attempt_recovery(
|
||||
component,
|
||||
RecoveryAction.RESTART_STREAM,
|
||||
health_check.message,
|
||||
health_check.details
|
||||
)
|
||||
elif "connection" in health_check.name:
|
||||
# Connection issue - reconnect
|
||||
return self.attempt_recovery(
|
||||
component,
|
||||
RecoveryAction.RECONNECT,
|
||||
health_check.message,
|
||||
health_check.details
|
||||
)
|
||||
elif "errors" in health_check.name:
|
||||
# High error rate - throttle or restart
|
||||
return self.attempt_recovery(
|
||||
component,
|
||||
RecoveryAction.THROTTLE,
|
||||
health_check.message,
|
||||
health_check.details
|
||||
)
|
||||
else:
|
||||
# Generic stream issue - restart
|
||||
return self.attempt_recovery(
|
||||
component,
|
||||
RecoveryAction.RESTART_STREAM,
|
||||
health_check.message,
|
||||
health_check.details
|
||||
)
|
||||
|
||||
def _handle_thread_recovery(self, component: str, health_check: HealthCheck) -> bool:
|
||||
"""Handle recovery for thread-related issues."""
|
||||
if "deadlock" in health_check.name:
|
||||
# Deadlock detected - restart thread
|
||||
return self.attempt_recovery(
|
||||
component,
|
||||
RecoveryAction.RESTART_THREAD,
|
||||
health_check.message,
|
||||
health_check.details
|
||||
)
|
||||
elif "responsive" in health_check.name:
|
||||
# Thread unresponsive - restart
|
||||
return self.attempt_recovery(
|
||||
component,
|
||||
RecoveryAction.RESTART_THREAD,
|
||||
health_check.message,
|
||||
health_check.details
|
||||
)
|
||||
else:
|
||||
# Generic thread issue - restart
|
||||
return self.attempt_recovery(
|
||||
component,
|
||||
RecoveryAction.RESTART_THREAD,
|
||||
health_check.message,
|
||||
health_check.details
|
||||
)
|
||||
|
||||
def _handle_buffer_recovery(self, component: str, health_check: HealthCheck) -> bool:
|
||||
"""Handle recovery for buffer-related issues."""
|
||||
# Buffer issues - clear buffer
|
||||
return self.attempt_recovery(
|
||||
component,
|
||||
RecoveryAction.CLEAR_BUFFER,
|
||||
health_check.message,
|
||||
health_check.details
|
||||
)
|
||||
|
||||
def get_recovery_stats(self) -> Dict[str, Any]:
|
||||
"""Get recovery statistics."""
|
||||
current_time = time.time()
|
||||
|
||||
with self._lock:
|
||||
# Calculate stats from history
|
||||
recent_recoveries = [
|
||||
attempt for attempt in self.recovery_history
|
||||
if current_time - attempt.timestamp <= 3600 # Last hour
|
||||
]
|
||||
|
||||
stats_by_component = defaultdict(lambda: {
|
||||
'attempts': 0,
|
||||
'successes': 0,
|
||||
'failures': 0,
|
||||
'last_attempt': None,
|
||||
'last_success': None
|
||||
})
|
||||
|
||||
for attempt in recent_recoveries:
|
||||
stats = stats_by_component[attempt.component]
|
||||
stats['attempts'] += 1
|
||||
|
||||
if attempt.success:
|
||||
stats['successes'] += 1
|
||||
if not stats['last_success'] or attempt.timestamp > stats['last_success']:
|
||||
stats['last_success'] = attempt.timestamp
|
||||
else:
|
||||
stats['failures'] += 1
|
||||
|
||||
if not stats['last_attempt'] or attempt.timestamp > stats['last_attempt']:
|
||||
stats['last_attempt'] = attempt.timestamp
|
||||
|
||||
return {
|
||||
'total_recoveries_last_hour': len(recent_recoveries),
|
||||
'recovery_by_component': dict(stats_by_component),
|
||||
'recovery_states': {
|
||||
component: {
|
||||
'failure_count': state.failure_count,
|
||||
'success_count': state.success_count,
|
||||
'last_failure_time': state.last_failure_time,
|
||||
'last_success_time': state.last_success_time
|
||||
}
|
||||
for component, state in self.recovery_states.items()
|
||||
},
|
||||
'recent_history': [
|
||||
{
|
||||
'timestamp': attempt.timestamp,
|
||||
'component': attempt.component,
|
||||
'action': attempt.action.value,
|
||||
'reason': attempt.reason,
|
||||
'success': attempt.success
|
||||
}
|
||||
for attempt in list(self.recovery_history)[-10:] # Last 10 attempts
|
||||
]
|
||||
}
|
||||
|
||||
def force_recovery(self, component: str, action: RecoveryAction, reason: str = "manual") -> bool:
|
||||
"""
|
||||
Force recovery for a component, bypassing rate limiting.
|
||||
|
||||
Args:
|
||||
component: Component identifier
|
||||
action: Recovery action to perform
|
||||
reason: Reason for forced recovery
|
||||
|
||||
Returns:
|
||||
True if recovery was successful
|
||||
"""
|
||||
logger.info(f"Forcing recovery for {component}: {action.value} ({reason})")
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
try:
|
||||
# Execute recovery action directly
|
||||
success = self._execute_recovery_action(component, action, {})
|
||||
|
||||
# Record forced recovery
|
||||
attempt = RecoveryAttempt(
|
||||
timestamp=current_time,
|
||||
component=component,
|
||||
action=action,
|
||||
reason=f"forced: {reason}",
|
||||
success=success,
|
||||
details={'forced': True}
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self.recovery_history.append(attempt)
|
||||
self.recovery_attempts[component].append(current_time)
|
||||
|
||||
# Update recovery state
|
||||
self._update_recovery_state(component, success)
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during forced recovery for {component}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# Global recovery manager instance
|
||||
recovery_manager = RecoveryManager()
|
||||
|
|
@ -1,351 +0,0 @@
|
|||
"""
|
||||
Stream-specific health monitoring for video streams.
|
||||
Tracks frame production, connection health, and stream-specific metrics.
|
||||
"""
|
||||
import time
|
||||
import logging
|
||||
import threading
|
||||
import requests
|
||||
from typing import Dict, Optional, List, Any
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .health import HealthCheck, HealthStatus, health_monitor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamMetrics:
|
||||
"""Metrics for an individual stream."""
|
||||
camera_id: str
|
||||
stream_type: str # 'rtsp', 'http_snapshot'
|
||||
start_time: float
|
||||
last_frame_time: Optional[float] = None
|
||||
frame_count: int = 0
|
||||
error_count: int = 0
|
||||
reconnect_count: int = 0
|
||||
bytes_received: int = 0
|
||||
frames_per_second: float = 0.0
|
||||
connection_attempts: int = 0
|
||||
last_connection_test: Optional[float] = None
|
||||
connection_healthy: bool = True
|
||||
last_error: Optional[str] = None
|
||||
last_error_time: Optional[float] = None
|
||||
|
||||
|
||||
class StreamHealthTracker:
|
||||
"""Tracks health for individual video streams."""
|
||||
|
||||
def __init__(self):
|
||||
self.streams: Dict[str, StreamMetrics] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Configuration
|
||||
self.connection_test_interval = 300 # Test connection every 5 minutes
|
||||
self.frame_timeout_warning = 120 # Warn if no frames for 2 minutes
|
||||
self.frame_timeout_critical = 300 # Critical if no frames for 5 minutes
|
||||
self.error_rate_threshold = 0.1 # 10% error rate threshold
|
||||
|
||||
# Register with health monitor
|
||||
health_monitor.register_health_checker(self._perform_health_checks)
|
||||
|
||||
def register_stream(self, camera_id: str, stream_type: str, source_url: Optional[str] = None):
|
||||
"""Register a new stream for monitoring."""
|
||||
with self._lock:
|
||||
if camera_id not in self.streams:
|
||||
self.streams[camera_id] = StreamMetrics(
|
||||
camera_id=camera_id,
|
||||
stream_type=stream_type,
|
||||
start_time=time.time()
|
||||
)
|
||||
logger.info(f"Registered stream for monitoring: {camera_id} ({stream_type})")
|
||||
|
||||
# Update health monitor metrics
|
||||
health_monitor.update_metrics(
|
||||
camera_id,
|
||||
thread_alive=True,
|
||||
connection_healthy=True
|
||||
)
|
||||
|
||||
def unregister_stream(self, camera_id: str):
|
||||
"""Unregister a stream from monitoring."""
|
||||
with self._lock:
|
||||
if camera_id in self.streams:
|
||||
del self.streams[camera_id]
|
||||
logger.info(f"Unregistered stream from monitoring: {camera_id}")
|
||||
|
||||
def report_frame_received(self, camera_id: str, frame_size_bytes: int = 0):
|
||||
"""Report that a frame was received."""
|
||||
current_time = time.time()
|
||||
|
||||
with self._lock:
|
||||
if camera_id not in self.streams:
|
||||
logger.warning(f"Frame received for unregistered stream: {camera_id}")
|
||||
return
|
||||
|
||||
stream = self.streams[camera_id]
|
||||
|
||||
# Update frame metrics
|
||||
if stream.last_frame_time:
|
||||
interval = current_time - stream.last_frame_time
|
||||
# Calculate FPS as moving average
|
||||
if stream.frames_per_second == 0:
|
||||
stream.frames_per_second = 1.0 / interval if interval > 0 else 0
|
||||
else:
|
||||
new_fps = 1.0 / interval if interval > 0 else 0
|
||||
stream.frames_per_second = (stream.frames_per_second * 0.9) + (new_fps * 0.1)
|
||||
|
||||
stream.last_frame_time = current_time
|
||||
stream.frame_count += 1
|
||||
stream.bytes_received += frame_size_bytes
|
||||
|
||||
# Report to health monitor
|
||||
health_monitor.report_frame_received(camera_id)
|
||||
health_monitor.update_metrics(
|
||||
camera_id,
|
||||
frame_count=stream.frame_count,
|
||||
avg_frame_interval=1.0 / stream.frames_per_second if stream.frames_per_second > 0 else 0,
|
||||
last_frame_time=current_time
|
||||
)
|
||||
|
||||
def report_error(self, camera_id: str, error_message: str):
|
||||
"""Report an error for a stream."""
|
||||
current_time = time.time()
|
||||
|
||||
with self._lock:
|
||||
if camera_id not in self.streams:
|
||||
logger.warning(f"Error reported for unregistered stream: {camera_id}")
|
||||
return
|
||||
|
||||
stream = self.streams[camera_id]
|
||||
stream.error_count += 1
|
||||
stream.last_error = error_message
|
||||
stream.last_error_time = current_time
|
||||
|
||||
# Report to health monitor
|
||||
health_monitor.report_error(camera_id, "stream_error")
|
||||
health_monitor.update_metrics(
|
||||
camera_id,
|
||||
error_count=stream.error_count
|
||||
)
|
||||
|
||||
logger.debug(f"Error reported for stream {camera_id}: {error_message}")
|
||||
|
||||
def report_reconnect(self, camera_id: str, reason: str = "unknown"):
|
||||
"""Report that a stream reconnected."""
|
||||
current_time = time.time()
|
||||
|
||||
with self._lock:
|
||||
if camera_id not in self.streams:
|
||||
logger.warning(f"Reconnect reported for unregistered stream: {camera_id}")
|
||||
return
|
||||
|
||||
stream = self.streams[camera_id]
|
||||
stream.reconnect_count += 1
|
||||
|
||||
# Report to health monitor
|
||||
health_monitor.report_restart(camera_id)
|
||||
health_monitor.update_metrics(
|
||||
camera_id,
|
||||
restart_count=stream.reconnect_count
|
||||
)
|
||||
|
||||
logger.info(f"Reconnect reported for stream {camera_id}: {reason}")
|
||||
|
||||
def report_connection_attempt(self, camera_id: str, success: bool):
|
||||
"""Report a connection attempt."""
|
||||
with self._lock:
|
||||
if camera_id not in self.streams:
|
||||
return
|
||||
|
||||
stream = self.streams[camera_id]
|
||||
stream.connection_attempts += 1
|
||||
stream.connection_healthy = success
|
||||
|
||||
# Report to health monitor
|
||||
health_monitor.update_metrics(
|
||||
camera_id,
|
||||
connection_healthy=success
|
||||
)
|
||||
|
||||
def test_http_connection(self, camera_id: str, url: str) -> bool:
|
||||
"""Test HTTP connection health for snapshot streams."""
|
||||
try:
|
||||
# Quick HEAD request to test connectivity
|
||||
response = requests.head(url, timeout=5, verify=False)
|
||||
success = response.status_code in [200, 404] # 404 might be normal for some cameras
|
||||
|
||||
self.report_connection_attempt(camera_id, success)
|
||||
|
||||
if success:
|
||||
logger.debug(f"Connection test passed for {camera_id}")
|
||||
else:
|
||||
logger.warning(f"Connection test failed for {camera_id}: HTTP {response.status_code}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Connection test failed for {camera_id}: {e}")
|
||||
self.report_connection_attempt(camera_id, False)
|
||||
return False
|
||||
|
||||
def get_stream_metrics(self, camera_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get metrics for a specific stream."""
|
||||
with self._lock:
|
||||
if camera_id not in self.streams:
|
||||
return None
|
||||
|
||||
stream = self.streams[camera_id]
|
||||
current_time = time.time()
|
||||
|
||||
# Calculate derived metrics
|
||||
uptime = current_time - stream.start_time
|
||||
frame_age = current_time - stream.last_frame_time if stream.last_frame_time else None
|
||||
error_rate = stream.error_count / max(1, stream.frame_count)
|
||||
|
||||
return {
|
||||
'camera_id': camera_id,
|
||||
'stream_type': stream.stream_type,
|
||||
'uptime_seconds': uptime,
|
||||
'frame_count': stream.frame_count,
|
||||
'frames_per_second': stream.frames_per_second,
|
||||
'bytes_received': stream.bytes_received,
|
||||
'error_count': stream.error_count,
|
||||
'error_rate': error_rate,
|
||||
'reconnect_count': stream.reconnect_count,
|
||||
'connection_attempts': stream.connection_attempts,
|
||||
'connection_healthy': stream.connection_healthy,
|
||||
'last_frame_age_seconds': frame_age,
|
||||
'last_error': stream.last_error,
|
||||
'last_error_time': stream.last_error_time
|
||||
}
|
||||
|
||||
def get_all_metrics(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get metrics for all streams."""
|
||||
with self._lock:
|
||||
return {
|
||||
camera_id: self.get_stream_metrics(camera_id)
|
||||
for camera_id in self.streams.keys()
|
||||
}
|
||||
|
||||
def _perform_health_checks(self) -> List[HealthCheck]:
|
||||
"""Perform health checks for all streams."""
|
||||
checks = []
|
||||
current_time = time.time()
|
||||
|
||||
with self._lock:
|
||||
for camera_id, stream in self.streams.items():
|
||||
checks.extend(self._check_stream_health(camera_id, stream, current_time))
|
||||
|
||||
return checks
|
||||
|
||||
def _check_stream_health(self, camera_id: str, stream: StreamMetrics, current_time: float) -> List[HealthCheck]:
|
||||
"""Perform health checks for a single stream."""
|
||||
checks = []
|
||||
|
||||
# Check frame freshness
|
||||
if stream.last_frame_time:
|
||||
frame_age = current_time - stream.last_frame_time
|
||||
|
||||
if frame_age > self.frame_timeout_critical:
|
||||
checks.append(HealthCheck(
|
||||
name=f"stream_{camera_id}_frames",
|
||||
status=HealthStatus.CRITICAL,
|
||||
message=f"No frames for {frame_age:.1f}s (critical threshold: {self.frame_timeout_critical}s)",
|
||||
details={
|
||||
'frame_age': frame_age,
|
||||
'threshold': self.frame_timeout_critical,
|
||||
'last_frame_time': stream.last_frame_time
|
||||
},
|
||||
recovery_action="restart_stream"
|
||||
))
|
||||
elif frame_age > self.frame_timeout_warning:
|
||||
checks.append(HealthCheck(
|
||||
name=f"stream_{camera_id}_frames",
|
||||
status=HealthStatus.WARNING,
|
||||
message=f"Frames aging: {frame_age:.1f}s (warning threshold: {self.frame_timeout_warning}s)",
|
||||
details={
|
||||
'frame_age': frame_age,
|
||||
'threshold': self.frame_timeout_warning,
|
||||
'last_frame_time': stream.last_frame_time
|
||||
}
|
||||
))
|
||||
else:
|
||||
# No frames received yet
|
||||
startup_time = current_time - stream.start_time
|
||||
if startup_time > 60: # Allow 1 minute for initial connection
|
||||
checks.append(HealthCheck(
|
||||
name=f"stream_{camera_id}_startup",
|
||||
status=HealthStatus.CRITICAL,
|
||||
message=f"No frames received since startup {startup_time:.1f}s ago",
|
||||
details={
|
||||
'startup_time': startup_time,
|
||||
'start_time': stream.start_time
|
||||
},
|
||||
recovery_action="restart_stream"
|
||||
))
|
||||
|
||||
# Check error rate
|
||||
if stream.frame_count > 10: # Need sufficient samples
|
||||
error_rate = stream.error_count / stream.frame_count
|
||||
if error_rate > self.error_rate_threshold:
|
||||
checks.append(HealthCheck(
|
||||
name=f"stream_{camera_id}_errors",
|
||||
status=HealthStatus.WARNING,
|
||||
message=f"High error rate: {error_rate:.1%} ({stream.error_count}/{stream.frame_count})",
|
||||
details={
|
||||
'error_rate': error_rate,
|
||||
'error_count': stream.error_count,
|
||||
'frame_count': stream.frame_count,
|
||||
'last_error': stream.last_error
|
||||
}
|
||||
))
|
||||
|
||||
# Check connection health
|
||||
if not stream.connection_healthy:
|
||||
checks.append(HealthCheck(
|
||||
name=f"stream_{camera_id}_connection",
|
||||
status=HealthStatus.WARNING,
|
||||
message="Connection unhealthy (last test failed)",
|
||||
details={
|
||||
'connection_attempts': stream.connection_attempts,
|
||||
'last_connection_test': stream.last_connection_test
|
||||
}
|
||||
))
|
||||
|
||||
# Check excessive reconnects
|
||||
uptime_hours = (current_time - stream.start_time) / 3600
|
||||
if uptime_hours > 1 and stream.reconnect_count > 5: # More than 5 reconnects per hour
|
||||
reconnect_rate = stream.reconnect_count / uptime_hours
|
||||
checks.append(HealthCheck(
|
||||
name=f"stream_{camera_id}_stability",
|
||||
status=HealthStatus.WARNING,
|
||||
message=f"Frequent reconnects: {reconnect_rate:.1f}/hour ({stream.reconnect_count} total)",
|
||||
details={
|
||||
'reconnect_rate': reconnect_rate,
|
||||
'reconnect_count': stream.reconnect_count,
|
||||
'uptime_hours': uptime_hours
|
||||
}
|
||||
))
|
||||
|
||||
# Check frame rate health
|
||||
if stream.last_frame_time and stream.frames_per_second > 0:
|
||||
expected_fps = 6.0 # Expected FPS for streams
|
||||
if stream.frames_per_second < expected_fps * 0.5: # Less than 50% of expected
|
||||
checks.append(HealthCheck(
|
||||
name=f"stream_{camera_id}_framerate",
|
||||
status=HealthStatus.WARNING,
|
||||
message=f"Low frame rate: {stream.frames_per_second:.1f} fps (expected: ~{expected_fps} fps)",
|
||||
details={
|
||||
'current_fps': stream.frames_per_second,
|
||||
'expected_fps': expected_fps
|
||||
}
|
||||
))
|
||||
|
||||
return checks
|
||||
|
||||
|
||||
# Global stream health tracker instance
|
||||
stream_health_tracker = StreamHealthTracker()
|
||||
|
|
@ -1,381 +0,0 @@
|
|||
"""
|
||||
Thread health monitoring for detecting unresponsive and deadlocked threads.
|
||||
Provides thread liveness detection and responsiveness testing.
|
||||
"""
|
||||
import time
|
||||
import threading
|
||||
import logging
|
||||
import signal
|
||||
import traceback
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
|
||||
from .health import HealthCheck, HealthStatus, health_monitor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThreadInfo:
|
||||
"""Information about a monitored thread."""
|
||||
thread_id: int
|
||||
thread_name: str
|
||||
start_time: float
|
||||
last_heartbeat: float
|
||||
heartbeat_count: int = 0
|
||||
is_responsive: bool = True
|
||||
last_activity: Optional[str] = None
|
||||
stack_traces: List[str] = None
|
||||
|
||||
|
||||
class ThreadHealthMonitor:
|
||||
"""Monitors thread health and responsiveness."""
|
||||
|
||||
def __init__(self):
|
||||
self.monitored_threads: Dict[int, ThreadInfo] = {}
|
||||
self.heartbeat_callbacks: Dict[int, Callable[[], bool]] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Configuration
|
||||
self.heartbeat_timeout = 60.0 # 1 minute without heartbeat = unresponsive
|
||||
self.responsiveness_test_interval = 30.0 # Test responsiveness every 30 seconds
|
||||
self.stack_trace_count = 5 # Keep last 5 stack traces for analysis
|
||||
|
||||
# Register with health monitor
|
||||
health_monitor.register_health_checker(self._perform_health_checks)
|
||||
|
||||
# Enable periodic responsiveness testing
|
||||
self.test_thread = threading.Thread(target=self._responsiveness_test_loop, daemon=True)
|
||||
self.test_thread.start()
|
||||
|
||||
def register_thread(self, thread: threading.Thread, heartbeat_callback: Optional[Callable[[], bool]] = None):
|
||||
"""
|
||||
Register a thread for monitoring.
|
||||
|
||||
Args:
|
||||
thread: Thread to monitor
|
||||
heartbeat_callback: Optional callback to test thread responsiveness
|
||||
"""
|
||||
with self._lock:
|
||||
thread_info = ThreadInfo(
|
||||
thread_id=thread.ident,
|
||||
thread_name=thread.name,
|
||||
start_time=time.time(),
|
||||
last_heartbeat=time.time()
|
||||
)
|
||||
|
||||
self.monitored_threads[thread.ident] = thread_info
|
||||
|
||||
if heartbeat_callback:
|
||||
self.heartbeat_callbacks[thread.ident] = heartbeat_callback
|
||||
|
||||
logger.info(f"Registered thread for monitoring: {thread.name} (ID: {thread.ident})")
|
||||
|
||||
def unregister_thread(self, thread_id: int):
|
||||
"""Unregister a thread from monitoring."""
|
||||
with self._lock:
|
||||
if thread_id in self.monitored_threads:
|
||||
thread_name = self.monitored_threads[thread_id].thread_name
|
||||
del self.monitored_threads[thread_id]
|
||||
|
||||
if thread_id in self.heartbeat_callbacks:
|
||||
del self.heartbeat_callbacks[thread_id]
|
||||
|
||||
logger.info(f"Unregistered thread from monitoring: {thread_name} (ID: {thread_id})")
|
||||
|
||||
def heartbeat(self, thread_id: Optional[int] = None, activity: Optional[str] = None):
|
||||
"""
|
||||
Report thread heartbeat.
|
||||
|
||||
Args:
|
||||
thread_id: Thread ID (uses current thread if None)
|
||||
activity: Description of current activity
|
||||
"""
|
||||
if thread_id is None:
|
||||
thread_id = threading.current_thread().ident
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
with self._lock:
|
||||
if thread_id in self.monitored_threads:
|
||||
thread_info = self.monitored_threads[thread_id]
|
||||
thread_info.last_heartbeat = current_time
|
||||
thread_info.heartbeat_count += 1
|
||||
thread_info.is_responsive = True
|
||||
|
||||
if activity:
|
||||
thread_info.last_activity = activity
|
||||
|
||||
# Report to health monitor
|
||||
health_monitor.update_metrics(
|
||||
f"thread_{thread_info.thread_name}",
|
||||
thread_alive=True,
|
||||
last_frame_time=current_time
|
||||
)
|
||||
|
||||
def get_thread_info(self, thread_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""Get information about a monitored thread."""
|
||||
with self._lock:
|
||||
if thread_id not in self.monitored_threads:
|
||||
return None
|
||||
|
||||
thread_info = self.monitored_threads[thread_id]
|
||||
current_time = time.time()
|
||||
|
||||
return {
|
||||
'thread_id': thread_id,
|
||||
'thread_name': thread_info.thread_name,
|
||||
'uptime_seconds': current_time - thread_info.start_time,
|
||||
'last_heartbeat_age': current_time - thread_info.last_heartbeat,
|
||||
'heartbeat_count': thread_info.heartbeat_count,
|
||||
'is_responsive': thread_info.is_responsive,
|
||||
'last_activity': thread_info.last_activity,
|
||||
'stack_traces': thread_info.stack_traces or []
|
||||
}
|
||||
|
||||
def get_all_thread_info(self) -> Dict[int, Dict[str, Any]]:
|
||||
"""Get information about all monitored threads."""
|
||||
with self._lock:
|
||||
return {
|
||||
thread_id: self.get_thread_info(thread_id)
|
||||
for thread_id in self.monitored_threads.keys()
|
||||
}
|
||||
|
||||
def test_thread_responsiveness(self, thread_id: int) -> bool:
|
||||
"""
|
||||
Test if a thread is responsive by calling its heartbeat callback.
|
||||
|
||||
Args:
|
||||
thread_id: ID of thread to test
|
||||
|
||||
Returns:
|
||||
True if thread responds within timeout
|
||||
"""
|
||||
if thread_id not in self.heartbeat_callbacks:
|
||||
return True # Can't test if no callback provided
|
||||
|
||||
try:
|
||||
# Call the heartbeat callback with a timeout
|
||||
callback = self.heartbeat_callbacks[thread_id]
|
||||
|
||||
# This is a simple approach - in practice you might want to use
|
||||
# threading.Timer or asyncio for more sophisticated timeout handling
|
||||
start_time = time.time()
|
||||
result = callback()
|
||||
response_time = time.time() - start_time
|
||||
|
||||
with self._lock:
|
||||
if thread_id in self.monitored_threads:
|
||||
self.monitored_threads[thread_id].is_responsive = result
|
||||
|
||||
if response_time > 5.0: # Slow response
|
||||
logger.warning(f"Thread {thread_id} slow response: {response_time:.1f}s")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing thread {thread_id} responsiveness: {e}")
|
||||
with self._lock:
|
||||
if thread_id in self.monitored_threads:
|
||||
self.monitored_threads[thread_id].is_responsive = False
|
||||
return False
|
||||
|
||||
def capture_stack_trace(self, thread_id: int) -> Optional[str]:
|
||||
"""
|
||||
Capture stack trace for a thread.
|
||||
|
||||
Args:
|
||||
thread_id: ID of thread to capture
|
||||
|
||||
Returns:
|
||||
Stack trace string or None if not available
|
||||
"""
|
||||
try:
|
||||
# Get all frames for all threads
|
||||
frames = dict(threading._current_frames())
|
||||
|
||||
if thread_id not in frames:
|
||||
return None
|
||||
|
||||
# Format stack trace
|
||||
frame = frames[thread_id]
|
||||
stack_trace = ''.join(traceback.format_stack(frame))
|
||||
|
||||
# Store in thread info
|
||||
with self._lock:
|
||||
if thread_id in self.monitored_threads:
|
||||
thread_info = self.monitored_threads[thread_id]
|
||||
if thread_info.stack_traces is None:
|
||||
thread_info.stack_traces = []
|
||||
|
||||
thread_info.stack_traces.append(f"{time.time()}: {stack_trace}")
|
||||
|
||||
# Keep only last N stack traces
|
||||
if len(thread_info.stack_traces) > self.stack_trace_count:
|
||||
thread_info.stack_traces = thread_info.stack_traces[-self.stack_trace_count:]
|
||||
|
||||
return stack_trace
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing stack trace for thread {thread_id}: {e}")
|
||||
return None
|
||||
|
||||
def detect_deadlocks(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Attempt to detect potential deadlocks by analyzing thread states.
|
||||
|
||||
Returns:
|
||||
List of potential deadlock scenarios
|
||||
"""
|
||||
deadlocks = []
|
||||
current_time = time.time()
|
||||
|
||||
with self._lock:
|
||||
# Look for threads that haven't had heartbeats for a long time
|
||||
# and are supposedly alive
|
||||
for thread_id, thread_info in self.monitored_threads.items():
|
||||
heartbeat_age = current_time - thread_info.last_heartbeat
|
||||
|
||||
if heartbeat_age > self.heartbeat_timeout * 2: # Double the timeout
|
||||
# Check if thread still exists
|
||||
thread_exists = any(
|
||||
t.ident == thread_id and t.is_alive()
|
||||
for t in threading.enumerate()
|
||||
)
|
||||
|
||||
if thread_exists:
|
||||
# Thread exists but not responding - potential deadlock
|
||||
stack_trace = self.capture_stack_trace(thread_id)
|
||||
|
||||
deadlock_info = {
|
||||
'thread_id': thread_id,
|
||||
'thread_name': thread_info.thread_name,
|
||||
'heartbeat_age': heartbeat_age,
|
||||
'last_activity': thread_info.last_activity,
|
||||
'stack_trace': stack_trace,
|
||||
'detection_time': current_time
|
||||
}
|
||||
|
||||
deadlocks.append(deadlock_info)
|
||||
logger.warning(f"Potential deadlock detected in thread {thread_info.thread_name}")
|
||||
|
||||
return deadlocks
|
||||
|
||||
def _responsiveness_test_loop(self):
|
||||
"""Background loop to test thread responsiveness."""
|
||||
logger.info("Thread responsiveness testing started")
|
||||
|
||||
while True:
|
||||
try:
|
||||
time.sleep(self.responsiveness_test_interval)
|
||||
|
||||
with self._lock:
|
||||
thread_ids = list(self.monitored_threads.keys())
|
||||
|
||||
for thread_id in thread_ids:
|
||||
try:
|
||||
self.test_thread_responsiveness(thread_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing thread {thread_id}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in responsiveness test loop: {e}")
|
||||
time.sleep(10.0) # Fallback sleep
|
||||
|
||||
def _perform_health_checks(self) -> List[HealthCheck]:
|
||||
"""Perform health checks for all monitored threads."""
|
||||
checks = []
|
||||
current_time = time.time()
|
||||
|
||||
with self._lock:
|
||||
for thread_id, thread_info in self.monitored_threads.items():
|
||||
checks.extend(self._check_thread_health(thread_id, thread_info, current_time))
|
||||
|
||||
# Check for deadlocks
|
||||
deadlocks = self.detect_deadlocks()
|
||||
for deadlock in deadlocks:
|
||||
checks.append(HealthCheck(
|
||||
name=f"deadlock_detection_{deadlock['thread_id']}",
|
||||
status=HealthStatus.CRITICAL,
|
||||
message=f"Potential deadlock in thread {deadlock['thread_name']} "
|
||||
f"(unresponsive for {deadlock['heartbeat_age']:.1f}s)",
|
||||
details=deadlock,
|
||||
recovery_action="restart_thread"
|
||||
))
|
||||
|
||||
return checks
|
||||
|
||||
def _check_thread_health(self, thread_id: int, thread_info: ThreadInfo, current_time: float) -> List[HealthCheck]:
|
||||
"""Perform health checks for a single thread."""
|
||||
checks = []
|
||||
|
||||
# Check if thread still exists
|
||||
thread_exists = any(
|
||||
t.ident == thread_id and t.is_alive()
|
||||
for t in threading.enumerate()
|
||||
)
|
||||
|
||||
if not thread_exists:
|
||||
checks.append(HealthCheck(
|
||||
name=f"thread_{thread_info.thread_name}_alive",
|
||||
status=HealthStatus.CRITICAL,
|
||||
message=f"Thread {thread_info.thread_name} is no longer alive",
|
||||
details={
|
||||
'thread_id': thread_id,
|
||||
'uptime': current_time - thread_info.start_time,
|
||||
'last_heartbeat': thread_info.last_heartbeat
|
||||
},
|
||||
recovery_action="restart_thread"
|
||||
))
|
||||
return checks
|
||||
|
||||
# Check heartbeat freshness
|
||||
heartbeat_age = current_time - thread_info.last_heartbeat
|
||||
|
||||
if heartbeat_age > self.heartbeat_timeout:
|
||||
checks.append(HealthCheck(
|
||||
name=f"thread_{thread_info.thread_name}_responsive",
|
||||
status=HealthStatus.CRITICAL,
|
||||
message=f"Thread {thread_info.thread_name} unresponsive for {heartbeat_age:.1f}s",
|
||||
details={
|
||||
'thread_id': thread_id,
|
||||
'heartbeat_age': heartbeat_age,
|
||||
'heartbeat_count': thread_info.heartbeat_count,
|
||||
'last_activity': thread_info.last_activity,
|
||||
'is_responsive': thread_info.is_responsive
|
||||
},
|
||||
recovery_action="restart_thread"
|
||||
))
|
||||
elif heartbeat_age > self.heartbeat_timeout * 0.5: # Warning at 50% of timeout
|
||||
checks.append(HealthCheck(
|
||||
name=f"thread_{thread_info.thread_name}_responsive",
|
||||
status=HealthStatus.WARNING,
|
||||
message=f"Thread {thread_info.thread_name} slow heartbeat: {heartbeat_age:.1f}s",
|
||||
details={
|
||||
'thread_id': thread_id,
|
||||
'heartbeat_age': heartbeat_age,
|
||||
'heartbeat_count': thread_info.heartbeat_count,
|
||||
'last_activity': thread_info.last_activity,
|
||||
'is_responsive': thread_info.is_responsive
|
||||
}
|
||||
))
|
||||
|
||||
# Check responsiveness test results
|
||||
if not thread_info.is_responsive:
|
||||
checks.append(HealthCheck(
|
||||
name=f"thread_{thread_info.thread_name}_callback",
|
||||
status=HealthStatus.WARNING,
|
||||
message=f"Thread {thread_info.thread_name} failed responsiveness test",
|
||||
details={
|
||||
'thread_id': thread_id,
|
||||
'last_activity': thread_info.last_activity
|
||||
}
|
||||
))
|
||||
|
||||
return checks
|
||||
|
||||
|
||||
# Global thread health monitor instance
|
||||
thread_health_monitor = ThreadHealthMonitor()
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
"""
|
||||
Storage module for the Python Detector Worker.
|
||||
|
||||
This module provides Redis operations for data persistence
|
||||
and caching in the detection pipeline.
|
||||
|
||||
Note: PostgreSQL operations have been disabled as database functionality
|
||||
has been moved to microservices architecture.
|
||||
"""
|
||||
from .redis import RedisManager
|
||||
# from .database import DatabaseManager # Disabled - moved to microservices
|
||||
|
||||
__all__ = ['RedisManager'] # Removed 'DatabaseManager'
|
||||
|
|
@ -1,369 +0,0 @@
|
|||
"""
|
||||
Database Operations Module - DISABLED
|
||||
|
||||
NOTE: This module has been disabled as PostgreSQL database operations have been
|
||||
moved to microservices architecture. All database connections, reads, and writes
|
||||
are now handled by separate backend services.
|
||||
|
||||
The detection worker now communicates results via:
|
||||
- WebSocket imageDetection messages (primary data flow to backend)
|
||||
- Redis image storage and pub/sub (temporary storage)
|
||||
|
||||
Original functionality: PostgreSQL operations for the detection pipeline.
|
||||
Status: Commented out - DO NOT ENABLE without updating architecture
|
||||
"""
|
||||
|
||||
# All PostgreSQL functionality below has been commented out
|
||||
# import psycopg2
|
||||
# import psycopg2.extras
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
# import uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# DatabaseManager class is disabled - all methods commented out
|
||||
# class DatabaseManager:
|
||||
# """
|
||||
# Manages PostgreSQL connections and operations for the detection pipeline.
|
||||
# Handles database operations and schema management.
|
||||
# """
|
||||
#
|
||||
# def __init__(self, config: Dict[str, Any]):
|
||||
# """
|
||||
# Initialize database manager with configuration.
|
||||
#
|
||||
# Args:
|
||||
# config: Database configuration dictionary
|
||||
# """
|
||||
# self.config = config
|
||||
# self.connection: Optional[psycopg2.extensions.connection] = None
|
||||
#
|
||||
# def connect(self) -> bool:
|
||||
# """
|
||||
# Connect to PostgreSQL database.
|
||||
#
|
||||
# Returns:
|
||||
# True if successful, False otherwise
|
||||
# """
|
||||
# try:
|
||||
# self.connection = psycopg2.connect(
|
||||
# host=self.config['host'],
|
||||
# port=self.config['port'],
|
||||
# database=self.config['database'],
|
||||
# user=self.config['username'],
|
||||
# password=self.config['password']
|
||||
# )
|
||||
# logger.info("PostgreSQL connection established successfully")
|
||||
# return True
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to connect to PostgreSQL: {e}")
|
||||
# return False
|
||||
#
|
||||
# def disconnect(self):
|
||||
# """Disconnect from PostgreSQL database."""
|
||||
# if self.connection:
|
||||
# self.connection.close()
|
||||
# self.connection = None
|
||||
# logger.info("PostgreSQL connection closed")
|
||||
#
|
||||
# def is_connected(self) -> bool:
|
||||
# """
|
||||
# Check if database connection is active.
|
||||
#
|
||||
# Returns:
|
||||
# True if connected, False otherwise
|
||||
# """
|
||||
# try:
|
||||
# if self.connection and not self.connection.closed:
|
||||
# cur = self.connection.cursor()
|
||||
# cur.execute("SELECT 1")
|
||||
# cur.fetchone()
|
||||
# cur.close()
|
||||
# return True
|
||||
# except:
|
||||
# pass
|
||||
# return False
|
||||
#
|
||||
# def update_car_info(self, session_id: str, brand: str, model: str, body_type: str) -> bool:
|
||||
# """
|
||||
# Update car information in the database.
|
||||
#
|
||||
# Args:
|
||||
# session_id: Session identifier
|
||||
# brand: Car brand
|
||||
# model: Car model
|
||||
# body_type: Car body type
|
||||
#
|
||||
# Returns:
|
||||
# True if successful, False otherwise
|
||||
# """
|
||||
# if not self.is_connected():
|
||||
# if not self.connect():
|
||||
# return False
|
||||
#
|
||||
# try:
|
||||
# cur = self.connection.cursor()
|
||||
# query = """
|
||||
# INSERT INTO car_frontal_info (session_id, car_brand, car_model, car_body_type, updated_at)
|
||||
# VALUES (%s, %s, %s, %s, NOW())
|
||||
# ON CONFLICT (session_id)
|
||||
# DO UPDATE SET
|
||||
# car_brand = EXCLUDED.car_brand,
|
||||
# car_model = EXCLUDED.car_model,
|
||||
# car_body_type = EXCLUDED.car_body_type,
|
||||
# updated_at = NOW()
|
||||
# """
|
||||
# cur.execute(query, (session_id, brand, model, body_type))
|
||||
# self.connection.commit()
|
||||
# cur.close()
|
||||
# logger.info(f"Updated car info for session {session_id}: {brand} {model} ({body_type})")
|
||||
# return True
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to update car info: {e}")
|
||||
# if self.connection:
|
||||
# self.connection.rollback()
|
||||
# return False
|
||||
#
|
||||
# def execute_update(self, table: str, key_field: str, key_value: str, fields: Dict[str, str]) -> bool:
|
||||
# """
|
||||
# Execute a dynamic update query on the database.
|
||||
#
|
||||
# Args:
|
||||
# table: Table name
|
||||
# key_field: Primary key field name
|
||||
# key_value: Primary key value
|
||||
# fields: Dictionary of fields to update
|
||||
#
|
||||
# Returns:
|
||||
# True if successful, False otherwise
|
||||
# """
|
||||
# if not self.is_connected():
|
||||
# if not self.connect():
|
||||
# return False
|
||||
#
|
||||
# try:
|
||||
# cur = self.connection.cursor()
|
||||
#
|
||||
# # Build the UPDATE query dynamically
|
||||
# set_clauses = []
|
||||
# values = []
|
||||
#
|
||||
# for field, value in fields.items():
|
||||
# if value == "NOW()":
|
||||
# set_clauses.append(f"{field} = NOW()")
|
||||
# else:
|
||||
# set_clauses.append(f"{field} = %s")
|
||||
# values.append(value)
|
||||
#
|
||||
# # Add schema prefix if table doesn't already have it
|
||||
# full_table_name = table if '.' in table else f"gas_station_1.{table}"
|
||||
#
|
||||
# query = f"""
|
||||
# INSERT INTO {full_table_name} ({key_field}, {', '.join(fields.keys())})
|
||||
# VALUES (%s, {', '.join(['%s'] * len(fields))})
|
||||
# ON CONFLICT ({key_field})
|
||||
# DO UPDATE SET {', '.join(set_clauses)}
|
||||
# """
|
||||
#
|
||||
# # Add key_value to the beginning of values list
|
||||
# all_values = [key_value] + list(fields.values()) + values
|
||||
#
|
||||
# cur.execute(query, all_values)
|
||||
# self.connection.commit()
|
||||
# cur.close()
|
||||
# logger.info(f"Updated {table} for {key_field}={key_value}")
|
||||
# return True
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to execute update on {table}: {e}")
|
||||
# if self.connection:
|
||||
# self.connection.rollback()
|
||||
# return False
|
||||
#
|
||||
# def create_car_frontal_info_table(self) -> bool:
|
||||
# """
|
||||
# Create the car_frontal_info table in gas_station_1 schema if it doesn't exist.
|
||||
#
|
||||
# Returns:
|
||||
# True if successful, False otherwise
|
||||
# """
|
||||
# if not self.is_connected():
|
||||
# if not self.connect():
|
||||
# return False
|
||||
#
|
||||
# try:
|
||||
# # Since the database already exists, just verify connection
|
||||
# cur = self.connection.cursor()
|
||||
#
|
||||
# # Simple verification that the table exists
|
||||
# cur.execute("""
|
||||
# SELECT EXISTS (
|
||||
# SELECT FROM information_schema.tables
|
||||
# WHERE table_schema = 'gas_station_1'
|
||||
# AND table_name = 'car_frontal_info'
|
||||
# )
|
||||
# """)
|
||||
#
|
||||
# table_exists = cur.fetchone()[0]
|
||||
# cur.close()
|
||||
#
|
||||
# if table_exists:
|
||||
# logger.info("Verified car_frontal_info table exists")
|
||||
# return True
|
||||
# else:
|
||||
# logger.error("car_frontal_info table does not exist in the database")
|
||||
# return False
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to create car_frontal_info table: {e}")
|
||||
# if self.connection:
|
||||
# self.connection.rollback()
|
||||
# return False
|
||||
#
|
||||
# def insert_initial_detection(self, display_id: str, captured_timestamp: str, session_id: str = None) -> str:
|
||||
# """
|
||||
# Insert initial detection record and return the session_id.
|
||||
#
|
||||
# Args:
|
||||
# display_id: Display identifier
|
||||
# captured_timestamp: Timestamp of the detection
|
||||
# session_id: Optional session ID, generates one if not provided
|
||||
#
|
||||
# Returns:
|
||||
# Session ID string or None on error
|
||||
# """
|
||||
# if not self.is_connected():
|
||||
# if not self.connect():
|
||||
# return None
|
||||
#
|
||||
# # Generate session_id if not provided
|
||||
# if not session_id:
|
||||
# session_id = str(uuid.uuid4())
|
||||
#
|
||||
# try:
|
||||
# # Ensure table exists
|
||||
# if not self.create_car_frontal_info_table():
|
||||
# logger.error("Failed to create/verify table before insertion")
|
||||
# return None
|
||||
#
|
||||
# cur = self.connection.cursor()
|
||||
# insert_query = """
|
||||
# INSERT INTO gas_station_1.car_frontal_info
|
||||
# (display_id, captured_timestamp, session_id, license_character, license_type, car_brand, car_model, car_body_type)
|
||||
# VALUES (%s, %s, %s, NULL, 'No model available', NULL, NULL, NULL)
|
||||
# ON CONFLICT (session_id) DO NOTHING
|
||||
# """
|
||||
#
|
||||
# cur.execute(insert_query, (display_id, captured_timestamp, session_id))
|
||||
# self.connection.commit()
|
||||
# cur.close()
|
||||
# logger.info(f"Inserted initial detection record with session_id: {session_id}")
|
||||
# return session_id
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to insert initial detection record: {e}")
|
||||
# if self.connection:
|
||||
# self.connection.rollback()
|
||||
# return None
|
||||
#
|
||||
# def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
# """
|
||||
# Get session information from the database.
|
||||
#
|
||||
# Args:
|
||||
# session_id: Session identifier
|
||||
#
|
||||
# Returns:
|
||||
# Dictionary with session data or None if not found
|
||||
# """
|
||||
# if not self.is_connected():
|
||||
# if not self.connect():
|
||||
# return None
|
||||
#
|
||||
# try:
|
||||
# cur = self.connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
||||
# query = "SELECT * FROM gas_station_1.car_frontal_info WHERE session_id = %s"
|
||||
# cur.execute(query, (session_id,))
|
||||
# result = cur.fetchone()
|
||||
# cur.close()
|
||||
#
|
||||
# if result:
|
||||
# return dict(result)
|
||||
# else:
|
||||
# logger.debug(f"No session info found for session_id: {session_id}")
|
||||
# return None
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to get session info: {e}")
|
||||
# return None
|
||||
#
|
||||
# def delete_session(self, session_id: str) -> bool:
|
||||
# """
|
||||
# Delete session record from the database.
|
||||
#
|
||||
# Args:
|
||||
# session_id: Session identifier
|
||||
#
|
||||
# Returns:
|
||||
# True if successful, False otherwise
|
||||
# """
|
||||
# if not self.is_connected():
|
||||
# if not self.connect():
|
||||
# return False
|
||||
#
|
||||
# try:
|
||||
# cur = self.connection.cursor()
|
||||
# query = "DELETE FROM gas_station_1.car_frontal_info WHERE session_id = %s"
|
||||
# cur.execute(query, (session_id,))
|
||||
# rows_affected = cur.rowcount
|
||||
# self.connection.commit()
|
||||
# cur.close()
|
||||
#
|
||||
# if rows_affected > 0:
|
||||
# logger.info(f"Deleted session record: {session_id}")
|
||||
# return True
|
||||
# else:
|
||||
# logger.warning(f"No session record found to delete: {session_id}")
|
||||
# return False
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to delete session: {e}")
|
||||
# if self.connection:
|
||||
# self.connection.rollback()
|
||||
# return False
|
||||
#
|
||||
# def get_statistics(self) -> Dict[str, Any]:
|
||||
# """
|
||||
# Get database statistics.
|
||||
#
|
||||
# Returns:
|
||||
# Dictionary with database statistics
|
||||
# """
|
||||
# stats = {
|
||||
# 'connected': self.is_connected(),
|
||||
# 'host': self.config.get('host', 'unknown'),
|
||||
# 'port': self.config.get('port', 'unknown'),
|
||||
# 'database': self.config.get('database', 'unknown')
|
||||
# }
|
||||
#
|
||||
# if self.is_connected():
|
||||
# try:
|
||||
# cur = self.connection.cursor()
|
||||
#
|
||||
# # Get table record count
|
||||
# cur.execute("SELECT COUNT(*) FROM gas_station_1.car_frontal_info")
|
||||
# stats['total_records'] = cur.fetchone()[0]
|
||||
#
|
||||
# # Get recent records count (last hour)
|
||||
# cur.execute("""
|
||||
# SELECT COUNT(*) FROM gas_station_1.car_frontal_info
|
||||
# WHERE created_at > NOW() - INTERVAL '1 hour'
|
||||
# """)
|
||||
# stats['recent_records'] = cur.fetchone()[0]
|
||||
#
|
||||
# cur.close()
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to get database statistics: {e}")
|
||||
# stats['error'] = str(e)
|
||||
#
|
||||
# return stats
|
||||
|
|
@ -1,300 +0,0 @@
|
|||
"""
|
||||
License Plate Manager Module.
|
||||
Handles Redis subscription to license plate results from LPR service.
|
||||
"""
|
||||
import logging
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, Optional, Any, Callable
|
||||
import redis.asyncio as redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LicensePlateManager:
|
||||
"""
|
||||
Manages license plate result subscription from Redis channel.
|
||||
Subscribes to 'license_results' channel for license plate data from LPR service.
|
||||
"""
|
||||
|
||||
def __init__(self, redis_config: Dict[str, Any]):
|
||||
"""
|
||||
Initialize license plate manager with Redis configuration.
|
||||
|
||||
Args:
|
||||
redis_config: Redis configuration dictionary
|
||||
"""
|
||||
self.config = redis_config
|
||||
self.redis_client: Optional[redis.Redis] = None
|
||||
self.pubsub = None
|
||||
self.subscription_task = None
|
||||
self.callback = None
|
||||
|
||||
# Connection parameters
|
||||
self.host = redis_config.get('host', 'localhost')
|
||||
self.port = redis_config.get('port', 6379)
|
||||
self.password = redis_config.get('password')
|
||||
self.db = redis_config.get('db', 0)
|
||||
|
||||
# License plate data cache - store recent results by session_id
|
||||
self.license_plate_cache: Dict[str, Dict[str, Any]] = {}
|
||||
self.cache_ttl = 300 # 5 minutes TTL for cached results
|
||||
|
||||
logger.info(f"LicensePlateManager initialized for {self.host}:{self.port}")
|
||||
|
||||
async def initialize(self, callback: Optional[Callable] = None) -> bool:
|
||||
"""
|
||||
Initialize Redis connection and start subscription to license_results channel.
|
||||
|
||||
Args:
|
||||
callback: Optional callback function for processing license plate results
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Create Redis connection
|
||||
self.redis_client = redis.Redis(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
password=self.password,
|
||||
db=self.db,
|
||||
decode_responses=True
|
||||
)
|
||||
|
||||
# Test connection
|
||||
await self.redis_client.ping()
|
||||
logger.info(f"Connected to Redis for license plate subscription")
|
||||
|
||||
# Set callback
|
||||
self.callback = callback
|
||||
|
||||
# Start subscription
|
||||
await self._start_subscription()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize license plate manager: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def _start_subscription(self):
|
||||
"""Start Redis subscription to license_results channel."""
|
||||
try:
|
||||
if not self.redis_client:
|
||||
logger.error("Redis client not initialized")
|
||||
return
|
||||
|
||||
# Create pubsub and subscribe
|
||||
self.pubsub = self.redis_client.pubsub()
|
||||
await self.pubsub.subscribe('license_results')
|
||||
|
||||
logger.info("Subscribed to Redis channel: license_results")
|
||||
|
||||
# Start listening task
|
||||
self.subscription_task = asyncio.create_task(self._listen_for_messages())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting license plate subscription: {e}", exc_info=True)
|
||||
|
||||
async def _listen_for_messages(self):
|
||||
"""Listen for messages on the license_results channel."""
|
||||
listen_generator = None
|
||||
try:
|
||||
if not self.pubsub:
|
||||
return
|
||||
|
||||
listen_generator = self.pubsub.listen()
|
||||
async for message in listen_generator:
|
||||
if message['type'] == 'message':
|
||||
try:
|
||||
# Log the raw message from Redis channel
|
||||
logger.info(f"[LICENSE PLATE RAW] Received from 'license_results' channel: {message['data']}")
|
||||
|
||||
# Parse the license plate result message
|
||||
data = json.loads(message['data'])
|
||||
logger.info(f"[LICENSE PLATE PARSED] Parsed JSON data: {data}")
|
||||
await self._process_license_plate_result(data)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"[LICENSE PLATE ERROR] Invalid JSON in license plate message: {e}")
|
||||
logger.error(f"[LICENSE PLATE ERROR] Raw message was: {message['data']}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing license plate message: {e}", exc_info=True)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("License plate subscription task cancelled")
|
||||
# Don't try to close generator here - let it be handled by the context
|
||||
# The async generator will be properly closed by the cancellation mechanism
|
||||
raise # Re-raise to maintain proper cancellation semantics
|
||||
except Exception as e:
|
||||
logger.error(f"Error in license plate message listener: {e}", exc_info=True)
|
||||
# Only attempt cleanup if it's not a cancellation
|
||||
finally:
|
||||
# Safe cleanup of async generator
|
||||
if listen_generator is not None:
|
||||
try:
|
||||
# Check if we can safely close without conflicting with ongoing operations
|
||||
if hasattr(listen_generator, 'aclose') and not asyncio.current_task().cancelled():
|
||||
await listen_generator.aclose()
|
||||
except (RuntimeError, AttributeError) as e:
|
||||
# Generator is already closing or in invalid state - safe to ignore
|
||||
logger.debug(f"Generator cleanup skipped (safe): {e}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Generator cleanup error (non-critical): {e}")
|
||||
|
||||
async def _process_license_plate_result(self, data: Dict[str, Any]):
|
||||
"""
|
||||
Process incoming license plate result from LPR service.
|
||||
|
||||
Expected message format (from actual LPR service):
|
||||
{
|
||||
"session_id": "511",
|
||||
"license_character": "ข3184"
|
||||
}
|
||||
or
|
||||
{
|
||||
"session_id": "508",
|
||||
"display_id": "test3",
|
||||
"license_plate_text": "ABC-123",
|
||||
"confidence": 0.95,
|
||||
"timestamp": "2025-09-24T21:10:00Z"
|
||||
}
|
||||
|
||||
Args:
|
||||
data: License plate result data
|
||||
"""
|
||||
try:
|
||||
session_id = data.get('session_id')
|
||||
if not session_id:
|
||||
logger.warning("License plate result missing session_id")
|
||||
return
|
||||
|
||||
# Handle different message formats
|
||||
# Format 1: {"session_id": "511", "license_character": "ข3184"}
|
||||
# Format 2: {"session_id": "508", "license_plate_text": "ABC-123", "confidence": 0.95, ...}
|
||||
license_plate_text = data.get('license_plate_text') or data.get('license_character')
|
||||
confidence = data.get('confidence', 1.0) # Default confidence for LPR service results
|
||||
display_id = data.get('display_id', '')
|
||||
timestamp = data.get('timestamp', '')
|
||||
|
||||
logger.info(f"[LICENSE PLATE] Received result for session {session_id}: "
|
||||
f"text='{license_plate_text}', confidence={confidence:.3f}")
|
||||
|
||||
# Store in cache
|
||||
self.license_plate_cache[session_id] = {
|
||||
'license_plate_text': license_plate_text,
|
||||
'confidence': confidence,
|
||||
'display_id': display_id,
|
||||
'timestamp': timestamp,
|
||||
'received_at': asyncio.get_event_loop().time()
|
||||
}
|
||||
|
||||
# Call callback if provided
|
||||
if self.callback:
|
||||
await self.callback(session_id, {
|
||||
'license_plate_text': license_plate_text,
|
||||
'confidence': confidence,
|
||||
'display_id': display_id,
|
||||
'timestamp': timestamp
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing license plate result: {e}", exc_info=True)
|
||||
|
||||
def get_license_plate_result(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get cached license plate result for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
License plate result dictionary or None if not found
|
||||
"""
|
||||
if session_id not in self.license_plate_cache:
|
||||
return None
|
||||
|
||||
result = self.license_plate_cache[session_id]
|
||||
|
||||
# Check TTL
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
if current_time - result.get('received_at', 0) > self.cache_ttl:
|
||||
# Expired, remove from cache
|
||||
del self.license_plate_cache[session_id]
|
||||
return None
|
||||
|
||||
return {
|
||||
'license_plate_text': result.get('license_plate_text'),
|
||||
'confidence': result.get('confidence'),
|
||||
'display_id': result.get('display_id'),
|
||||
'timestamp': result.get('timestamp')
|
||||
}
|
||||
|
||||
def cleanup_expired_results(self):
|
||||
"""Remove expired license plate results from cache."""
|
||||
try:
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
expired_sessions = []
|
||||
|
||||
for session_id, result in self.license_plate_cache.items():
|
||||
if current_time - result.get('received_at', 0) > self.cache_ttl:
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
for session_id in expired_sessions:
|
||||
del self.license_plate_cache[session_id]
|
||||
logger.debug(f"Removed expired license plate result for session {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up expired license plate results: {e}", exc_info=True)
|
||||
|
||||
async def close(self):
|
||||
"""Close Redis connection and cleanup resources."""
|
||||
try:
|
||||
# Cancel subscription task first
|
||||
if self.subscription_task and not self.subscription_task.done():
|
||||
self.subscription_task.cancel()
|
||||
try:
|
||||
await self.subscription_task
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("License plate subscription task cancelled successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error waiting for subscription task cancellation: {e}")
|
||||
|
||||
# Close pubsub connection properly
|
||||
if self.pubsub:
|
||||
try:
|
||||
# First unsubscribe from channels
|
||||
await self.pubsub.unsubscribe('license_results')
|
||||
# Then close the pubsub connection
|
||||
await self.pubsub.aclose()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing pubsub connection: {e}")
|
||||
finally:
|
||||
self.pubsub = None
|
||||
|
||||
# Close Redis connection
|
||||
if self.redis_client:
|
||||
try:
|
||||
await self.redis_client.aclose()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing Redis connection: {e}")
|
||||
finally:
|
||||
self.redis_client = None
|
||||
|
||||
# Clear cache
|
||||
self.license_plate_cache.clear()
|
||||
|
||||
logger.info("License plate manager closed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing license plate manager: {e}", exc_info=True)
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get license plate manager statistics."""
|
||||
return {
|
||||
'cached_results': len(self.license_plate_cache),
|
||||
'connected': self.redis_client is not None,
|
||||
'subscribed': self.pubsub is not None,
|
||||
'host': self.host,
|
||||
'port': self.port
|
||||
}
|
||||
|
|
@ -1,478 +0,0 @@
|
|||
"""
|
||||
Redis Operations Module.
|
||||
Handles Redis connections, image storage, and pub/sub messaging.
|
||||
"""
|
||||
import logging
|
||||
import json
|
||||
import time
|
||||
from typing import Optional, Dict, Any, Union
|
||||
import asyncio
|
||||
import cv2
|
||||
import numpy as np
|
||||
import redis.asyncio as redis
|
||||
from redis.exceptions import ConnectionError, TimeoutError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisManager:
|
||||
"""
|
||||
Manages Redis connections and operations for the detection pipeline.
|
||||
Handles image storage with region cropping and pub/sub messaging.
|
||||
"""
|
||||
|
||||
def __init__(self, redis_config: Dict[str, Any]):
|
||||
"""
|
||||
Initialize Redis manager with configuration.
|
||||
|
||||
Args:
|
||||
redis_config: Redis configuration dictionary
|
||||
"""
|
||||
self.config = redis_config
|
||||
self.redis_client: Optional[redis.Redis] = None
|
||||
|
||||
# Connection parameters
|
||||
self.host = redis_config.get('host', 'localhost')
|
||||
self.port = redis_config.get('port', 6379)
|
||||
self.password = redis_config.get('password')
|
||||
self.db = redis_config.get('db', 0)
|
||||
self.decode_responses = redis_config.get('decode_responses', True)
|
||||
|
||||
# Connection pool settings
|
||||
self.max_connections = redis_config.get('max_connections', 10)
|
||||
self.socket_timeout = redis_config.get('socket_timeout', 5)
|
||||
self.socket_connect_timeout = redis_config.get('socket_connect_timeout', 5)
|
||||
self.health_check_interval = redis_config.get('health_check_interval', 30)
|
||||
|
||||
# Statistics
|
||||
self.stats = {
|
||||
'images_stored': 0,
|
||||
'messages_published': 0,
|
||||
'connection_errors': 0,
|
||||
'operations_successful': 0,
|
||||
'operations_failed': 0
|
||||
}
|
||||
|
||||
logger.info(f"RedisManager initialized for {self.host}:{self.port}")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""
|
||||
Initialize Redis connection and test connectivity.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Validate configuration
|
||||
if not self._validate_config():
|
||||
return False
|
||||
|
||||
# Create Redis connection
|
||||
self.redis_client = redis.Redis(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
password=self.password,
|
||||
db=self.db,
|
||||
decode_responses=self.decode_responses,
|
||||
max_connections=self.max_connections,
|
||||
socket_timeout=self.socket_timeout,
|
||||
socket_connect_timeout=self.socket_connect_timeout,
|
||||
health_check_interval=self.health_check_interval
|
||||
)
|
||||
|
||||
# Test connection
|
||||
await self.redis_client.ping()
|
||||
logger.info(f"Successfully connected to Redis at {self.host}:{self.port}")
|
||||
return True
|
||||
|
||||
except ConnectionError as e:
|
||||
logger.error(f"Failed to connect to Redis: {e}")
|
||||
self.stats['connection_errors'] += 1
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing Redis connection: {e}", exc_info=True)
|
||||
self.stats['connection_errors'] += 1
|
||||
return False
|
||||
|
||||
def _validate_config(self) -> bool:
|
||||
"""
|
||||
Validate Redis configuration parameters.
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
required_fields = ['host', 'port']
|
||||
for field in required_fields:
|
||||
if field not in self.config:
|
||||
logger.error(f"Missing required Redis config field: {field}")
|
||||
return False
|
||||
|
||||
if not isinstance(self.port, int) or self.port <= 0:
|
||||
logger.error(f"Invalid Redis port: {self.port}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def is_connected(self) -> bool:
|
||||
"""
|
||||
Check if Redis connection is active.
|
||||
|
||||
Returns:
|
||||
True if connected, False otherwise
|
||||
"""
|
||||
try:
|
||||
if self.redis_client:
|
||||
await self.redis_client.ping()
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
async def save_image(self,
|
||||
key: str,
|
||||
image: np.ndarray,
|
||||
expire_seconds: Optional[int] = None,
|
||||
image_format: str = 'jpeg',
|
||||
quality: int = 90) -> bool:
|
||||
"""
|
||||
Save image to Redis with optional expiration.
|
||||
|
||||
Args:
|
||||
key: Redis key for the image
|
||||
image: Image array to save
|
||||
expire_seconds: Optional expiration time in seconds
|
||||
image_format: Image format ('jpeg' or 'png')
|
||||
quality: JPEG quality (1-100)
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not self.redis_client:
|
||||
logger.error("Redis client not initialized")
|
||||
self.stats['operations_failed'] += 1
|
||||
return False
|
||||
|
||||
# Encode image
|
||||
encoded_image = self._encode_image(image, image_format, quality)
|
||||
if encoded_image is None:
|
||||
logger.error("Failed to encode image")
|
||||
self.stats['operations_failed'] += 1
|
||||
return False
|
||||
|
||||
# Save to Redis
|
||||
if expire_seconds:
|
||||
await self.redis_client.setex(key, expire_seconds, encoded_image)
|
||||
logger.debug(f"Saved image to Redis with key: {key} (expires in {expire_seconds}s)")
|
||||
else:
|
||||
await self.redis_client.set(key, encoded_image)
|
||||
logger.debug(f"Saved image to Redis with key: {key}")
|
||||
|
||||
self.stats['images_stored'] += 1
|
||||
self.stats['operations_successful'] += 1
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving image to Redis: {e}", exc_info=True)
|
||||
self.stats['operations_failed'] += 1
|
||||
return False
|
||||
|
||||
async def get_image(self, key: str) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Retrieve image from Redis.
|
||||
|
||||
Args:
|
||||
key: Redis key for the image
|
||||
|
||||
Returns:
|
||||
Image array or None if not found
|
||||
"""
|
||||
try:
|
||||
if not self.redis_client:
|
||||
logger.error("Redis client not initialized")
|
||||
self.stats['operations_failed'] += 1
|
||||
return None
|
||||
|
||||
# Get image data from Redis
|
||||
image_data = await self.redis_client.get(key)
|
||||
if image_data is None:
|
||||
logger.debug(f"Image not found for key: {key}")
|
||||
return None
|
||||
|
||||
# Decode image
|
||||
image_array = np.frombuffer(image_data, np.uint8)
|
||||
image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
||||
|
||||
if image is not None:
|
||||
logger.debug(f"Retrieved image from Redis with key: {key}")
|
||||
self.stats['operations_successful'] += 1
|
||||
return image
|
||||
else:
|
||||
logger.error(f"Failed to decode image for key: {key}")
|
||||
self.stats['operations_failed'] += 1
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving image from Redis: {e}", exc_info=True)
|
||||
self.stats['operations_failed'] += 1
|
||||
return None
|
||||
|
||||
async def delete_image(self, key: str) -> bool:
|
||||
"""
|
||||
Delete image from Redis.
|
||||
|
||||
Args:
|
||||
key: Redis key for the image
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not self.redis_client:
|
||||
logger.error("Redis client not initialized")
|
||||
self.stats['operations_failed'] += 1
|
||||
return False
|
||||
|
||||
result = await self.redis_client.delete(key)
|
||||
if result > 0:
|
||||
logger.debug(f"Deleted image from Redis with key: {key}")
|
||||
self.stats['operations_successful'] += 1
|
||||
return True
|
||||
else:
|
||||
logger.debug(f"Image not found for deletion: {key}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting image from Redis: {e}", exc_info=True)
|
||||
self.stats['operations_failed'] += 1
|
||||
return False
|
||||
|
||||
async def publish_message(self, channel: str, message: Union[str, Dict]) -> int:
|
||||
"""
|
||||
Publish message to Redis channel.
|
||||
|
||||
Args:
|
||||
channel: Redis channel name
|
||||
message: Message to publish (string or dict)
|
||||
|
||||
Returns:
|
||||
Number of subscribers that received the message, -1 on error
|
||||
"""
|
||||
try:
|
||||
if not self.redis_client:
|
||||
logger.error("Redis client not initialized")
|
||||
self.stats['operations_failed'] += 1
|
||||
return -1
|
||||
|
||||
# Convert dict to JSON string if needed
|
||||
if isinstance(message, dict):
|
||||
message_str = json.dumps(message)
|
||||
else:
|
||||
message_str = str(message)
|
||||
|
||||
# Test connection before publishing
|
||||
await self.redis_client.ping()
|
||||
|
||||
# Publish message
|
||||
result = await self.redis_client.publish(channel, message_str)
|
||||
|
||||
logger.info(f"Published message to Redis channel '{channel}': {message_str}")
|
||||
logger.info(f"Redis publish result (subscribers count): {result}")
|
||||
|
||||
if result == 0:
|
||||
logger.warning(f"No subscribers listening to channel '{channel}'")
|
||||
else:
|
||||
logger.info(f"Message delivered to {result} subscriber(s)")
|
||||
|
||||
self.stats['messages_published'] += 1
|
||||
self.stats['operations_successful'] += 1
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error publishing message to Redis: {e}", exc_info=True)
|
||||
self.stats['operations_failed'] += 1
|
||||
return -1
|
||||
|
||||
async def subscribe_to_channel(self, channel: str, callback=None):
|
||||
"""
|
||||
Subscribe to Redis channel (for future use).
|
||||
|
||||
Args:
|
||||
channel: Redis channel name
|
||||
callback: Optional callback function for messages
|
||||
"""
|
||||
try:
|
||||
if not self.redis_client:
|
||||
logger.error("Redis client not initialized")
|
||||
return
|
||||
|
||||
pubsub = self.redis_client.pubsub()
|
||||
await pubsub.subscribe(channel)
|
||||
|
||||
logger.info(f"Subscribed to Redis channel: {channel}")
|
||||
|
||||
if callback:
|
||||
async for message in pubsub.listen():
|
||||
if message['type'] == 'message':
|
||||
try:
|
||||
await callback(message['data'])
|
||||
except Exception as e:
|
||||
logger.error(f"Error in message callback: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error subscribing to Redis channel: {e}", exc_info=True)
|
||||
|
||||
async def set_key(self, key: str, value: Union[str, bytes], expire_seconds: Optional[int] = None) -> bool:
|
||||
"""
|
||||
Set a key-value pair in Redis.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
value: Value to store
|
||||
expire_seconds: Optional expiration time in seconds
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not self.redis_client:
|
||||
logger.error("Redis client not initialized")
|
||||
self.stats['operations_failed'] += 1
|
||||
return False
|
||||
|
||||
if expire_seconds:
|
||||
await self.redis_client.setex(key, expire_seconds, value)
|
||||
else:
|
||||
await self.redis_client.set(key, value)
|
||||
|
||||
logger.debug(f"Set Redis key: {key}")
|
||||
self.stats['operations_successful'] += 1
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting Redis key: {e}", exc_info=True)
|
||||
self.stats['operations_failed'] += 1
|
||||
return False
|
||||
|
||||
async def get_key(self, key: str) -> Optional[Union[str, bytes]]:
|
||||
"""
|
||||
Get value for a Redis key.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
|
||||
Returns:
|
||||
Value or None if not found
|
||||
"""
|
||||
try:
|
||||
if not self.redis_client:
|
||||
logger.error("Redis client not initialized")
|
||||
self.stats['operations_failed'] += 1
|
||||
return None
|
||||
|
||||
value = await self.redis_client.get(key)
|
||||
if value is not None:
|
||||
logger.debug(f"Retrieved Redis key: {key}")
|
||||
self.stats['operations_successful'] += 1
|
||||
|
||||
return value
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Redis key: {e}", exc_info=True)
|
||||
self.stats['operations_failed'] += 1
|
||||
return None
|
||||
|
||||
async def delete_key(self, key: str) -> bool:
|
||||
"""
|
||||
Delete a Redis key.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not self.redis_client:
|
||||
logger.error("Redis client not initialized")
|
||||
self.stats['operations_failed'] += 1
|
||||
return False
|
||||
|
||||
result = await self.redis_client.delete(key)
|
||||
if result > 0:
|
||||
logger.debug(f"Deleted Redis key: {key}")
|
||||
self.stats['operations_successful'] += 1
|
||||
return True
|
||||
else:
|
||||
logger.debug(f"Redis key not found: {key}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting Redis key: {e}", exc_info=True)
|
||||
self.stats['operations_failed'] += 1
|
||||
return False
|
||||
|
||||
def _encode_image(self, image: np.ndarray, image_format: str, quality: int) -> Optional[bytes]:
|
||||
"""
|
||||
Encode image to bytes for Redis storage.
|
||||
|
||||
Args:
|
||||
image: Image array
|
||||
image_format: Image format ('jpeg' or 'png')
|
||||
quality: JPEG quality (1-100)
|
||||
|
||||
Returns:
|
||||
Encoded image bytes or None on error
|
||||
"""
|
||||
try:
|
||||
format_lower = image_format.lower()
|
||||
|
||||
if format_lower == 'jpeg' or format_lower == 'jpg':
|
||||
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
|
||||
success, buffer = cv2.imencode('.jpg', image, encode_params)
|
||||
elif format_lower == 'png':
|
||||
success, buffer = cv2.imencode('.png', image)
|
||||
else:
|
||||
logger.warning(f"Unknown image format '{image_format}', using JPEG")
|
||||
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
|
||||
success, buffer = cv2.imencode('.jpg', image, encode_params)
|
||||
|
||||
if success:
|
||||
return buffer.tobytes()
|
||||
else:
|
||||
logger.error(f"Failed to encode image as {image_format}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error encoding image: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get Redis manager statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics
|
||||
"""
|
||||
return {
|
||||
**self.stats,
|
||||
'connected': self.redis_client is not None,
|
||||
'host': self.host,
|
||||
'port': self.port,
|
||||
'db': self.db
|
||||
}
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup Redis connection."""
|
||||
if self.redis_client:
|
||||
# Note: redis.asyncio doesn't have a synchronous close method
|
||||
# The connection will be closed when the event loop shuts down
|
||||
self.redis_client = None
|
||||
logger.info("Redis connection cleaned up")
|
||||
|
||||
async def aclose(self):
|
||||
"""Async cleanup for Redis connection."""
|
||||
if self.redis_client:
|
||||
await self.redis_client.aclose()
|
||||
self.redis_client = None
|
||||
logger.info("Redis connection closed")
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
"""
|
||||
Streaming system for RTSP and HTTP camera feeds.
|
||||
Provides modular frame readers, buffers, and stream management.
|
||||
"""
|
||||
from .readers import HTTPSnapshotReader, FFmpegRTSPReader
|
||||
from .buffers import FrameBuffer, CacheBuffer, shared_frame_buffer, shared_cache_buffer
|
||||
from .manager import StreamManager, StreamConfig, SubscriptionInfo, shared_stream_manager, initialize_stream_manager
|
||||
|
||||
__all__ = [
|
||||
# Readers
|
||||
'HTTPSnapshotReader',
|
||||
'FFmpegRTSPReader',
|
||||
|
||||
# Buffers
|
||||
'FrameBuffer',
|
||||
'CacheBuffer',
|
||||
'shared_frame_buffer',
|
||||
'shared_cache_buffer',
|
||||
|
||||
# Manager
|
||||
'StreamManager',
|
||||
'StreamConfig',
|
||||
'SubscriptionInfo',
|
||||
'shared_stream_manager',
|
||||
'initialize_stream_manager'
|
||||
]
|
||||
|
|
@ -1,295 +0,0 @@
|
|||
"""
|
||||
Frame buffering and caching system optimized for different stream formats.
|
||||
Supports 1280x720 RTSP streams and 2560x1440 HTTP snapshots.
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
import cv2
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FrameBuffer:
|
||||
"""Thread-safe frame buffer for all camera streams."""
|
||||
|
||||
def __init__(self, max_age_seconds: int = 5):
|
||||
self.max_age_seconds = max_age_seconds
|
||||
self._frames: Dict[str, Dict[str, Any]] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def put_frame(self, camera_id: str, frame: np.ndarray):
|
||||
"""Store a frame for the given camera ID."""
|
||||
with self._lock:
|
||||
# Validate frame
|
||||
if not self._validate_frame(frame):
|
||||
logger.warning(f"Frame validation failed for camera {camera_id}")
|
||||
return
|
||||
|
||||
self._frames[camera_id] = {
|
||||
'frame': frame.copy(),
|
||||
'timestamp': time.time(),
|
||||
'shape': frame.shape,
|
||||
'dtype': str(frame.dtype),
|
||||
'size_mb': frame.nbytes / (1024 * 1024)
|
||||
}
|
||||
|
||||
def get_frame(self, camera_id: str) -> Optional[np.ndarray]:
|
||||
"""Get the latest frame for the given camera ID."""
|
||||
with self._lock:
|
||||
if camera_id not in self._frames:
|
||||
return None
|
||||
|
||||
frame_data = self._frames[camera_id]
|
||||
|
||||
# Return frame regardless of age - frames persist until replaced
|
||||
return frame_data['frame'].copy()
|
||||
|
||||
def get_frame_info(self, camera_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get frame metadata without copying the frame data."""
|
||||
with self._lock:
|
||||
if camera_id not in self._frames:
|
||||
return None
|
||||
|
||||
frame_data = self._frames[camera_id]
|
||||
age = time.time() - frame_data['timestamp']
|
||||
|
||||
# Return frame info regardless of age - frames persist until replaced
|
||||
return {
|
||||
'timestamp': frame_data['timestamp'],
|
||||
'age': age,
|
||||
'shape': frame_data['shape'],
|
||||
'dtype': frame_data['dtype'],
|
||||
'size_mb': frame_data.get('size_mb', 0)
|
||||
}
|
||||
|
||||
def has_frame(self, camera_id: str) -> bool:
|
||||
"""Check if a valid frame exists for the camera."""
|
||||
return self.get_frame_info(camera_id) is not None
|
||||
|
||||
def clear_camera(self, camera_id: str):
|
||||
"""Remove all frames for a specific camera."""
|
||||
with self._lock:
|
||||
if camera_id in self._frames:
|
||||
del self._frames[camera_id]
|
||||
logger.debug(f"Cleared frames for camera {camera_id}")
|
||||
|
||||
def clear_all(self):
|
||||
"""Clear all stored frames."""
|
||||
with self._lock:
|
||||
count = len(self._frames)
|
||||
self._frames.clear()
|
||||
logger.debug(f"Cleared all frames ({count} cameras)")
|
||||
|
||||
def get_camera_list(self) -> list:
|
||||
"""Get list of cameras with frames - all frames persist until replaced."""
|
||||
with self._lock:
|
||||
# Return all cameras that have frames - no age-based filtering
|
||||
return list(self._frames.keys())
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get buffer statistics."""
|
||||
with self._lock:
|
||||
current_time = time.time()
|
||||
stats = {
|
||||
'total_cameras': len(self._frames),
|
||||
'recent_cameras': 0,
|
||||
'stale_cameras': 0,
|
||||
'total_memory_mb': 0,
|
||||
'cameras': {}
|
||||
}
|
||||
|
||||
for camera_id, frame_data in self._frames.items():
|
||||
age = current_time - frame_data['timestamp']
|
||||
size_mb = frame_data.get('size_mb', 0)
|
||||
|
||||
# All frames are valid/available, but categorize by freshness for monitoring
|
||||
if age <= self.max_age_seconds:
|
||||
stats['recent_cameras'] += 1
|
||||
else:
|
||||
stats['stale_cameras'] += 1
|
||||
|
||||
stats['total_memory_mb'] += size_mb
|
||||
|
||||
stats['cameras'][camera_id] = {
|
||||
'age': age,
|
||||
'recent': age <= self.max_age_seconds, # Recent but all frames available
|
||||
'shape': frame_data['shape'],
|
||||
'dtype': frame_data['dtype'],
|
||||
'size_mb': size_mb
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def _validate_frame(self, frame: np.ndarray) -> bool:
|
||||
"""Validate frame - basic validation for any stream type."""
|
||||
if frame is None or frame.size == 0:
|
||||
return False
|
||||
|
||||
h, w = frame.shape[:2]
|
||||
size_mb = frame.nbytes / (1024 * 1024)
|
||||
|
||||
# Basic size validation - reject extremely large frames regardless of type
|
||||
max_size_mb = 50 # Generous limit for any frame type
|
||||
if size_mb > max_size_mb:
|
||||
logger.warning(f"Frame too large: {size_mb:.2f}MB (max {max_size_mb}MB) for {w}x{h}")
|
||||
return False
|
||||
|
||||
# Basic dimension validation
|
||||
if w < 100 or h < 100:
|
||||
logger.warning(f"Frame too small: {w}x{h}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class CacheBuffer:
|
||||
"""Enhanced frame cache with support for cropping."""
|
||||
|
||||
def __init__(self, max_age_seconds: int = 10):
|
||||
self.frame_buffer = FrameBuffer(max_age_seconds)
|
||||
self._crop_cache: Dict[str, Dict[str, Any]] = {}
|
||||
self._cache_lock = threading.RLock()
|
||||
self.jpeg_quality = 95 # High quality for all frames
|
||||
|
||||
def put_frame(self, camera_id: str, frame: np.ndarray):
|
||||
"""Store a frame and clear any associated crop cache."""
|
||||
self.frame_buffer.put_frame(camera_id, frame)
|
||||
|
||||
# Clear crop cache for this camera since we have a new frame
|
||||
with self._cache_lock:
|
||||
keys_to_remove = [key for key in self._crop_cache.keys() if key.startswith(f"{camera_id}_")]
|
||||
for key in keys_to_remove:
|
||||
del self._crop_cache[key]
|
||||
|
||||
def get_frame(self, camera_id: str, crop_coords: Optional[Tuple[int, int, int, int]] = None) -> Optional[np.ndarray]:
|
||||
"""Get frame with optional cropping."""
|
||||
if crop_coords is None:
|
||||
return self.frame_buffer.get_frame(camera_id)
|
||||
|
||||
# Check crop cache first
|
||||
crop_key = f"{camera_id}_{crop_coords}"
|
||||
with self._cache_lock:
|
||||
if crop_key in self._crop_cache:
|
||||
cache_entry = self._crop_cache[crop_key]
|
||||
age = time.time() - cache_entry['timestamp']
|
||||
if age <= self.frame_buffer.max_age_seconds:
|
||||
return cache_entry['cropped_frame'].copy()
|
||||
else:
|
||||
del self._crop_cache[crop_key]
|
||||
|
||||
# Get original frame and crop it
|
||||
original_frame = self.frame_buffer.get_frame(camera_id)
|
||||
if original_frame is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
x1, y1, x2, y2 = crop_coords
|
||||
|
||||
# Ensure coordinates are within frame bounds
|
||||
h, w = original_frame.shape[:2]
|
||||
x1 = max(0, min(x1, w))
|
||||
y1 = max(0, min(y1, h))
|
||||
x2 = max(x1, min(x2, w))
|
||||
y2 = max(y1, min(y2, h))
|
||||
|
||||
cropped_frame = original_frame[y1:y2, x1:x2]
|
||||
|
||||
# Cache the cropped frame
|
||||
with self._cache_lock:
|
||||
# Limit cache size to prevent memory issues
|
||||
if len(self._crop_cache) > 100:
|
||||
# Remove oldest entries
|
||||
oldest_keys = sorted(self._crop_cache.keys(),
|
||||
key=lambda k: self._crop_cache[k]['timestamp'])[:50]
|
||||
for key in oldest_keys:
|
||||
del self._crop_cache[key]
|
||||
|
||||
self._crop_cache[crop_key] = {
|
||||
'cropped_frame': cropped_frame.copy(),
|
||||
'timestamp': time.time(),
|
||||
'crop_coords': (x1, y1, x2, y2)
|
||||
}
|
||||
|
||||
return cropped_frame
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cropping frame for camera {camera_id}: {e}")
|
||||
return original_frame
|
||||
|
||||
def get_frame_as_jpeg(self, camera_id: str, crop_coords: Optional[Tuple[int, int, int, int]] = None,
|
||||
quality: Optional[int] = None) -> Optional[bytes]:
|
||||
"""Get frame as JPEG bytes."""
|
||||
frame = self.get_frame(camera_id, crop_coords)
|
||||
if frame is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Use specified quality or default
|
||||
if quality is None:
|
||||
quality = self.jpeg_quality
|
||||
|
||||
# Encode as JPEG with specified quality
|
||||
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
|
||||
success, encoded_img = cv2.imencode('.jpg', frame, encode_params)
|
||||
|
||||
if success:
|
||||
jpeg_bytes = encoded_img.tobytes()
|
||||
logger.debug(f"Encoded JPEG for camera {camera_id}: quality={quality}, size={len(jpeg_bytes)} bytes")
|
||||
return jpeg_bytes
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error encoding frame as JPEG for camera {camera_id}: {e}")
|
||||
return None
|
||||
|
||||
def has_frame(self, camera_id: str) -> bool:
|
||||
"""Check if a valid frame exists for the camera."""
|
||||
return self.frame_buffer.has_frame(camera_id)
|
||||
|
||||
def clear_camera(self, camera_id: str):
|
||||
"""Remove all frames and cache for a specific camera."""
|
||||
self.frame_buffer.clear_camera(camera_id)
|
||||
with self._cache_lock:
|
||||
# Clear crop cache entries for this camera
|
||||
keys_to_remove = [key for key in self._crop_cache.keys() if key.startswith(f"{camera_id}_")]
|
||||
for key in keys_to_remove:
|
||||
del self._crop_cache[key]
|
||||
|
||||
def clear_all(self):
|
||||
"""Clear all stored frames and cache."""
|
||||
self.frame_buffer.clear_all()
|
||||
with self._cache_lock:
|
||||
self._crop_cache.clear()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive buffer and cache statistics."""
|
||||
buffer_stats = self.frame_buffer.get_stats()
|
||||
|
||||
with self._cache_lock:
|
||||
cache_stats = {
|
||||
'crop_cache_entries': len(self._crop_cache),
|
||||
'crop_cache_cameras': len(set(key.split('_')[0] for key in self._crop_cache.keys() if '_' in key)),
|
||||
'crop_cache_memory_mb': sum(
|
||||
entry['cropped_frame'].nbytes / (1024 * 1024)
|
||||
for entry in self._crop_cache.values()
|
||||
)
|
||||
}
|
||||
|
||||
return {
|
||||
'buffer': buffer_stats,
|
||||
'cache': cache_stats,
|
||||
'total_memory_mb': buffer_stats.get('total_memory_mb', 0) + cache_stats.get('crop_cache_memory_mb', 0)
|
||||
}
|
||||
|
||||
|
||||
# Global shared instances for application use
|
||||
shared_frame_buffer = FrameBuffer(max_age_seconds=5)
|
||||
shared_cache_buffer = CacheBuffer(max_age_seconds=10)
|
||||
|
||||
|
||||
|
|
@ -1,722 +0,0 @@
|
|||
"""
|
||||
Stream coordination and lifecycle management.
|
||||
Optimized for 1280x720@6fps RTSP and 2560x1440 HTTP snapshots.
|
||||
"""
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import queue
|
||||
import asyncio
|
||||
from typing import Dict, Set, Optional, List, Any
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
|
||||
from .readers import HTTPSnapshotReader, FFmpegRTSPReader
|
||||
from .buffers import shared_cache_buffer
|
||||
from ..tracking.integration import TrackingPipelineIntegration
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamConfig:
|
||||
"""Configuration for a stream."""
|
||||
camera_id: str
|
||||
rtsp_url: Optional[str] = None
|
||||
snapshot_url: Optional[str] = None
|
||||
snapshot_interval: int = 5000 # milliseconds
|
||||
max_retries: int = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubscriptionInfo:
|
||||
"""Information about a subscription."""
|
||||
subscription_id: str
|
||||
camera_id: str
|
||||
stream_config: StreamConfig
|
||||
created_at: float
|
||||
crop_coords: Optional[tuple] = None
|
||||
model_id: Optional[str] = None
|
||||
model_url: Optional[str] = None
|
||||
tracking_integration: Optional[TrackingPipelineIntegration] = None
|
||||
|
||||
|
||||
class StreamManager:
|
||||
"""Manages multiple camera streams with shared optimization."""
|
||||
|
||||
def __init__(self, max_streams: int = 10):
|
||||
self.max_streams = max_streams
|
||||
self._streams: Dict[str, Any] = {} # camera_id -> reader instance
|
||||
self._subscriptions: Dict[str, SubscriptionInfo] = {} # subscription_id -> info
|
||||
self._camera_subscribers: Dict[str, Set[str]] = defaultdict(set) # camera_id -> set of subscription_ids
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Fair tracking queue system - per camera queues
|
||||
self._tracking_queues: Dict[str, queue.Queue] = {} # camera_id -> queue
|
||||
self._tracking_workers = []
|
||||
self._stop_workers = threading.Event()
|
||||
self._dropped_frame_counts: Dict[str, int] = {} # per-camera drop counts
|
||||
|
||||
# Round-robin scheduling state
|
||||
self._camera_list = [] # Ordered list of active cameras
|
||||
self._camera_round_robin_index = 0
|
||||
self._round_robin_lock = threading.Lock()
|
||||
|
||||
# Start worker threads for tracking processing
|
||||
num_workers = min(4, max_streams // 2 + 1) # Scale with streams
|
||||
for i in range(num_workers):
|
||||
worker = threading.Thread(
|
||||
target=self._tracking_worker_loop,
|
||||
name=f"TrackingWorker-{i}",
|
||||
daemon=True
|
||||
)
|
||||
worker.start()
|
||||
self._tracking_workers.append(worker)
|
||||
|
||||
logger.info(f"Started {num_workers} tracking worker threads")
|
||||
|
||||
def _ensure_camera_queue(self, camera_id: str):
|
||||
"""Ensure a tracking queue exists for the camera."""
|
||||
if camera_id not in self._tracking_queues:
|
||||
self._tracking_queues[camera_id] = queue.Queue(maxsize=10) # 10 frames per camera
|
||||
self._dropped_frame_counts[camera_id] = 0
|
||||
|
||||
with self._round_robin_lock:
|
||||
if camera_id not in self._camera_list:
|
||||
self._camera_list.append(camera_id)
|
||||
logger.info(f"Created tracking queue for camera {camera_id}")
|
||||
else:
|
||||
logger.debug(f"Camera {camera_id} already has tracking queue")
|
||||
|
||||
def _remove_camera_queue(self, camera_id: str):
|
||||
"""Remove tracking queue for a camera that's no longer active."""
|
||||
if camera_id in self._tracking_queues:
|
||||
# Clear any remaining items
|
||||
while not self._tracking_queues[camera_id].empty():
|
||||
try:
|
||||
self._tracking_queues[camera_id].get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
del self._tracking_queues[camera_id]
|
||||
del self._dropped_frame_counts[camera_id]
|
||||
|
||||
with self._round_robin_lock:
|
||||
if camera_id in self._camera_list:
|
||||
self._camera_list.remove(camera_id)
|
||||
# Reset index if needed
|
||||
if self._camera_round_robin_index >= len(self._camera_list):
|
||||
self._camera_round_robin_index = 0
|
||||
|
||||
logger.info(f"Removed tracking queue for camera {camera_id}")
|
||||
|
||||
def add_subscription(self, subscription_id: str, stream_config: StreamConfig,
|
||||
crop_coords: Optional[tuple] = None,
|
||||
model_id: Optional[str] = None,
|
||||
model_url: Optional[str] = None,
|
||||
tracking_integration: Optional[TrackingPipelineIntegration] = None) -> bool:
|
||||
"""Add a new subscription. Returns True if successful."""
|
||||
with self._lock:
|
||||
if subscription_id in self._subscriptions:
|
||||
logger.warning(f"Subscription {subscription_id} already exists")
|
||||
return False
|
||||
|
||||
camera_id = stream_config.camera_id
|
||||
|
||||
# Create subscription info
|
||||
subscription_info = SubscriptionInfo(
|
||||
subscription_id=subscription_id,
|
||||
camera_id=camera_id,
|
||||
stream_config=stream_config,
|
||||
created_at=time.time(),
|
||||
crop_coords=crop_coords,
|
||||
model_id=model_id,
|
||||
model_url=model_url,
|
||||
tracking_integration=tracking_integration
|
||||
)
|
||||
|
||||
# Pass subscription info to tracking integration for snapshot access
|
||||
if tracking_integration:
|
||||
tracking_integration.set_subscription_info(subscription_info)
|
||||
|
||||
self._subscriptions[subscription_id] = subscription_info
|
||||
self._camera_subscribers[camera_id].add(subscription_id)
|
||||
|
||||
# Start stream if not already running
|
||||
if camera_id not in self._streams:
|
||||
if len(self._streams) >= self.max_streams:
|
||||
logger.error(f"Maximum streams ({self.max_streams}) reached, cannot add {camera_id}")
|
||||
self._remove_subscription_internal(subscription_id)
|
||||
return False
|
||||
|
||||
success = self._start_stream(camera_id, stream_config)
|
||||
if not success:
|
||||
self._remove_subscription_internal(subscription_id)
|
||||
return False
|
||||
else:
|
||||
# Stream already exists, but ensure queue exists too
|
||||
logger.info(f"Stream already exists for {camera_id}, ensuring queue exists")
|
||||
self._ensure_camera_queue(camera_id)
|
||||
|
||||
logger.info(f"Added subscription {subscription_id} for camera {camera_id} "
|
||||
f"({len(self._camera_subscribers[camera_id])} total subscribers)")
|
||||
return True
|
||||
|
||||
def remove_subscription(self, subscription_id: str) -> bool:
|
||||
"""Remove a subscription. Returns True if found and removed."""
|
||||
with self._lock:
|
||||
return self._remove_subscription_internal(subscription_id)
|
||||
|
||||
def _remove_subscription_internal(self, subscription_id: str) -> bool:
|
||||
"""Internal method to remove subscription (assumes lock is held)."""
|
||||
if subscription_id not in self._subscriptions:
|
||||
logger.warning(f"Subscription {subscription_id} not found")
|
||||
return False
|
||||
|
||||
subscription_info = self._subscriptions[subscription_id]
|
||||
camera_id = subscription_info.camera_id
|
||||
|
||||
# Remove from tracking
|
||||
del self._subscriptions[subscription_id]
|
||||
self._camera_subscribers[camera_id].discard(subscription_id)
|
||||
|
||||
# Stop stream if no more subscribers
|
||||
if not self._camera_subscribers[camera_id]:
|
||||
self._stop_stream(camera_id)
|
||||
del self._camera_subscribers[camera_id]
|
||||
|
||||
logger.info(f"Removed subscription {subscription_id} for camera {camera_id} "
|
||||
f"({len(self._camera_subscribers[camera_id])} remaining subscribers)")
|
||||
return True
|
||||
|
||||
def _start_stream(self, camera_id: str, stream_config: StreamConfig) -> bool:
|
||||
"""Start a stream for the given camera."""
|
||||
try:
|
||||
if stream_config.rtsp_url:
|
||||
# RTSP stream using FFmpeg subprocess with CUDA acceleration
|
||||
logger.info(f"\033[94m[RTSP] Starting {camera_id}\033[0m")
|
||||
reader = FFmpegRTSPReader(
|
||||
camera_id=camera_id,
|
||||
rtsp_url=stream_config.rtsp_url,
|
||||
max_retries=stream_config.max_retries
|
||||
)
|
||||
reader.set_frame_callback(self._frame_callback)
|
||||
reader.start()
|
||||
self._streams[camera_id] = reader
|
||||
self._ensure_camera_queue(camera_id) # Create tracking queue
|
||||
logger.info(f"\033[92m[RTSP] {camera_id} connected\033[0m")
|
||||
|
||||
elif stream_config.snapshot_url:
|
||||
# HTTP snapshot stream
|
||||
logger.info(f"\033[95m[HTTP] Starting {camera_id}\033[0m")
|
||||
reader = HTTPSnapshotReader(
|
||||
camera_id=camera_id,
|
||||
snapshot_url=stream_config.snapshot_url,
|
||||
interval_ms=stream_config.snapshot_interval,
|
||||
max_retries=stream_config.max_retries
|
||||
)
|
||||
reader.set_frame_callback(self._frame_callback)
|
||||
reader.start()
|
||||
self._streams[camera_id] = reader
|
||||
self._ensure_camera_queue(camera_id) # Create tracking queue
|
||||
logger.info(f"\033[92m[HTTP] {camera_id} connected\033[0m")
|
||||
|
||||
else:
|
||||
logger.error(f"No valid URL provided for camera {camera_id}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting stream for camera {camera_id}: {e}")
|
||||
return False
|
||||
|
||||
def _stop_stream(self, camera_id: str):
|
||||
"""Stop a stream for the given camera."""
|
||||
if camera_id in self._streams:
|
||||
try:
|
||||
self._streams[camera_id].stop()
|
||||
del self._streams[camera_id]
|
||||
self._remove_camera_queue(camera_id) # Remove tracking queue
|
||||
# DON'T clear frames - they should persist until replaced
|
||||
# shared_cache_buffer.clear_camera(camera_id) # REMOVED - frames should persist
|
||||
logger.info(f"Stopped stream for camera {camera_id} (frames preserved in buffer)")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping stream for camera {camera_id}: {e}")
|
||||
|
||||
def _frame_callback(self, camera_id: str, frame):
|
||||
"""Callback for when a new frame is available."""
|
||||
try:
|
||||
# Store frame in shared buffer
|
||||
shared_cache_buffer.put_frame(camera_id, frame)
|
||||
# Quieter frame callback logging - only log occasionally
|
||||
if hasattr(self, '_frame_log_count'):
|
||||
self._frame_log_count += 1
|
||||
else:
|
||||
self._frame_log_count = 1
|
||||
|
||||
# Log every 100 frames to avoid spam
|
||||
if self._frame_log_count % 100 == 0:
|
||||
available_cameras = shared_cache_buffer.frame_buffer.get_camera_list()
|
||||
logger.info(f"\033[96m[BUFFER] {len(available_cameras)} active cameras: {', '.join(available_cameras)}\033[0m")
|
||||
|
||||
# Queue for tracking processing (non-blocking) - route to camera-specific queue
|
||||
if camera_id in self._tracking_queues:
|
||||
try:
|
||||
self._tracking_queues[camera_id].put_nowait({
|
||||
'frame': frame,
|
||||
'timestamp': time.time()
|
||||
})
|
||||
except queue.Full:
|
||||
# Drop frame if camera queue is full (maintain real-time)
|
||||
self._dropped_frame_counts[camera_id] += 1
|
||||
|
||||
if self._dropped_frame_counts[camera_id] % 50 == 0:
|
||||
logger.warning(f"Dropped {self._dropped_frame_counts[camera_id]} frames for camera {camera_id} due to full queue")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in frame callback for camera {camera_id}: {e}")
|
||||
|
||||
def _process_tracking_for_camera(self, camera_id: str, frame):
|
||||
"""Process tracking for all subscriptions of a camera."""
|
||||
try:
|
||||
with self._lock:
|
||||
for subscription_id in self._camera_subscribers[camera_id]:
|
||||
subscription_info = self._subscriptions[subscription_id]
|
||||
|
||||
# Skip if no tracking integration
|
||||
if not subscription_info.tracking_integration:
|
||||
continue
|
||||
|
||||
# Extract display_id from subscription_id
|
||||
display_id = subscription_id.split(';')[0] if ';' in subscription_id else subscription_id
|
||||
|
||||
# Process frame through tracking asynchronously
|
||||
# Note: This is synchronous for now, can be made async in future
|
||||
try:
|
||||
# Create a simple asyncio event loop for this frame
|
||||
import asyncio
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result = loop.run_until_complete(
|
||||
subscription_info.tracking_integration.process_frame(
|
||||
frame, display_id, subscription_id
|
||||
)
|
||||
)
|
||||
# Log tracking results
|
||||
if result:
|
||||
tracked_count = len(result.get('tracked_vehicles', []))
|
||||
validated_vehicle = result.get('validated_vehicle')
|
||||
pipeline_result = result.get('pipeline_result')
|
||||
|
||||
if tracked_count > 0:
|
||||
logger.info(f"[Tracking] {camera_id}: {tracked_count} vehicles tracked")
|
||||
|
||||
if validated_vehicle:
|
||||
logger.info(f"[Tracking] {camera_id}: Vehicle {validated_vehicle['track_id']} "
|
||||
f"validated as {validated_vehicle['state']} "
|
||||
f"(confidence: {validated_vehicle['confidence']:.2f})")
|
||||
|
||||
if pipeline_result:
|
||||
logger.info(f"[Pipeline] {camera_id}: {pipeline_result.get('status', 'unknown')} - "
|
||||
f"{pipeline_result.get('message', 'no message')}")
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception as track_e:
|
||||
logger.error(f"Error in tracking for {subscription_id}: {track_e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing tracking for camera {camera_id}: {e}")
|
||||
|
||||
def _tracking_worker_loop(self):
|
||||
"""Worker thread loop for round-robin processing of camera queues."""
|
||||
logger.info(f"Tracking worker {threading.current_thread().name} started")
|
||||
|
||||
consecutive_empty = 0
|
||||
max_consecutive_empty = 10 # Sleep if all cameras empty this many times
|
||||
|
||||
while not self._stop_workers.is_set():
|
||||
try:
|
||||
# Get next camera in round-robin fashion
|
||||
camera_id, item = self._get_next_camera_item()
|
||||
|
||||
if camera_id is None:
|
||||
# No cameras have items, sleep briefly
|
||||
consecutive_empty += 1
|
||||
if consecutive_empty >= max_consecutive_empty:
|
||||
time.sleep(0.1) # Sleep 100ms if nothing to process
|
||||
consecutive_empty = 0
|
||||
continue
|
||||
|
||||
consecutive_empty = 0 # Reset counter when we find work
|
||||
|
||||
frame = item['frame']
|
||||
timestamp = item['timestamp']
|
||||
|
||||
# Check if frame is too old (drop if > 1 second old)
|
||||
age = time.time() - timestamp
|
||||
if age > 1.0:
|
||||
logger.debug(f"Dropping old frame for {camera_id} (age: {age:.2f}s)")
|
||||
continue
|
||||
|
||||
# Process tracking for this camera's frame
|
||||
self._process_tracking_for_camera_sync(camera_id, frame)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in tracking worker: {e}", exc_info=True)
|
||||
|
||||
logger.info(f"Tracking worker {threading.current_thread().name} stopped")
|
||||
|
||||
def _get_next_camera_item(self):
|
||||
"""Get next item from camera queues using round-robin scheduling."""
|
||||
with self._round_robin_lock:
|
||||
# Get current list of cameras from actual tracking queues (central state)
|
||||
camera_list = list(self._tracking_queues.keys())
|
||||
|
||||
if not camera_list:
|
||||
return None, None
|
||||
|
||||
attempts = 0
|
||||
max_attempts = len(camera_list)
|
||||
|
||||
while attempts < max_attempts:
|
||||
# Get current camera using round-robin index
|
||||
if self._camera_round_robin_index >= len(camera_list):
|
||||
self._camera_round_robin_index = 0
|
||||
|
||||
camera_id = camera_list[self._camera_round_robin_index]
|
||||
|
||||
# Move to next camera for next call
|
||||
self._camera_round_robin_index = (self._camera_round_robin_index + 1) % len(camera_list)
|
||||
|
||||
# Try to get item from this camera's queue
|
||||
try:
|
||||
item = self._tracking_queues[camera_id].get_nowait()
|
||||
return camera_id, item
|
||||
except queue.Empty:
|
||||
pass # Try next camera
|
||||
|
||||
attempts += 1
|
||||
|
||||
return None, None # All cameras empty
|
||||
|
||||
def _process_tracking_for_camera_sync(self, camera_id: str, frame):
|
||||
"""Synchronous version of tracking processing for worker threads."""
|
||||
try:
|
||||
with self._lock:
|
||||
subscription_ids = list(self._camera_subscribers.get(camera_id, []))
|
||||
|
||||
for subscription_id in subscription_ids:
|
||||
subscription_info = self._subscriptions.get(subscription_id)
|
||||
|
||||
if not subscription_info:
|
||||
logger.warning(f"No subscription info found for {subscription_id}")
|
||||
continue
|
||||
|
||||
if not subscription_info.tracking_integration:
|
||||
logger.debug(f"No tracking integration for {subscription_id} (camera {camera_id}), skipping inference")
|
||||
continue
|
||||
|
||||
display_id = subscription_id.split(';')[0] if ';' in subscription_id else subscription_id
|
||||
|
||||
try:
|
||||
# Run async tracking in thread's event loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result = loop.run_until_complete(
|
||||
subscription_info.tracking_integration.process_frame(
|
||||
frame, display_id, subscription_id
|
||||
)
|
||||
)
|
||||
|
||||
# Log tracking results
|
||||
if result:
|
||||
tracked_count = len(result.get('tracked_vehicles', []))
|
||||
validated_vehicle = result.get('validated_vehicle')
|
||||
pipeline_result = result.get('pipeline_result')
|
||||
|
||||
if tracked_count > 0:
|
||||
logger.info(f"[Tracking] {camera_id}: {tracked_count} vehicles tracked")
|
||||
|
||||
if validated_vehicle:
|
||||
logger.info(f"[Tracking] {camera_id}: Vehicle {validated_vehicle['track_id']} "
|
||||
f"validated as {validated_vehicle['state']} "
|
||||
f"(confidence: {validated_vehicle['confidence']:.2f})")
|
||||
|
||||
if pipeline_result:
|
||||
logger.info(f"[Pipeline] {camera_id}: {pipeline_result.get('status', 'unknown')} - "
|
||||
f"{pipeline_result.get('message', 'no message')}")
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
except Exception as track_e:
|
||||
logger.error(f"Error in tracking for {subscription_id}: {track_e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing tracking for camera {camera_id}: {e}")
|
||||
|
||||
def get_frame(self, camera_id: str, crop_coords: Optional[tuple] = None):
|
||||
"""Get the latest frame for a camera with optional cropping."""
|
||||
return shared_cache_buffer.get_frame(camera_id, crop_coords)
|
||||
|
||||
def get_frame_as_jpeg(self, camera_id: str, crop_coords: Optional[tuple] = None,
|
||||
quality: int = 100) -> Optional[bytes]:
|
||||
"""Get frame as JPEG bytes for HTTP responses with highest quality by default."""
|
||||
return shared_cache_buffer.get_frame_as_jpeg(camera_id, crop_coords, quality)
|
||||
|
||||
def has_frame(self, camera_id: str) -> bool:
|
||||
"""Check if a frame is available for the camera."""
|
||||
return shared_cache_buffer.has_frame(camera_id)
|
||||
|
||||
def get_subscription_info(self, subscription_id: str) -> Optional[SubscriptionInfo]:
|
||||
"""Get information about a subscription."""
|
||||
with self._lock:
|
||||
return self._subscriptions.get(subscription_id)
|
||||
|
||||
def get_camera_subscribers(self, camera_id: str) -> Set[str]:
|
||||
"""Get all subscription IDs for a camera."""
|
||||
with self._lock:
|
||||
return self._camera_subscribers[camera_id].copy()
|
||||
|
||||
def get_active_cameras(self) -> List[str]:
|
||||
"""Get list of cameras with active streams."""
|
||||
with self._lock:
|
||||
return list(self._streams.keys())
|
||||
|
||||
def get_all_subscriptions(self) -> List[SubscriptionInfo]:
|
||||
"""Get all active subscriptions."""
|
||||
with self._lock:
|
||||
return list(self._subscriptions.values())
|
||||
|
||||
def reconcile_subscriptions(self, target_subscriptions: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Reconcile current subscriptions with target list.
|
||||
Returns summary of changes made.
|
||||
"""
|
||||
with self._lock:
|
||||
current_subscription_ids = set(self._subscriptions.keys())
|
||||
target_subscription_ids = {sub['subscriptionIdentifier'] for sub in target_subscriptions}
|
||||
|
||||
# Find subscriptions to remove and add
|
||||
to_remove = current_subscription_ids - target_subscription_ids
|
||||
to_add = target_subscription_ids - current_subscription_ids
|
||||
|
||||
# Remove old subscriptions
|
||||
removed_count = 0
|
||||
for subscription_id in to_remove:
|
||||
if self._remove_subscription_internal(subscription_id):
|
||||
removed_count += 1
|
||||
|
||||
# Add new subscriptions
|
||||
added_count = 0
|
||||
failed_count = 0
|
||||
for target_sub in target_subscriptions:
|
||||
subscription_id = target_sub['subscriptionIdentifier']
|
||||
if subscription_id in to_add:
|
||||
success = self._add_subscription_from_payload(subscription_id, target_sub)
|
||||
if success:
|
||||
added_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
result = {
|
||||
'removed': removed_count,
|
||||
'added': added_count,
|
||||
'failed': failed_count,
|
||||
'total_active': len(self._subscriptions),
|
||||
'active_streams': len(self._streams)
|
||||
}
|
||||
|
||||
logger.info(f"Subscription reconciliation: {result}")
|
||||
return result
|
||||
|
||||
def _add_subscription_from_payload(self, subscription_id: str, payload: Dict[str, Any]) -> bool:
|
||||
"""Add subscription from WebSocket payload format."""
|
||||
try:
|
||||
# Extract camera ID from subscription identifier
|
||||
# Format: "display-001;cam-001" -> camera_id = "cam-001"
|
||||
camera_id = subscription_id.split(';')[-1]
|
||||
|
||||
# Extract crop coordinates if present
|
||||
crop_coords = None
|
||||
if all(key in payload for key in ['cropX1', 'cropY1', 'cropX2', 'cropY2']):
|
||||
crop_coords = (
|
||||
payload['cropX1'],
|
||||
payload['cropY1'],
|
||||
payload['cropX2'],
|
||||
payload['cropY2']
|
||||
)
|
||||
|
||||
# Create stream configuration
|
||||
stream_config = StreamConfig(
|
||||
camera_id=camera_id,
|
||||
rtsp_url=payload.get('rtspUrl'),
|
||||
snapshot_url=payload.get('snapshotUrl'),
|
||||
snapshot_interval=payload.get('snapshotInterval', 5000),
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
return self.add_subscription(
|
||||
subscription_id,
|
||||
stream_config,
|
||||
crop_coords,
|
||||
model_id=payload.get('modelId'),
|
||||
model_url=payload.get('modelUrl')
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding subscription from payload {subscription_id}: {e}")
|
||||
return False
|
||||
|
||||
def stop_all(self):
|
||||
"""Stop all streams and clear all subscriptions."""
|
||||
# Signal workers to stop
|
||||
self._stop_workers.set()
|
||||
|
||||
# Clear all camera queues
|
||||
for camera_id, camera_queue in list(self._tracking_queues.items()):
|
||||
while not camera_queue.empty():
|
||||
try:
|
||||
camera_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# Wait for workers to finish
|
||||
for worker in self._tracking_workers:
|
||||
worker.join(timeout=2.0)
|
||||
|
||||
# Clear queue management structures
|
||||
self._tracking_queues.clear()
|
||||
self._dropped_frame_counts.clear()
|
||||
with self._round_robin_lock:
|
||||
self._camera_list.clear()
|
||||
self._camera_round_robin_index = 0
|
||||
|
||||
logger.info("Stopped all tracking worker threads")
|
||||
|
||||
with self._lock:
|
||||
# Stop all streams
|
||||
for camera_id in list(self._streams.keys()):
|
||||
self._stop_stream(camera_id)
|
||||
|
||||
# Clear all tracking
|
||||
self._subscriptions.clear()
|
||||
self._camera_subscribers.clear()
|
||||
shared_cache_buffer.clear_all()
|
||||
|
||||
logger.info("Stopped all streams and cleared all subscriptions")
|
||||
|
||||
def set_session_id(self, display_id: str, session_id: str):
|
||||
"""Set session ID for tracking integration."""
|
||||
# Ensure session_id is always a string for consistent type handling
|
||||
session_id = str(session_id) if session_id is not None else None
|
||||
with self._lock:
|
||||
for subscription_info in self._subscriptions.values():
|
||||
# Check if this subscription matches the display_id
|
||||
subscription_display_id = subscription_info.subscription_id.split(';')[0]
|
||||
if subscription_display_id == display_id and subscription_info.tracking_integration:
|
||||
# Pass the full subscription_id (displayId;cameraId) to the tracking integration
|
||||
subscription_info.tracking_integration.set_session_id(
|
||||
display_id,
|
||||
session_id,
|
||||
subscription_id=subscription_info.subscription_id
|
||||
)
|
||||
logger.debug(f"Set session {session_id} for display {display_id} with subscription {subscription_info.subscription_id}")
|
||||
|
||||
def clear_session_id(self, session_id: str):
|
||||
"""Clear session ID from the specific tracking integration handling this session."""
|
||||
with self._lock:
|
||||
# Find the subscription that's handling this session
|
||||
session_subscription = None
|
||||
for subscription_info in self._subscriptions.values():
|
||||
if subscription_info.tracking_integration:
|
||||
# Check if this integration is handling the given session_id
|
||||
integration = subscription_info.tracking_integration
|
||||
if session_id in integration.session_vehicles:
|
||||
session_subscription = subscription_info
|
||||
break
|
||||
|
||||
if session_subscription and session_subscription.tracking_integration:
|
||||
session_subscription.tracking_integration.clear_session_id(session_id)
|
||||
logger.debug(f"Cleared session {session_id} from subscription {session_subscription.subscription_id}")
|
||||
else:
|
||||
logger.warning(f"No tracking integration found for session {session_id}, broadcasting to all subscriptions")
|
||||
# Fallback: broadcast to all (original behavior)
|
||||
for subscription_info in self._subscriptions.values():
|
||||
if subscription_info.tracking_integration:
|
||||
subscription_info.tracking_integration.clear_session_id(session_id)
|
||||
|
||||
def set_progression_stage(self, session_id: str, stage: str):
|
||||
"""Set progression stage for the specific tracking integration handling this session."""
|
||||
with self._lock:
|
||||
# Find the subscription that's handling this session
|
||||
session_subscription = None
|
||||
for subscription_info in self._subscriptions.values():
|
||||
if subscription_info.tracking_integration:
|
||||
# Check if this integration is handling the given session_id
|
||||
# We need to check the integration's active sessions
|
||||
integration = subscription_info.tracking_integration
|
||||
if session_id in integration.session_vehicles:
|
||||
session_subscription = subscription_info
|
||||
break
|
||||
|
||||
if session_subscription and session_subscription.tracking_integration:
|
||||
session_subscription.tracking_integration.set_progression_stage(session_id, stage)
|
||||
logger.debug(f"Set progression stage for session {session_id}: {stage} on subscription {session_subscription.subscription_id}")
|
||||
else:
|
||||
logger.warning(f"No tracking integration found for session {session_id}, broadcasting to all subscriptions")
|
||||
# Fallback: broadcast to all (original behavior)
|
||||
for subscription_info in self._subscriptions.values():
|
||||
if subscription_info.tracking_integration:
|
||||
subscription_info.tracking_integration.set_progression_stage(session_id, stage)
|
||||
|
||||
def get_tracking_stats(self) -> Dict[str, Any]:
|
||||
"""Get tracking statistics from all subscriptions."""
|
||||
stats = {}
|
||||
with self._lock:
|
||||
for subscription_id, subscription_info in self._subscriptions.items():
|
||||
if subscription_info.tracking_integration:
|
||||
stats[subscription_id] = subscription_info.tracking_integration.get_statistics()
|
||||
return stats
|
||||
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive streaming statistics."""
|
||||
with self._lock:
|
||||
buffer_stats = shared_cache_buffer.get_stats()
|
||||
tracking_stats = self.get_tracking_stats()
|
||||
|
||||
return {
|
||||
'active_subscriptions': len(self._subscriptions),
|
||||
'active_streams': len(self._streams),
|
||||
'cameras_with_subscribers': len(self._camera_subscribers),
|
||||
'max_streams': self.max_streams,
|
||||
'subscriptions_by_camera': {
|
||||
camera_id: len(subscribers)
|
||||
for camera_id, subscribers in self._camera_subscribers.items()
|
||||
},
|
||||
'buffer_stats': buffer_stats,
|
||||
'tracking_stats': tracking_stats,
|
||||
'memory_usage_mb': buffer_stats.get('total_memory_mb', 0)
|
||||
}
|
||||
|
||||
|
||||
# Global shared instance for application use
|
||||
# Default initialization, will be updated with config value in app.py
|
||||
shared_stream_manager = StreamManager(max_streams=20)
|
||||
|
||||
def initialize_stream_manager(max_streams: int = 10):
|
||||
"""Re-initialize the global stream manager with config value."""
|
||||
global shared_stream_manager
|
||||
# Release old manager if exists
|
||||
if shared_stream_manager:
|
||||
try:
|
||||
# Stop all existing streams gracefully
|
||||
shared_stream_manager.stop_all()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping previous stream manager: {e}")
|
||||
shared_stream_manager = StreamManager(max_streams=max_streams)
|
||||
return shared_stream_manager
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
"""
|
||||
Stream readers for RTSP and HTTP camera feeds.
|
||||
"""
|
||||
from .base import VideoReader
|
||||
from .ffmpeg_rtsp import FFmpegRTSPReader
|
||||
from .http_snapshot import HTTPSnapshotReader
|
||||
from .utils import log_success, log_warning, log_error, log_info, Colors
|
||||
|
||||
__all__ = [
|
||||
'VideoReader',
|
||||
'FFmpegRTSPReader',
|
||||
'HTTPSnapshotReader',
|
||||
'log_success',
|
||||
'log_warning',
|
||||
'log_error',
|
||||
'log_info',
|
||||
'Colors'
|
||||
]
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
"""
|
||||
Abstract base class for video stream readers.
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Callable
|
||||
import numpy as np
|
||||
|
||||
|
||||
class VideoReader(ABC):
|
||||
"""Abstract base class for video stream readers."""
|
||||
|
||||
def __init__(self, camera_id: str, source_url: str, max_retries: int = 3):
|
||||
"""
|
||||
Initialize the video reader.
|
||||
|
||||
Args:
|
||||
camera_id: Unique identifier for the camera
|
||||
source_url: URL or path to the video source
|
||||
max_retries: Maximum number of retry attempts
|
||||
"""
|
||||
self.camera_id = camera_id
|
||||
self.source_url = source_url
|
||||
self.max_retries = max_retries
|
||||
self.frame_callback: Optional[Callable[[str, np.ndarray], None]] = None
|
||||
|
||||
@abstractmethod
|
||||
def start(self) -> None:
|
||||
"""Start the video reader."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""Stop the video reader."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_frame_callback(self, callback: Callable[[str, np.ndarray], None]) -> None:
|
||||
"""
|
||||
Set callback function to handle captured frames.
|
||||
|
||||
Args:
|
||||
callback: Function that takes (camera_id, frame) as arguments
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the reader is currently running."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def reader_type(self) -> str:
|
||||
"""Get the type of reader (e.g., 'rtsp', 'http_snapshot')."""
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
self.stop()
|
||||
|
|
@ -1,436 +0,0 @@
|
|||
"""
|
||||
FFmpeg RTSP stream reader using subprocess piping frames directly to buffer.
|
||||
Enhanced with comprehensive health monitoring and automatic recovery.
|
||||
"""
|
||||
import cv2
|
||||
import time
|
||||
import threading
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import struct
|
||||
from typing import Optional, Callable, Dict, Any
|
||||
|
||||
from .base import VideoReader
|
||||
from .utils import log_success, log_warning, log_error, log_info
|
||||
from ...monitoring.stream_health import stream_health_tracker
|
||||
from ...monitoring.thread_health import thread_health_monitor
|
||||
from ...monitoring.recovery import recovery_manager, RecoveryAction
|
||||
|
||||
|
||||
class FFmpegRTSPReader(VideoReader):
|
||||
"""RTSP stream reader using subprocess FFmpeg piping frames directly to buffer."""
|
||||
|
||||
def __init__(self, camera_id: str, rtsp_url: str, max_retries: int = 3):
|
||||
super().__init__(camera_id, rtsp_url, max_retries)
|
||||
self.rtsp_url = rtsp_url
|
||||
self.process = None
|
||||
self.stop_event = threading.Event()
|
||||
self.thread = None
|
||||
self.stderr_thread = None
|
||||
|
||||
# Expected stream specs (for reference, actual dimensions read from PPM header)
|
||||
self.width = 1280
|
||||
self.height = 720
|
||||
|
||||
# Watchdog timers for stream reliability
|
||||
self.process_start_time = None
|
||||
self.last_frame_time = None
|
||||
self.is_restart = False # Track if this is a restart (shorter timeout)
|
||||
self.first_start_timeout = 30.0 # 30s timeout on first start
|
||||
self.restart_timeout = 15.0 # 15s timeout after restart
|
||||
|
||||
# Health monitoring setup
|
||||
self.last_heartbeat = time.time()
|
||||
self.consecutive_errors = 0
|
||||
self.ffmpeg_restart_count = 0
|
||||
|
||||
# Register recovery handlers
|
||||
recovery_manager.register_recovery_handler(
|
||||
RecoveryAction.RESTART_STREAM,
|
||||
self._handle_restart_recovery
|
||||
)
|
||||
recovery_manager.register_recovery_handler(
|
||||
RecoveryAction.RECONNECT,
|
||||
self._handle_reconnect_recovery
|
||||
)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the reader is currently running."""
|
||||
return self.thread is not None and self.thread.is_alive()
|
||||
|
||||
@property
|
||||
def reader_type(self) -> str:
|
||||
"""Get the type of reader."""
|
||||
return "rtsp_ffmpeg"
|
||||
|
||||
def set_frame_callback(self, callback: Callable[[str, np.ndarray], None]):
|
||||
"""Set callback function to handle captured frames."""
|
||||
self.frame_callback = callback
|
||||
|
||||
def start(self):
|
||||
"""Start the FFmpeg subprocess reader."""
|
||||
if self.thread and self.thread.is_alive():
|
||||
log_warning(self.camera_id, "FFmpeg reader already running")
|
||||
return
|
||||
|
||||
self.stop_event.clear()
|
||||
self.thread = threading.Thread(target=self._read_frames, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
# Register with health monitoring
|
||||
stream_health_tracker.register_stream(self.camera_id, "rtsp_ffmpeg", self.rtsp_url)
|
||||
thread_health_monitor.register_thread(self.thread, self._heartbeat_callback)
|
||||
|
||||
log_success(self.camera_id, "Stream started with health monitoring")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the FFmpeg subprocess reader."""
|
||||
self.stop_event.set()
|
||||
|
||||
# Unregister from health monitoring
|
||||
if self.thread:
|
||||
thread_health_monitor.unregister_thread(self.thread.ident)
|
||||
|
||||
if self.process:
|
||||
self.process.terminate()
|
||||
try:
|
||||
self.process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.process.kill()
|
||||
|
||||
if self.thread:
|
||||
self.thread.join(timeout=5.0)
|
||||
if self.stderr_thread:
|
||||
self.stderr_thread.join(timeout=2.0)
|
||||
|
||||
stream_health_tracker.unregister_stream(self.camera_id)
|
||||
|
||||
log_info(self.camera_id, "Stream stopped")
|
||||
|
||||
def _start_ffmpeg_process(self):
|
||||
"""Start FFmpeg subprocess outputting BMP frames to stdout pipe."""
|
||||
cmd = [
|
||||
'ffmpeg',
|
||||
# DO NOT REMOVE
|
||||
'-hwaccel', 'cuda',
|
||||
'-hwaccel_device', '0',
|
||||
# Real-time input flags
|
||||
'-fflags', 'nobuffer+genpts',
|
||||
'-flags', 'low_delay',
|
||||
'-max_delay', '0', # No reordering delay
|
||||
# RTSP configuration
|
||||
'-rtsp_transport', 'tcp',
|
||||
'-i', self.rtsp_url,
|
||||
# Output configuration (keeping BMP)
|
||||
'-f', 'image2pipe', # Output images to pipe
|
||||
'-vcodec', 'bmp', # BMP format with header containing dimensions
|
||||
'-vsync', 'passthrough', # Pass frames as-is
|
||||
# Use native stream resolution and framerate
|
||||
'-an', # No audio
|
||||
'-' # Output to stdout
|
||||
]
|
||||
|
||||
try:
|
||||
# Start FFmpeg with stdout pipe to read frames directly
|
||||
self.process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE, # Capture stdout for frame data
|
||||
stderr=subprocess.PIPE, # Capture stderr for error logging
|
||||
bufsize=0 # Unbuffered for real-time processing
|
||||
)
|
||||
|
||||
# Start stderr reading thread
|
||||
if self.stderr_thread and self.stderr_thread.is_alive():
|
||||
# Stop previous stderr thread
|
||||
try:
|
||||
self.stderr_thread.join(timeout=1.0)
|
||||
except:
|
||||
pass
|
||||
|
||||
self.stderr_thread = threading.Thread(target=self._read_stderr, daemon=True)
|
||||
self.stderr_thread.start()
|
||||
|
||||
# Set process start time for watchdog
|
||||
self.process_start_time = time.time()
|
||||
self.last_frame_time = None # Reset frame time
|
||||
|
||||
# After successful restart, next timeout will be back to 30s
|
||||
if self.is_restart:
|
||||
log_info(self.camera_id, f"FFmpeg restarted successfully, next timeout: {self.first_start_timeout}s")
|
||||
self.is_restart = False
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
log_error(self.camera_id, f"FFmpeg startup failed: {e}")
|
||||
return False
|
||||
|
||||
def _read_bmp_frame(self, pipe):
|
||||
"""Read BMP frame from pipe - BMP header contains dimensions."""
|
||||
try:
|
||||
# Read BMP header (14 bytes file header + 40 bytes info header = 54 bytes minimum)
|
||||
header_data = b''
|
||||
bytes_to_read = 54
|
||||
|
||||
while len(header_data) < bytes_to_read:
|
||||
chunk = pipe.read(bytes_to_read - len(header_data))
|
||||
if not chunk:
|
||||
return None # Silent end of stream
|
||||
header_data += chunk
|
||||
|
||||
# Parse BMP header
|
||||
if header_data[:2] != b'BM':
|
||||
return None # Invalid format, skip frame silently
|
||||
|
||||
# Extract file size from header (bytes 2-5)
|
||||
file_size = struct.unpack('<L', header_data[2:6])[0]
|
||||
|
||||
# Extract width and height from info header (bytes 18-21 and 22-25)
|
||||
width = struct.unpack('<L', header_data[18:22])[0]
|
||||
height = struct.unpack('<L', header_data[22:26])[0]
|
||||
|
||||
# Read remaining file data
|
||||
remaining_size = file_size - 54
|
||||
remaining_data = b''
|
||||
|
||||
while len(remaining_data) < remaining_size:
|
||||
chunk = pipe.read(remaining_size - len(remaining_data))
|
||||
if not chunk:
|
||||
return None # Stream ended silently
|
||||
remaining_data += chunk
|
||||
|
||||
# Complete BMP data
|
||||
bmp_data = header_data + remaining_data
|
||||
|
||||
# Use OpenCV to decode BMP directly from memory
|
||||
frame_array = np.frombuffer(bmp_data, dtype=np.uint8)
|
||||
frame = cv2.imdecode(frame_array, cv2.IMREAD_COLOR)
|
||||
|
||||
if frame is None:
|
||||
return None # Decode failed silently
|
||||
|
||||
return frame
|
||||
|
||||
except Exception:
|
||||
return None # Error reading frame silently
|
||||
|
||||
def _read_stderr(self):
|
||||
"""Read and log FFmpeg stderr output in background thread."""
|
||||
if not self.process or not self.process.stderr:
|
||||
return
|
||||
|
||||
try:
|
||||
while self.process and self.process.poll() is None:
|
||||
try:
|
||||
line = self.process.stderr.readline()
|
||||
if line:
|
||||
error_msg = line.decode('utf-8', errors='ignore').strip()
|
||||
if error_msg and not self.stop_event.is_set():
|
||||
# Filter out common noise but log actual errors
|
||||
if any(keyword in error_msg.lower() for keyword in ['error', 'failed', 'cannot', 'invalid']):
|
||||
log_error(self.camera_id, f"FFmpeg: {error_msg}")
|
||||
elif 'warning' in error_msg.lower():
|
||||
log_warning(self.camera_id, f"FFmpeg: {error_msg}")
|
||||
except Exception:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _check_watchdog_timeout(self) -> bool:
|
||||
"""Check if watchdog timeout has been exceeded."""
|
||||
if not self.process_start_time:
|
||||
return False
|
||||
|
||||
current_time = time.time()
|
||||
time_since_start = current_time - self.process_start_time
|
||||
|
||||
# Determine timeout based on whether this is a restart
|
||||
timeout = self.restart_timeout if self.is_restart else self.first_start_timeout
|
||||
|
||||
# If no frames received yet, check against process start time
|
||||
if not self.last_frame_time:
|
||||
if time_since_start > timeout:
|
||||
log_warning(self.camera_id, f"Watchdog timeout: No frames for {time_since_start:.1f}s (limit: {timeout}s)")
|
||||
return True
|
||||
else:
|
||||
# Check time since last frame
|
||||
time_since_frame = current_time - self.last_frame_time
|
||||
if time_since_frame > timeout:
|
||||
log_warning(self.camera_id, f"Watchdog timeout: No frames for {time_since_frame:.1f}s (limit: {timeout}s)")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _restart_ffmpeg_process(self):
|
||||
"""Restart FFmpeg process due to watchdog timeout."""
|
||||
log_warning(self.camera_id, "Watchdog triggered FFmpeg restart")
|
||||
|
||||
# Terminate current process
|
||||
if self.process:
|
||||
try:
|
||||
self.process.terminate()
|
||||
self.process.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.process.kill()
|
||||
except Exception:
|
||||
pass
|
||||
self.process = None
|
||||
|
||||
# Mark as restart for shorter timeout
|
||||
self.is_restart = True
|
||||
|
||||
# Small delay before restart
|
||||
time.sleep(1.0)
|
||||
|
||||
def _read_frames(self):
|
||||
"""Read frames directly from FFmpeg stdout pipe."""
|
||||
frame_count = 0
|
||||
last_log_time = time.time()
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
# Send heartbeat for thread health monitoring
|
||||
self._send_heartbeat("reading_frames")
|
||||
|
||||
# Check watchdog timeout if process is running
|
||||
if self.process and self.process.poll() is None:
|
||||
if self._check_watchdog_timeout():
|
||||
self._restart_ffmpeg_process()
|
||||
continue
|
||||
|
||||
# Start FFmpeg if not running
|
||||
if not self.process or self.process.poll() is not None:
|
||||
if self.process and self.process.poll() is not None:
|
||||
log_warning(self.camera_id, "Stream disconnected, reconnecting...")
|
||||
stream_health_tracker.report_error(
|
||||
self.camera_id,
|
||||
"FFmpeg process disconnected"
|
||||
)
|
||||
|
||||
if not self._start_ffmpeg_process():
|
||||
self.consecutive_errors += 1
|
||||
stream_health_tracker.report_error(
|
||||
self.camera_id,
|
||||
"Failed to start FFmpeg process"
|
||||
)
|
||||
time.sleep(5.0)
|
||||
continue
|
||||
|
||||
# Read frames directly from FFmpeg stdout
|
||||
try:
|
||||
if self.process and self.process.stdout:
|
||||
# Read BMP frame data
|
||||
frame = self._read_bmp_frame(self.process.stdout)
|
||||
if frame is None:
|
||||
continue
|
||||
|
||||
# Update watchdog - we got a frame
|
||||
self.last_frame_time = time.time()
|
||||
|
||||
# Reset error counter on successful frame
|
||||
self.consecutive_errors = 0
|
||||
|
||||
# Report successful frame to health monitoring
|
||||
frame_size = frame.nbytes
|
||||
stream_health_tracker.report_frame_received(self.camera_id, frame_size)
|
||||
|
||||
# Call frame callback
|
||||
if self.frame_callback:
|
||||
try:
|
||||
self.frame_callback(self.camera_id, frame)
|
||||
except Exception as e:
|
||||
stream_health_tracker.report_error(
|
||||
self.camera_id,
|
||||
f"Frame callback error: {e}"
|
||||
)
|
||||
|
||||
frame_count += 1
|
||||
|
||||
# Log progress every 60 seconds (quieter)
|
||||
current_time = time.time()
|
||||
if current_time - last_log_time >= 60:
|
||||
log_success(self.camera_id, f"{frame_count} frames captured ({frame.shape[1]}x{frame.shape[0]})")
|
||||
last_log_time = current_time
|
||||
|
||||
except Exception as e:
|
||||
# Process might have died, let it restart on next iteration
|
||||
stream_health_tracker.report_error(
|
||||
self.camera_id,
|
||||
f"Frame reading error: {e}"
|
||||
)
|
||||
if self.process:
|
||||
self.process.terminate()
|
||||
self.process = None
|
||||
time.sleep(1.0)
|
||||
|
||||
except Exception as e:
|
||||
stream_health_tracker.report_error(
|
||||
self.camera_id,
|
||||
f"Main loop error: {e}"
|
||||
)
|
||||
time.sleep(1.0)
|
||||
|
||||
# Cleanup
|
||||
if self.process:
|
||||
self.process.terminate()
|
||||
|
||||
# Health monitoring methods
|
||||
def _send_heartbeat(self, activity: str = "running"):
|
||||
"""Send heartbeat to thread health monitor."""
|
||||
self.last_heartbeat = time.time()
|
||||
thread_health_monitor.heartbeat(activity=activity)
|
||||
|
||||
def _heartbeat_callback(self) -> bool:
|
||||
"""Heartbeat callback for thread responsiveness testing."""
|
||||
try:
|
||||
# Check if thread is responsive by checking recent heartbeat
|
||||
current_time = time.time()
|
||||
age = current_time - self.last_heartbeat
|
||||
|
||||
# Thread is responsive if heartbeat is recent
|
||||
return age < 30.0 # 30 second responsiveness threshold
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _handle_restart_recovery(self, component: str, details: Dict[str, Any]) -> bool:
|
||||
"""Handle restart recovery action."""
|
||||
try:
|
||||
log_info(self.camera_id, "Restarting FFmpeg RTSP reader for health recovery")
|
||||
|
||||
# Stop current instance
|
||||
self.stop()
|
||||
|
||||
# Small delay
|
||||
time.sleep(2.0)
|
||||
|
||||
# Restart
|
||||
self.start()
|
||||
|
||||
# Report successful restart
|
||||
stream_health_tracker.report_reconnect(self.camera_id, "health_recovery_restart")
|
||||
self.ffmpeg_restart_count += 1
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
log_error(self.camera_id, f"Failed to restart FFmpeg RTSP reader: {e}")
|
||||
return False
|
||||
|
||||
def _handle_reconnect_recovery(self, component: str, details: Dict[str, Any]) -> bool:
|
||||
"""Handle reconnect recovery action."""
|
||||
try:
|
||||
log_info(self.camera_id, "Reconnecting FFmpeg RTSP reader for health recovery")
|
||||
|
||||
# Force restart FFmpeg process
|
||||
self._restart_ffmpeg_process()
|
||||
|
||||
# Reset error counters
|
||||
self.consecutive_errors = 0
|
||||
stream_health_tracker.report_reconnect(self.camera_id, "health_recovery_reconnect")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
log_error(self.camera_id, f"Failed to reconnect FFmpeg RTSP reader: {e}")
|
||||
return False
|
||||
|
|
@ -1,378 +0,0 @@
|
|||
"""
|
||||
HTTP snapshot reader optimized for 2560x1440 (2K) high quality images.
|
||||
Enhanced with comprehensive health monitoring and automatic recovery.
|
||||
"""
|
||||
import cv2
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
import requests
|
||||
import numpy as np
|
||||
from typing import Optional, Callable, Dict, Any
|
||||
|
||||
from .base import VideoReader
|
||||
from .utils import log_success, log_warning, log_error, log_info
|
||||
from ...monitoring.stream_health import stream_health_tracker
|
||||
from ...monitoring.thread_health import thread_health_monitor
|
||||
from ...monitoring.recovery import recovery_manager, RecoveryAction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HTTPSnapshotReader(VideoReader):
|
||||
"""HTTP snapshot reader optimized for 2560x1440 (2K) high quality images."""
|
||||
|
||||
def __init__(self, camera_id: str, snapshot_url: str, interval_ms: int = 5000, max_retries: int = 3):
|
||||
super().__init__(camera_id, snapshot_url, max_retries)
|
||||
self.snapshot_url = snapshot_url
|
||||
self.interval_ms = interval_ms
|
||||
self.stop_event = threading.Event()
|
||||
self.thread = None
|
||||
|
||||
# Expected snapshot specifications
|
||||
self.expected_width = 2560
|
||||
self.expected_height = 1440
|
||||
self.max_file_size = 10 * 1024 * 1024 # 10MB max for 2K image
|
||||
|
||||
# Health monitoring setup
|
||||
self.last_heartbeat = time.time()
|
||||
self.consecutive_errors = 0
|
||||
self.connection_test_interval = 300 # Test connection every 5 minutes
|
||||
self.last_connection_test = None
|
||||
|
||||
# Register recovery handlers
|
||||
recovery_manager.register_recovery_handler(
|
||||
RecoveryAction.RESTART_STREAM,
|
||||
self._handle_restart_recovery
|
||||
)
|
||||
recovery_manager.register_recovery_handler(
|
||||
RecoveryAction.RECONNECT,
|
||||
self._handle_reconnect_recovery
|
||||
)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the reader is currently running."""
|
||||
return self.thread is not None and self.thread.is_alive()
|
||||
|
||||
@property
|
||||
def reader_type(self) -> str:
|
||||
"""Get the type of reader."""
|
||||
return "http_snapshot"
|
||||
|
||||
def set_frame_callback(self, callback: Callable[[str, np.ndarray], None]):
|
||||
"""Set callback function to handle captured frames."""
|
||||
self.frame_callback = callback
|
||||
|
||||
def start(self):
|
||||
"""Start the snapshot reader thread."""
|
||||
if self.thread and self.thread.is_alive():
|
||||
logger.warning(f"Snapshot reader for {self.camera_id} already running")
|
||||
return
|
||||
|
||||
self.stop_event.clear()
|
||||
self.thread = threading.Thread(target=self._read_snapshots, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
# Register with health monitoring
|
||||
stream_health_tracker.register_stream(self.camera_id, "http_snapshot", self.snapshot_url)
|
||||
thread_health_monitor.register_thread(self.thread, self._heartbeat_callback)
|
||||
|
||||
logger.info(f"Started snapshot reader for camera {self.camera_id} with health monitoring")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the snapshot reader thread."""
|
||||
self.stop_event.set()
|
||||
|
||||
# Unregister from health monitoring
|
||||
if self.thread:
|
||||
thread_health_monitor.unregister_thread(self.thread.ident)
|
||||
self.thread.join(timeout=5.0)
|
||||
|
||||
stream_health_tracker.unregister_stream(self.camera_id)
|
||||
|
||||
logger.info(f"Stopped snapshot reader for camera {self.camera_id}")
|
||||
|
||||
def _read_snapshots(self):
|
||||
"""Main snapshot reading loop for high quality 2K images."""
|
||||
retries = 0
|
||||
frame_count = 0
|
||||
last_log_time = time.time()
|
||||
last_connection_test = time.time()
|
||||
interval_seconds = self.interval_ms / 1000.0
|
||||
|
||||
logger.info(f"Snapshot interval for camera {self.camera_id}: {interval_seconds}s")
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
# Send heartbeat for thread health monitoring
|
||||
self._send_heartbeat("fetching_snapshot")
|
||||
|
||||
start_time = time.time()
|
||||
frame = self._fetch_snapshot()
|
||||
|
||||
if frame is None:
|
||||
retries += 1
|
||||
self.consecutive_errors += 1
|
||||
|
||||
# Report error to health monitoring
|
||||
stream_health_tracker.report_error(
|
||||
self.camera_id,
|
||||
f"Failed to fetch snapshot (retry {retries}/{self.max_retries})"
|
||||
)
|
||||
|
||||
logger.warning(f"Failed to fetch snapshot for camera {self.camera_id}, retry {retries}/{self.max_retries}")
|
||||
|
||||
if self.max_retries != -1 and retries > self.max_retries:
|
||||
logger.error(f"Max retries reached for snapshot camera {self.camera_id}")
|
||||
break
|
||||
|
||||
time.sleep(min(2.0, interval_seconds))
|
||||
continue
|
||||
|
||||
# Accept any valid image dimensions - don't force specific resolution
|
||||
if frame.shape[1] <= 0 or frame.shape[0] <= 0:
|
||||
logger.warning(f"Camera {self.camera_id}: Invalid frame dimensions {frame.shape[1]}x{frame.shape[0]}")
|
||||
stream_health_tracker.report_error(
|
||||
self.camera_id,
|
||||
f"Invalid frame dimensions: {frame.shape[1]}x{frame.shape[0]}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Reset retry counter on successful fetch
|
||||
retries = 0
|
||||
self.consecutive_errors = 0
|
||||
frame_count += 1
|
||||
|
||||
# Report successful frame to health monitoring
|
||||
frame_size = frame.nbytes
|
||||
stream_health_tracker.report_frame_received(self.camera_id, frame_size)
|
||||
|
||||
# Call frame callback
|
||||
if self.frame_callback:
|
||||
try:
|
||||
self.frame_callback(self.camera_id, frame)
|
||||
except Exception as e:
|
||||
logger.error(f"Camera {self.camera_id}: Frame callback error: {e}")
|
||||
stream_health_tracker.report_error(self.camera_id, f"Frame callback error: {e}")
|
||||
|
||||
# Periodic connection health test
|
||||
current_time = time.time()
|
||||
if current_time - last_connection_test >= self.connection_test_interval:
|
||||
self._test_connection_health()
|
||||
last_connection_test = current_time
|
||||
|
||||
# Log progress every 30 seconds
|
||||
if current_time - last_log_time >= 30:
|
||||
logger.info(f"Camera {self.camera_id}: {frame_count} snapshots processed")
|
||||
last_log_time = current_time
|
||||
|
||||
# Wait for next interval
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, interval_seconds - elapsed)
|
||||
if sleep_time > 0:
|
||||
self.stop_event.wait(sleep_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in snapshot loop for camera {self.camera_id}: {e}")
|
||||
stream_health_tracker.report_error(self.camera_id, f"Snapshot loop error: {e}")
|
||||
retries += 1
|
||||
if self.max_retries != -1 and retries > self.max_retries:
|
||||
break
|
||||
time.sleep(min(2.0, interval_seconds))
|
||||
|
||||
logger.info(f"Snapshot reader thread ended for camera {self.camera_id}")
|
||||
|
||||
def _fetch_snapshot(self) -> Optional[np.ndarray]:
|
||||
"""Fetch a single high quality snapshot from HTTP URL."""
|
||||
try:
|
||||
# Parse URL for authentication
|
||||
from urllib.parse import urlparse
|
||||
parsed_url = urlparse(self.snapshot_url)
|
||||
|
||||
headers = {
|
||||
'User-Agent': 'Python-Detector-Worker/1.0',
|
||||
'Accept': 'image/jpeg, image/png, image/*'
|
||||
}
|
||||
auth = None
|
||||
|
||||
if parsed_url.username and parsed_url.password:
|
||||
from requests.auth import HTTPBasicAuth, HTTPDigestAuth
|
||||
auth = HTTPBasicAuth(parsed_url.username, parsed_url.password)
|
||||
|
||||
# Reconstruct URL without credentials
|
||||
clean_url = f"{parsed_url.scheme}://{parsed_url.hostname}"
|
||||
if parsed_url.port:
|
||||
clean_url += f":{parsed_url.port}"
|
||||
clean_url += parsed_url.path
|
||||
if parsed_url.query:
|
||||
clean_url += f"?{parsed_url.query}"
|
||||
|
||||
# Try Basic Auth first
|
||||
response = requests.get(clean_url, auth=auth, timeout=15, headers=headers,
|
||||
stream=True, verify=False)
|
||||
|
||||
# If Basic Auth fails, try Digest Auth
|
||||
if response.status_code == 401:
|
||||
auth = HTTPDigestAuth(parsed_url.username, parsed_url.password)
|
||||
response = requests.get(clean_url, auth=auth, timeout=15, headers=headers,
|
||||
stream=True, verify=False)
|
||||
else:
|
||||
response = requests.get(self.snapshot_url, timeout=15, headers=headers,
|
||||
stream=True, verify=False)
|
||||
|
||||
if response.status_code == 200:
|
||||
# Check content size
|
||||
content_length = int(response.headers.get('content-length', 0))
|
||||
if content_length > self.max_file_size:
|
||||
logger.warning(f"Snapshot too large for camera {self.camera_id}: {content_length} bytes")
|
||||
return None
|
||||
|
||||
# Read content
|
||||
content = response.content
|
||||
|
||||
# Convert to numpy array
|
||||
image_array = np.frombuffer(content, np.uint8)
|
||||
|
||||
# Decode as high quality image
|
||||
frame = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
||||
|
||||
if frame is None:
|
||||
logger.error(f"Failed to decode snapshot for camera {self.camera_id}")
|
||||
return None
|
||||
|
||||
logger.debug(f"Fetched snapshot for camera {self.camera_id}: {frame.shape[1]}x{frame.shape[0]}")
|
||||
return frame
|
||||
else:
|
||||
logger.warning(f"HTTP {response.status_code} from {self.camera_id}")
|
||||
return None
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Request error fetching snapshot for {self.camera_id}: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding snapshot for {self.camera_id}: {e}")
|
||||
return None
|
||||
|
||||
def fetch_single_snapshot(self) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Fetch a single high-quality snapshot on demand for pipeline processing.
|
||||
This method is for one-time fetch from HTTP URL, not continuous streaming.
|
||||
|
||||
Returns:
|
||||
High quality 2K snapshot frame or None if failed
|
||||
"""
|
||||
logger.info(f"[SNAPSHOT] Fetching snapshot for {self.camera_id} from {self.snapshot_url}")
|
||||
|
||||
# Try to fetch snapshot with retries
|
||||
for attempt in range(self.max_retries):
|
||||
frame = self._fetch_snapshot()
|
||||
|
||||
if frame is not None:
|
||||
logger.info(f"[SNAPSHOT] Successfully fetched {frame.shape[1]}x{frame.shape[0]} snapshot for {self.camera_id}")
|
||||
return frame
|
||||
|
||||
if attempt < self.max_retries - 1:
|
||||
logger.warning(f"[SNAPSHOT] Attempt {attempt + 1}/{self.max_retries} failed for {self.camera_id}, retrying...")
|
||||
time.sleep(0.5)
|
||||
|
||||
logger.error(f"[SNAPSHOT] Failed to fetch snapshot for {self.camera_id} after {self.max_retries} attempts")
|
||||
return None
|
||||
|
||||
def _resize_maintain_aspect(self, frame: np.ndarray, target_width: int, target_height: int) -> np.ndarray:
|
||||
"""Resize image while maintaining aspect ratio for high quality."""
|
||||
h, w = frame.shape[:2]
|
||||
aspect = w / h
|
||||
target_aspect = target_width / target_height
|
||||
|
||||
if aspect > target_aspect:
|
||||
# Image is wider
|
||||
new_width = target_width
|
||||
new_height = int(target_width / aspect)
|
||||
else:
|
||||
# Image is taller
|
||||
new_height = target_height
|
||||
new_width = int(target_height * aspect)
|
||||
|
||||
# Use INTER_LANCZOS4 for high quality downsampling
|
||||
resized = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4)
|
||||
|
||||
# Pad to target size if needed
|
||||
if new_width < target_width or new_height < target_height:
|
||||
top = (target_height - new_height) // 2
|
||||
bottom = target_height - new_height - top
|
||||
left = (target_width - new_width) // 2
|
||||
right = target_width - new_width - left
|
||||
resized = cv2.copyMakeBorder(resized, top, bottom, left, right, cv2.BORDER_CONSTANT, value=[0, 0, 0])
|
||||
|
||||
return resized
|
||||
|
||||
# Health monitoring methods
|
||||
def _send_heartbeat(self, activity: str = "running"):
|
||||
"""Send heartbeat to thread health monitor."""
|
||||
self.last_heartbeat = time.time()
|
||||
thread_health_monitor.heartbeat(activity=activity)
|
||||
|
||||
def _heartbeat_callback(self) -> bool:
|
||||
"""Heartbeat callback for thread responsiveness testing."""
|
||||
try:
|
||||
# Check if thread is responsive by checking recent heartbeat
|
||||
current_time = time.time()
|
||||
age = current_time - self.last_heartbeat
|
||||
|
||||
# Thread is responsive if heartbeat is recent
|
||||
return age < 30.0 # 30 second responsiveness threshold
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _test_connection_health(self):
|
||||
"""Test HTTP connection health."""
|
||||
try:
|
||||
stream_health_tracker.test_http_connection(self.camera_id, self.snapshot_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing connection health for {self.camera_id}: {e}")
|
||||
|
||||
def _handle_restart_recovery(self, component: str, details: Dict[str, Any]) -> bool:
|
||||
"""Handle restart recovery action."""
|
||||
try:
|
||||
logger.info(f"Restarting HTTP snapshot reader for {self.camera_id}")
|
||||
|
||||
# Stop current instance
|
||||
self.stop()
|
||||
|
||||
# Small delay
|
||||
time.sleep(2.0)
|
||||
|
||||
# Restart
|
||||
self.start()
|
||||
|
||||
# Report successful restart
|
||||
stream_health_tracker.report_reconnect(self.camera_id, "health_recovery_restart")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to restart HTTP snapshot reader for {self.camera_id}: {e}")
|
||||
return False
|
||||
|
||||
def _handle_reconnect_recovery(self, component: str, details: Dict[str, Any]) -> bool:
|
||||
"""Handle reconnect recovery action."""
|
||||
try:
|
||||
logger.info(f"Reconnecting HTTP snapshot reader for {self.camera_id}")
|
||||
|
||||
# Test connection first
|
||||
success = stream_health_tracker.test_http_connection(self.camera_id, self.snapshot_url)
|
||||
|
||||
if success:
|
||||
# Reset error counters
|
||||
self.consecutive_errors = 0
|
||||
stream_health_tracker.report_reconnect(self.camera_id, "health_recovery_reconnect")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Connection test failed during recovery for {self.camera_id}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reconnect HTTP snapshot reader for {self.camera_id}: {e}")
|
||||
return False
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
"""
|
||||
Utility functions for stream readers.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
|
||||
# Keep OpenCV errors visible but allow FFmpeg stderr logging
|
||||
os.environ["OPENCV_LOG_LEVEL"] = "ERROR"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Color codes for pretty logging
|
||||
class Colors:
|
||||
GREEN = '\033[92m'
|
||||
YELLOW = '\033[93m'
|
||||
RED = '\033[91m'
|
||||
BLUE = '\033[94m'
|
||||
PURPLE = '\033[95m'
|
||||
CYAN = '\033[96m'
|
||||
WHITE = '\033[97m'
|
||||
BOLD = '\033[1m'
|
||||
END = '\033[0m'
|
||||
|
||||
def log_success(camera_id: str, message: str):
|
||||
"""Log success messages in green"""
|
||||
logger.info(f"{Colors.GREEN}[{camera_id}] {message}{Colors.END}")
|
||||
|
||||
def log_warning(camera_id: str, message: str):
|
||||
"""Log warnings in yellow"""
|
||||
logger.warning(f"{Colors.YELLOW}[{camera_id}] {message}{Colors.END}")
|
||||
|
||||
def log_error(camera_id: str, message: str):
|
||||
"""Log errors in red"""
|
||||
logger.error(f"{Colors.RED}[{camera_id}] {message}{Colors.END}")
|
||||
|
||||
def log_info(camera_id: str, message: str):
|
||||
"""Log info in cyan"""
|
||||
logger.info(f"{Colors.CYAN}[{camera_id}] {message}{Colors.END}")
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
# Tracking module for vehicle tracking and validation
|
||||
|
||||
from .tracker import VehicleTracker, TrackedVehicle
|
||||
from .validator import StableCarValidator, ValidationResult, VehicleState
|
||||
from .integration import TrackingPipelineIntegration
|
||||
|
||||
__all__ = [
|
||||
'VehicleTracker',
|
||||
'TrackedVehicle',
|
||||
'StableCarValidator',
|
||||
'ValidationResult',
|
||||
'VehicleState',
|
||||
'TrackingPipelineIntegration'
|
||||
]
|
||||
|
|
@ -1,408 +0,0 @@
|
|||
"""
|
||||
BoT-SORT Multi-Object Tracker with Camera Isolation
|
||||
Based on BoT-SORT: Robust Associations Multi-Pedestrian Tracking
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from filterpy.kalman import KalmanFilter
|
||||
import cv2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrackState:
|
||||
"""Track state enumeration"""
|
||||
TENTATIVE = "tentative" # New track, not confirmed yet
|
||||
CONFIRMED = "confirmed" # Confirmed track
|
||||
DELETED = "deleted" # Track to be deleted
|
||||
|
||||
|
||||
class Track:
|
||||
"""
|
||||
Individual track representation with Kalman filter for motion prediction
|
||||
"""
|
||||
|
||||
def __init__(self, detection, track_id: int, camera_id: str):
|
||||
"""
|
||||
Initialize a new track
|
||||
|
||||
Args:
|
||||
detection: Initial detection (bbox, confidence, class)
|
||||
track_id: Unique track identifier within camera
|
||||
camera_id: Camera identifier
|
||||
"""
|
||||
self.track_id = track_id
|
||||
self.camera_id = camera_id
|
||||
self.state = TrackState.TENTATIVE
|
||||
|
||||
# Time tracking
|
||||
self.start_time = time.time()
|
||||
self.last_update_time = time.time()
|
||||
|
||||
# Appearance and motion
|
||||
self.bbox = detection.bbox # [x1, y1, x2, y2]
|
||||
self.confidence = detection.confidence
|
||||
self.class_name = detection.class_name
|
||||
|
||||
# Track management
|
||||
self.hit_streak = 1
|
||||
self.time_since_update = 0
|
||||
self.age = 1
|
||||
|
||||
# Kalman filter for motion prediction
|
||||
self.kf = self._create_kalman_filter()
|
||||
self._update_kalman_filter(detection.bbox)
|
||||
|
||||
# Track history
|
||||
self.history = [detection.bbox]
|
||||
self.max_history = 10
|
||||
|
||||
def _create_kalman_filter(self) -> KalmanFilter:
|
||||
"""Create Kalman filter for bbox tracking (x, y, w, h, vx, vy, vw, vh)"""
|
||||
kf = KalmanFilter(dim_x=8, dim_z=4)
|
||||
|
||||
# State transition matrix (constant velocity model)
|
||||
kf.F = np.array([
|
||||
[1, 0, 0, 0, 1, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0, 1, 0, 0],
|
||||
[0, 0, 1, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 1, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1]
|
||||
])
|
||||
|
||||
# Measurement matrix (observe x, y, w, h)
|
||||
kf.H = np.array([
|
||||
[1, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 1, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 1, 0, 0, 0, 0]
|
||||
])
|
||||
|
||||
# Process noise
|
||||
kf.Q *= 0.01
|
||||
|
||||
# Measurement noise
|
||||
kf.R *= 10
|
||||
|
||||
# Initial covariance
|
||||
kf.P *= 100
|
||||
|
||||
return kf
|
||||
|
||||
def _update_kalman_filter(self, bbox: List[float]):
|
||||
"""Update Kalman filter with new bbox"""
|
||||
# Convert [x1, y1, x2, y2] to [cx, cy, w, h]
|
||||
x1, y1, x2, y2 = bbox
|
||||
cx = (x1 + x2) / 2
|
||||
cy = (y1 + y2) / 2
|
||||
w = x2 - x1
|
||||
h = y2 - y1
|
||||
|
||||
# Properly assign to column vector
|
||||
self.kf.x[:4, 0] = [cx, cy, w, h]
|
||||
|
||||
def predict(self) -> np.ndarray:
|
||||
"""Predict next position using Kalman filter"""
|
||||
self.kf.predict()
|
||||
|
||||
# Convert back to [x1, y1, x2, y2] format
|
||||
cx, cy, w, h = self.kf.x[:4, 0] # Extract from column vector
|
||||
x1 = cx - w/2
|
||||
y1 = cy - h/2
|
||||
x2 = cx + w/2
|
||||
y2 = cy + h/2
|
||||
|
||||
return np.array([x1, y1, x2, y2])
|
||||
|
||||
def update(self, detection):
|
||||
"""Update track with new detection"""
|
||||
self.last_update_time = time.time()
|
||||
self.time_since_update = 0
|
||||
self.hit_streak += 1
|
||||
self.age += 1
|
||||
|
||||
# Update track properties
|
||||
self.bbox = detection.bbox
|
||||
self.confidence = detection.confidence
|
||||
|
||||
# Update Kalman filter
|
||||
x1, y1, x2, y2 = detection.bbox
|
||||
cx = (x1 + x2) / 2
|
||||
cy = (y1 + y2) / 2
|
||||
w = x2 - x1
|
||||
h = y2 - y1
|
||||
|
||||
self.kf.update([cx, cy, w, h])
|
||||
|
||||
# Update history
|
||||
self.history.append(detection.bbox)
|
||||
if len(self.history) > self.max_history:
|
||||
self.history.pop(0)
|
||||
|
||||
# Update state
|
||||
if self.state == TrackState.TENTATIVE and self.hit_streak >= 3:
|
||||
self.state = TrackState.CONFIRMED
|
||||
|
||||
def mark_missed(self):
|
||||
"""Mark track as missed in this frame"""
|
||||
self.time_since_update += 1
|
||||
self.age += 1
|
||||
|
||||
if self.time_since_update > 5: # Delete after 5 missed frames
|
||||
self.state = TrackState.DELETED
|
||||
|
||||
def is_confirmed(self) -> bool:
|
||||
"""Check if track is confirmed"""
|
||||
return self.state == TrackState.CONFIRMED
|
||||
|
||||
def is_deleted(self) -> bool:
|
||||
"""Check if track should be deleted"""
|
||||
return self.state == TrackState.DELETED
|
||||
|
||||
|
||||
class CameraTracker:
|
||||
"""
|
||||
BoT-SORT tracker for a single camera
|
||||
"""
|
||||
|
||||
def __init__(self, camera_id: str, max_disappeared: int = 10):
|
||||
"""
|
||||
Initialize camera tracker
|
||||
|
||||
Args:
|
||||
camera_id: Unique camera identifier
|
||||
max_disappeared: Maximum frames a track can be missed before deletion
|
||||
"""
|
||||
self.camera_id = camera_id
|
||||
self.max_disappeared = max_disappeared
|
||||
|
||||
# Track management
|
||||
self.tracks: Dict[int, Track] = {}
|
||||
self.next_id = 1
|
||||
self.frame_count = 0
|
||||
|
||||
logger.info(f"Initialized BoT-SORT tracker for camera {camera_id}")
|
||||
|
||||
def update(self, detections: List) -> List[Track]:
|
||||
"""
|
||||
Update tracker with new detections
|
||||
|
||||
Args:
|
||||
detections: List of Detection objects
|
||||
|
||||
Returns:
|
||||
List of active confirmed tracks
|
||||
"""
|
||||
self.frame_count += 1
|
||||
|
||||
# Predict all existing tracks
|
||||
for track in self.tracks.values():
|
||||
track.predict()
|
||||
|
||||
# Associate detections to tracks
|
||||
matched_tracks, unmatched_detections, unmatched_tracks = self._associate(detections)
|
||||
|
||||
# Update matched tracks
|
||||
for track_id, detection in matched_tracks:
|
||||
self.tracks[track_id].update(detection)
|
||||
|
||||
# Mark unmatched tracks as missed
|
||||
for track_id in unmatched_tracks:
|
||||
self.tracks[track_id].mark_missed()
|
||||
|
||||
# Create new tracks for unmatched detections
|
||||
for detection in unmatched_detections:
|
||||
track = Track(detection, self.next_id, self.camera_id)
|
||||
self.tracks[self.next_id] = track
|
||||
self.next_id += 1
|
||||
|
||||
# Remove deleted tracks
|
||||
tracks_to_remove = [tid for tid, track in self.tracks.items() if track.is_deleted()]
|
||||
for tid in tracks_to_remove:
|
||||
del self.tracks[tid]
|
||||
|
||||
# Return confirmed tracks
|
||||
confirmed_tracks = [track for track in self.tracks.values() if track.is_confirmed()]
|
||||
|
||||
return confirmed_tracks
|
||||
|
||||
def _associate(self, detections: List) -> Tuple[List[Tuple[int, Any]], List[Any], List[int]]:
|
||||
"""
|
||||
Associate detections to existing tracks using IoU distance
|
||||
|
||||
Returns:
|
||||
(matched_tracks, unmatched_detections, unmatched_tracks)
|
||||
"""
|
||||
if not detections or not self.tracks:
|
||||
return [], detections, list(self.tracks.keys())
|
||||
|
||||
# Calculate IoU distance matrix
|
||||
track_ids = list(self.tracks.keys())
|
||||
cost_matrix = np.zeros((len(track_ids), len(detections)))
|
||||
|
||||
for i, track_id in enumerate(track_ids):
|
||||
track = self.tracks[track_id]
|
||||
predicted_bbox = track.predict()
|
||||
|
||||
for j, detection in enumerate(detections):
|
||||
iou = self._calculate_iou(predicted_bbox, detection.bbox)
|
||||
cost_matrix[i, j] = 1 - iou # Convert IoU to distance
|
||||
|
||||
# Solve assignment problem
|
||||
row_indices, col_indices = linear_sum_assignment(cost_matrix)
|
||||
|
||||
# Filter matches by IoU threshold
|
||||
iou_threshold = 0.3
|
||||
matched_tracks = []
|
||||
matched_detection_indices = set()
|
||||
matched_track_indices = set()
|
||||
|
||||
for row, col in zip(row_indices, col_indices):
|
||||
if cost_matrix[row, col] <= (1 - iou_threshold):
|
||||
track_id = track_ids[row]
|
||||
detection = detections[col]
|
||||
matched_tracks.append((track_id, detection))
|
||||
matched_detection_indices.add(col)
|
||||
matched_track_indices.add(row)
|
||||
|
||||
# Find unmatched detections and tracks
|
||||
unmatched_detections = [detections[i] for i in range(len(detections))
|
||||
if i not in matched_detection_indices]
|
||||
unmatched_tracks = [track_ids[i] for i in range(len(track_ids))
|
||||
if i not in matched_track_indices]
|
||||
|
||||
return matched_tracks, unmatched_detections, unmatched_tracks
|
||||
|
||||
def _calculate_iou(self, bbox1: np.ndarray, bbox2: List[float]) -> float:
|
||||
"""Calculate IoU between two bounding boxes"""
|
||||
x1_1, y1_1, x2_1, y2_1 = bbox1
|
||||
x1_2, y1_2, x2_2, y2_2 = bbox2
|
||||
|
||||
# Calculate intersection area
|
||||
x1_i = max(x1_1, x1_2)
|
||||
y1_i = max(y1_1, y1_2)
|
||||
x2_i = min(x2_1, x2_2)
|
||||
y2_i = min(y2_1, y2_2)
|
||||
|
||||
if x2_i <= x1_i or y2_i <= y1_i:
|
||||
return 0.0
|
||||
|
||||
intersection = (x2_i - x1_i) * (y2_i - y1_i)
|
||||
|
||||
# Calculate union area
|
||||
area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
|
||||
area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
|
||||
union = area1 + area2 - intersection
|
||||
|
||||
return intersection / union if union > 0 else 0.0
|
||||
|
||||
|
||||
class MultiCameraBoTSORT:
|
||||
"""
|
||||
Multi-camera BoT-SORT tracker with complete camera isolation
|
||||
"""
|
||||
|
||||
def __init__(self, trigger_classes: List[str], min_confidence: float = 0.6):
|
||||
"""
|
||||
Initialize multi-camera tracker
|
||||
|
||||
Args:
|
||||
trigger_classes: List of class names to track
|
||||
min_confidence: Minimum detection confidence threshold
|
||||
"""
|
||||
self.trigger_classes = trigger_classes
|
||||
self.min_confidence = min_confidence
|
||||
|
||||
# Camera-specific trackers
|
||||
self.camera_trackers: Dict[str, CameraTracker] = {}
|
||||
|
||||
logger.info(f"Initialized MultiCameraBoTSORT with classes={trigger_classes}, "
|
||||
f"min_confidence={min_confidence}")
|
||||
|
||||
def get_or_create_tracker(self, camera_id: str) -> CameraTracker:
|
||||
"""Get or create tracker for specific camera"""
|
||||
if camera_id not in self.camera_trackers:
|
||||
self.camera_trackers[camera_id] = CameraTracker(camera_id)
|
||||
logger.info(f"Created new tracker for camera {camera_id}")
|
||||
|
||||
return self.camera_trackers[camera_id]
|
||||
|
||||
def update(self, camera_id: str, inference_result) -> List[Dict]:
|
||||
"""
|
||||
Update tracker for specific camera with detections
|
||||
|
||||
Args:
|
||||
camera_id: Camera identifier
|
||||
inference_result: InferenceResult with detections
|
||||
|
||||
Returns:
|
||||
List of track information dictionaries
|
||||
"""
|
||||
# Filter detections by confidence and trigger classes
|
||||
filtered_detections = []
|
||||
|
||||
if hasattr(inference_result, 'detections') and inference_result.detections:
|
||||
for detection in inference_result.detections:
|
||||
if (detection.confidence >= self.min_confidence and
|
||||
detection.class_name in self.trigger_classes):
|
||||
filtered_detections.append(detection)
|
||||
|
||||
# Get camera tracker and update
|
||||
tracker = self.get_or_create_tracker(camera_id)
|
||||
confirmed_tracks = tracker.update(filtered_detections)
|
||||
|
||||
# Convert tracks to output format
|
||||
track_results = []
|
||||
for track in confirmed_tracks:
|
||||
track_results.append({
|
||||
'track_id': track.track_id,
|
||||
'camera_id': track.camera_id,
|
||||
'bbox': track.bbox,
|
||||
'confidence': track.confidence,
|
||||
'class_name': track.class_name,
|
||||
'hit_streak': track.hit_streak,
|
||||
'age': track.age
|
||||
})
|
||||
|
||||
return track_results
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get tracking statistics across all cameras"""
|
||||
stats = {}
|
||||
total_tracks = 0
|
||||
|
||||
for camera_id, tracker in self.camera_trackers.items():
|
||||
camera_stats = {
|
||||
'active_tracks': len([t for t in tracker.tracks.values() if t.is_confirmed()]),
|
||||
'total_tracks': len(tracker.tracks),
|
||||
'frame_count': tracker.frame_count
|
||||
}
|
||||
stats[camera_id] = camera_stats
|
||||
total_tracks += camera_stats['active_tracks']
|
||||
|
||||
stats['summary'] = {
|
||||
'total_cameras': len(self.camera_trackers),
|
||||
'total_active_tracks': total_tracks
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def reset_camera(self, camera_id: str):
|
||||
"""Reset tracking for specific camera"""
|
||||
if camera_id in self.camera_trackers:
|
||||
del self.camera_trackers[camera_id]
|
||||
logger.info(f"Reset tracking for camera {camera_id}")
|
||||
|
||||
def reset_all(self):
|
||||
"""Reset all camera trackers"""
|
||||
self.camera_trackers.clear()
|
||||
logger.info("Reset all camera trackers")
|
||||
|
|
@ -1,883 +0,0 @@
|
|||
"""
|
||||
Tracking-Pipeline Integration Module.
|
||||
Connects the tracking system with the main detection pipeline and manages the flow.
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Optional, Any, List, Tuple
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import asyncio
|
||||
import numpy as np
|
||||
|
||||
from .tracker import VehicleTracker, TrackedVehicle
|
||||
from .validator import StableCarValidator
|
||||
from ..models.inference import YOLOWrapper
|
||||
from ..models.pipeline import PipelineParser
|
||||
from ..detection.pipeline import DetectionPipeline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrackingPipelineIntegration:
|
||||
"""
|
||||
Integrates vehicle tracking with the detection pipeline.
|
||||
Manages tracking state transitions and pipeline execution triggers.
|
||||
"""
|
||||
|
||||
def __init__(self, pipeline_parser: PipelineParser, model_manager: Any, model_id: int, message_sender=None):
|
||||
"""
|
||||
Initialize tracking-pipeline integration.
|
||||
|
||||
Args:
|
||||
pipeline_parser: Pipeline parser with loaded configuration
|
||||
model_manager: Model manager for loading models
|
||||
model_id: The model ID to use for loading models
|
||||
message_sender: Optional callback function for sending WebSocket messages
|
||||
"""
|
||||
self.pipeline_parser = pipeline_parser
|
||||
self.model_manager = model_manager
|
||||
self.model_id = model_id
|
||||
self.message_sender = message_sender
|
||||
|
||||
# Store subscription info for snapshot access
|
||||
self.subscription_info = None
|
||||
|
||||
# Initialize tracking components
|
||||
tracking_config = pipeline_parser.tracking_config.__dict__ if pipeline_parser.tracking_config else {}
|
||||
self.tracker = VehicleTracker(tracking_config)
|
||||
self.validator = StableCarValidator()
|
||||
|
||||
# Tracking model
|
||||
self.tracking_model: Optional[YOLOWrapper] = None
|
||||
self.tracking_model_id = None
|
||||
|
||||
# Detection pipeline (Phase 5)
|
||||
self.detection_pipeline: Optional[DetectionPipeline] = None
|
||||
|
||||
# Session management
|
||||
self.active_sessions: Dict[str, str] = {} # display_id -> session_id
|
||||
self.session_vehicles: Dict[str, int] = {} # session_id -> track_id
|
||||
self.cleared_sessions: Dict[str, float] = {} # session_id -> clear_time
|
||||
self.pending_vehicles: Dict[str, int] = {} # display_id -> track_id (waiting for session ID)
|
||||
self.pending_processing_data: Dict[str, Dict] = {} # display_id -> processing data (waiting for session ID)
|
||||
self.display_to_subscription: Dict[str, str] = {} # display_id -> subscription_id (for fallback)
|
||||
|
||||
# Additional validators for enhanced flow control
|
||||
self.permanently_processed: Dict[str, float] = {} # "camera_id:track_id" -> process_time (never process again)
|
||||
self.progression_stages: Dict[str, str] = {} # session_id -> current_stage
|
||||
self.last_detection_time: Dict[str, float] = {} # display_id -> last_detection_timestamp
|
||||
self.abandonment_timeout = 3.0 # seconds to wait before declaring car abandoned
|
||||
|
||||
# Thread pool for pipeline execution
|
||||
# Increased to 8 workers to handle 8 concurrent cameras without queuing
|
||||
self.executor = ThreadPoolExecutor(max_workers=8)
|
||||
|
||||
# Min bbox filtering configuration
|
||||
# TODO: Make this configurable via pipeline.json in the future
|
||||
self.min_bbox_area_percentage = 3.5 # 3.5% of frame area minimum
|
||||
|
||||
# Statistics
|
||||
self.stats = {
|
||||
'frames_processed': 0,
|
||||
'vehicles_detected': 0,
|
||||
'vehicles_validated': 0,
|
||||
'pipelines_executed': 0,
|
||||
'frontals_filtered_small': 0 # Track filtered detections
|
||||
}
|
||||
|
||||
|
||||
logger.info("TrackingPipelineIntegration initialized")
|
||||
|
||||
async def initialize_tracking_model(self) -> bool:
|
||||
"""
|
||||
Load and initialize the tracking model.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not self.pipeline_parser.tracking_config:
|
||||
logger.warning("No tracking configuration found in pipeline")
|
||||
return False
|
||||
|
||||
model_file = self.pipeline_parser.tracking_config.model_file
|
||||
model_id = self.pipeline_parser.tracking_config.model_id
|
||||
|
||||
if not model_file:
|
||||
logger.warning("No tracking model file specified")
|
||||
return False
|
||||
|
||||
# Load tracking model
|
||||
logger.info(f"Loading tracking model: {model_id} ({model_file})")
|
||||
self.tracking_model = self.model_manager.get_yolo_model(self.model_id, model_file)
|
||||
if not self.tracking_model:
|
||||
logger.error(f"Failed to load tracking model {model_file} from model {self.model_id}")
|
||||
return False
|
||||
self.tracking_model_id = model_id
|
||||
|
||||
if self.tracking_model:
|
||||
logger.info(f"Tracking model {model_id} loaded successfully")
|
||||
|
||||
# Initialize detection pipeline (Phase 5)
|
||||
await self._initialize_detection_pipeline()
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to load tracking model {model_id}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing tracking model: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def _initialize_detection_pipeline(self) -> bool:
|
||||
"""
|
||||
Initialize the detection pipeline for main detection processing.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not self.pipeline_parser:
|
||||
logger.warning("No pipeline parser available for detection pipeline")
|
||||
return False
|
||||
|
||||
# Create detection pipeline with message sender capability
|
||||
self.detection_pipeline = DetectionPipeline(self.pipeline_parser, self.model_manager, self.model_id, self.message_sender)
|
||||
|
||||
# Initialize detection pipeline
|
||||
if await self.detection_pipeline.initialize():
|
||||
logger.info("Detection pipeline initialized successfully")
|
||||
return True
|
||||
else:
|
||||
logger.error("Failed to initialize detection pipeline")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing detection pipeline: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def process_frame(self,
|
||||
frame: np.ndarray,
|
||||
display_id: str,
|
||||
subscription_id: str,
|
||||
session_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a frame through tracking and potentially the detection pipeline.
|
||||
|
||||
Args:
|
||||
frame: Input frame to process
|
||||
display_id: Display identifier
|
||||
subscription_id: Full subscription identifier
|
||||
session_id: Optional session ID from backend
|
||||
|
||||
Returns:
|
||||
Dictionary with processing results
|
||||
"""
|
||||
start_time = time.time()
|
||||
result = {
|
||||
'tracked_vehicles': [],
|
||||
'validated_vehicle': None,
|
||||
'pipeline_result': None,
|
||||
'session_id': session_id,
|
||||
'processing_time': 0.0
|
||||
}
|
||||
|
||||
try:
|
||||
# Update stats
|
||||
self.stats['frames_processed'] += 1
|
||||
|
||||
# Run tracking model
|
||||
if self.tracking_model:
|
||||
# Run detection-only (tracking handled by our own tracker)
|
||||
tracking_results = self.tracking_model.track(
|
||||
frame,
|
||||
confidence_threshold=self.tracker.min_confidence,
|
||||
trigger_classes=self.tracker.trigger_classes,
|
||||
persist=True
|
||||
)
|
||||
|
||||
# Debug: Log raw detection results
|
||||
if tracking_results and hasattr(tracking_results, 'detections'):
|
||||
raw_detections = len(tracking_results.detections)
|
||||
if raw_detections > 0:
|
||||
class_names = [detection.class_name for detection in tracking_results.detections]
|
||||
logger.debug(f"Raw detections: {raw_detections}, classes: {class_names}")
|
||||
else:
|
||||
logger.debug(f"No raw detections found")
|
||||
else:
|
||||
logger.debug(f"No tracking results or detections attribute")
|
||||
|
||||
# Filter out small frontal detections (neighboring pumps/distant cars)
|
||||
if tracking_results and hasattr(tracking_results, 'detections'):
|
||||
tracking_results = self._filter_small_frontals(tracking_results, frame)
|
||||
|
||||
# Process tracking results
|
||||
tracked_vehicles = self.tracker.process_detections(
|
||||
tracking_results,
|
||||
display_id,
|
||||
frame
|
||||
)
|
||||
|
||||
# Update last detection time for abandonment detection
|
||||
# Update when vehicles ARE detected, so when they leave, timestamp ages
|
||||
if tracked_vehicles:
|
||||
self.last_detection_time[display_id] = time.time()
|
||||
logger.debug(f"Updated last_detection_time for {display_id}: {len(tracked_vehicles)} vehicles")
|
||||
|
||||
# Check for car abandonment (vehicle left after getting car_wait_staff stage)
|
||||
await self._check_car_abandonment(display_id, subscription_id)
|
||||
|
||||
result['tracked_vehicles'] = [
|
||||
{
|
||||
'track_id': v.track_id,
|
||||
'bbox': v.bbox,
|
||||
'confidence': v.confidence,
|
||||
'is_stable': v.is_stable,
|
||||
'session_id': v.session_id
|
||||
}
|
||||
for v in tracked_vehicles
|
||||
]
|
||||
|
||||
# Log tracking info periodically
|
||||
if self.stats['frames_processed'] % 30 == 0: # Every 30 frames
|
||||
logger.debug(f"Tracking: {len(tracked_vehicles)} vehicles, "
|
||||
f"display={display_id}")
|
||||
|
||||
# Get stable vehicles for validation
|
||||
stable_vehicles = self.tracker.get_stable_vehicles(display_id)
|
||||
|
||||
# Validate and potentially process stable vehicles
|
||||
for vehicle in stable_vehicles:
|
||||
# Check if vehicle is already processed or has session
|
||||
if vehicle.processed_pipeline:
|
||||
continue
|
||||
|
||||
# Check for session cleared (post-fueling)
|
||||
if session_id and vehicle.session_id == session_id:
|
||||
# Same vehicle with same session, skip
|
||||
continue
|
||||
|
||||
# Check if this was a recently cleared session
|
||||
session_cleared = False
|
||||
if vehicle.session_id in self.cleared_sessions:
|
||||
clear_time = self.cleared_sessions[vehicle.session_id]
|
||||
if (time.time() - clear_time) < 30: # 30 second cooldown
|
||||
session_cleared = True
|
||||
|
||||
# Skip same car after session clear or if permanently processed
|
||||
if self.validator.should_skip_same_car(vehicle, session_cleared, self.permanently_processed):
|
||||
continue
|
||||
|
||||
# Validate vehicle
|
||||
validation_result = self.validator.validate_vehicle(vehicle, frame.shape)
|
||||
|
||||
if validation_result.is_valid and validation_result.should_process:
|
||||
logger.info(f"Vehicle {vehicle.track_id} validated for processing: "
|
||||
f"{validation_result.reason}")
|
||||
|
||||
result['validated_vehicle'] = {
|
||||
'track_id': vehicle.track_id,
|
||||
'state': validation_result.state.value,
|
||||
'confidence': validation_result.confidence
|
||||
}
|
||||
|
||||
# Execute detection pipeline - this will send real imageDetection when detection is found
|
||||
|
||||
# Mark vehicle as pending session ID assignment
|
||||
self.pending_vehicles[display_id] = vehicle.track_id
|
||||
logger.info(f"Vehicle {vehicle.track_id} waiting for session ID from backend")
|
||||
|
||||
# Execute detection pipeline (placeholder for Phase 5)
|
||||
pipeline_result = await self._execute_pipeline(
|
||||
frame,
|
||||
vehicle,
|
||||
display_id,
|
||||
None, # No session ID yet
|
||||
subscription_id
|
||||
)
|
||||
|
||||
result['pipeline_result'] = pipeline_result
|
||||
# No session_id in result yet - backend will provide it
|
||||
self.stats['pipelines_executed'] += 1
|
||||
|
||||
# Only process one vehicle per frame
|
||||
break
|
||||
|
||||
self.stats['vehicles_detected'] = len(tracked_vehicles)
|
||||
self.stats['vehicles_validated'] = len(stable_vehicles)
|
||||
|
||||
else:
|
||||
logger.warning("No tracking model available")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in tracking pipeline: {e}", exc_info=True)
|
||||
|
||||
|
||||
result['processing_time'] = time.time() - start_time
|
||||
return result
|
||||
|
||||
async def _execute_pipeline(self,
|
||||
frame: np.ndarray,
|
||||
vehicle: TrackedVehicle,
|
||||
display_id: str,
|
||||
session_id: str,
|
||||
subscription_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute the main detection pipeline for a validated vehicle.
|
||||
|
||||
Args:
|
||||
frame: Input frame
|
||||
vehicle: Validated tracked vehicle
|
||||
display_id: Display identifier
|
||||
session_id: Session identifier
|
||||
subscription_id: Full subscription identifier
|
||||
|
||||
Returns:
|
||||
Pipeline execution results
|
||||
"""
|
||||
logger.info(f"Executing detection pipeline for vehicle {vehicle.track_id}, "
|
||||
f"session={session_id}, display={display_id}")
|
||||
|
||||
try:
|
||||
# Check if detection pipeline is available
|
||||
if not self.detection_pipeline:
|
||||
logger.warning("Detection pipeline not initialized, using fallback")
|
||||
return {
|
||||
'status': 'error',
|
||||
'message': 'Detection pipeline not available',
|
||||
'vehicle_id': vehicle.track_id,
|
||||
'session_id': session_id
|
||||
}
|
||||
|
||||
# Fetch high-quality 2K snapshot for detection phase (not RTSP frame)
|
||||
# This ensures bbox coordinates match the frame used in processing phase
|
||||
logger.info(f"[DETECTION PHASE] Fetching 2K snapshot for vehicle {vehicle.track_id}")
|
||||
snapshot_frame = self._fetch_snapshot()
|
||||
|
||||
if snapshot_frame is None:
|
||||
logger.warning(f"[DETECTION PHASE] Failed to fetch snapshot, falling back to RTSP frame")
|
||||
snapshot_frame = frame # Fallback to RTSP if snapshot fails
|
||||
else:
|
||||
logger.info(f"[DETECTION PHASE] Using {snapshot_frame.shape[1]}x{snapshot_frame.shape[0]} snapshot for detection")
|
||||
|
||||
# Execute only the detection phase (first phase)
|
||||
# This will run detection and send imageDetection message to backend
|
||||
detection_result = await self.detection_pipeline.execute_detection_phase(
|
||||
frame=snapshot_frame, # Use 2K snapshot instead of RTSP frame
|
||||
display_id=display_id,
|
||||
subscription_id=subscription_id
|
||||
)
|
||||
|
||||
# Add vehicle information to result
|
||||
detection_result['vehicle_id'] = vehicle.track_id
|
||||
detection_result['vehicle_bbox'] = vehicle.bbox
|
||||
detection_result['vehicle_confidence'] = vehicle.confidence
|
||||
detection_result['phase'] = 'detection'
|
||||
|
||||
logger.info(f"Detection phase executed for vehicle {vehicle.track_id}: "
|
||||
f"status={detection_result.get('status', 'unknown')}, "
|
||||
f"message_sent={detection_result.get('message_sent', False)}, "
|
||||
f"processing_time={detection_result.get('processing_time', 0):.3f}s")
|
||||
|
||||
# Store frame and detection results for processing phase
|
||||
if detection_result['message_sent']:
|
||||
# Store for later processing when sessionId is received
|
||||
self.pending_processing_data[display_id] = {
|
||||
'frame': snapshot_frame.copy(), # Store copy of 2K snapshot (not RTSP frame!)
|
||||
'vehicle': vehicle,
|
||||
'subscription_id': subscription_id,
|
||||
'detection_result': detection_result,
|
||||
'timestamp': time.time()
|
||||
}
|
||||
logger.info(f"Stored processing data ({snapshot_frame.shape[1]}x{snapshot_frame.shape[0]} frame) for {display_id}, waiting for sessionId from backend")
|
||||
|
||||
return detection_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing detection pipeline: {e}", exc_info=True)
|
||||
return {
|
||||
'status': 'error',
|
||||
'message': str(e),
|
||||
'vehicle_id': vehicle.track_id,
|
||||
'session_id': session_id,
|
||||
'processing_time': 0.0
|
||||
}
|
||||
|
||||
async def _execute_processing_phase(self,
|
||||
processing_data: Dict[str, Any],
|
||||
session_id: str,
|
||||
display_id: str) -> None:
|
||||
"""
|
||||
Execute the processing phase after receiving sessionId from backend.
|
||||
This includes branch processing and database operations.
|
||||
|
||||
Args:
|
||||
processing_data: Stored processing data from detection phase
|
||||
session_id: Session ID from backend
|
||||
display_id: Display identifier
|
||||
"""
|
||||
try:
|
||||
vehicle = processing_data['vehicle']
|
||||
subscription_id = processing_data['subscription_id']
|
||||
detection_result = processing_data['detection_result']
|
||||
|
||||
logger.info(f"Executing processing phase for session {session_id}, vehicle {vehicle.track_id}")
|
||||
|
||||
# Reuse the snapshot from detection phase OR fetch fresh one if detection used RTSP fallback
|
||||
detection_frame = processing_data['frame']
|
||||
frame_height = detection_frame.shape[0]
|
||||
|
||||
# Check if detection phase used 2K snapshot (height > 1000) or RTSP fallback (height = 720)
|
||||
if frame_height >= 1000:
|
||||
# Detection used 2K snapshot - reuse it for consistent coordinates
|
||||
logger.info(f"[PROCESSING PHASE] Reusing 2K snapshot from detection phase ({detection_frame.shape[1]}x{detection_frame.shape[0]})")
|
||||
frame = detection_frame
|
||||
else:
|
||||
# Detection used RTSP fallback - need to fetch fresh 2K snapshot
|
||||
logger.warning(f"[PROCESSING PHASE] Detection used RTSP fallback ({detection_frame.shape[1]}x{detection_frame.shape[0]}), fetching fresh 2K snapshot")
|
||||
frame = self._fetch_snapshot()
|
||||
|
||||
if frame is None:
|
||||
logger.error(f"[PROCESSING PHASE] Failed to fetch snapshot and detection used RTSP - coordinate mismatch will occur!")
|
||||
logger.error(f"[PROCESSING PHASE] Cannot proceed with mismatched coordinates. Aborting processing phase.")
|
||||
return # Cannot process safely - bbox coordinates won't match frame resolution
|
||||
else:
|
||||
logger.warning(f"[PROCESSING PHASE] Fetched fresh 2K snapshot ({frame.shape[1]}x{frame.shape[0]}), but coordinates may not match exactly")
|
||||
logger.warning(f"[PROCESSING PHASE] Re-running detection on fresh snapshot is recommended but not implemented yet")
|
||||
|
||||
# Extract detected regions from detection phase result if available
|
||||
detected_regions = detection_result.get('detected_regions', {})
|
||||
logger.info(f"[INTEGRATION] Passing detected_regions to processing phase: {list(detected_regions.keys())}")
|
||||
|
||||
# Execute processing phase with detection pipeline
|
||||
if self.detection_pipeline:
|
||||
processing_result = await self.detection_pipeline.execute_processing_phase(
|
||||
frame=frame,
|
||||
display_id=display_id,
|
||||
session_id=session_id,
|
||||
subscription_id=subscription_id,
|
||||
detected_regions=detected_regions
|
||||
)
|
||||
|
||||
logger.info(f"Processing phase completed for session {session_id}: "
|
||||
f"status={processing_result.get('status', 'unknown')}, "
|
||||
f"branches={len(processing_result.get('branch_results', {}))}, "
|
||||
f"actions={len(processing_result.get('actions_executed', []))}, "
|
||||
f"processing_time={processing_result.get('processing_time', 0):.3f}s")
|
||||
|
||||
# Update stats
|
||||
self.stats['pipelines_executed'] += 1
|
||||
|
||||
else:
|
||||
logger.error("Detection pipeline not available for processing phase")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in processing phase for session {session_id}: {e}", exc_info=True)
|
||||
|
||||
|
||||
def set_subscription_info(self, subscription_info):
|
||||
"""
|
||||
Set subscription info to access snapshot URL and other stream details.
|
||||
|
||||
Args:
|
||||
subscription_info: SubscriptionInfo object containing stream config
|
||||
"""
|
||||
self.subscription_info = subscription_info
|
||||
logger.debug(f"Set subscription info with snapshot_url: {subscription_info.stream_config.snapshot_url if subscription_info else None}")
|
||||
|
||||
def set_session_id(self, display_id: str, session_id: str, subscription_id: str = None):
|
||||
"""
|
||||
Set session ID for a display (from backend).
|
||||
This is called when backend sends setSessionId after receiving imageDetection.
|
||||
|
||||
Args:
|
||||
display_id: Display identifier
|
||||
session_id: Session identifier
|
||||
subscription_id: Subscription identifier (displayId;cameraId) - needed for fallback
|
||||
"""
|
||||
# Ensure session_id is always a string for consistent type handling
|
||||
session_id = str(session_id) if session_id is not None else None
|
||||
self.active_sessions[display_id] = session_id
|
||||
|
||||
# Store subscription_id for fallback usage
|
||||
if subscription_id:
|
||||
self.display_to_subscription[display_id] = subscription_id
|
||||
logger.info(f"Set session {session_id} for display {display_id} with subscription {subscription_id}")
|
||||
else:
|
||||
logger.info(f"Set session {session_id} for display {display_id}")
|
||||
|
||||
# Check if we have a pending vehicle for this display
|
||||
if display_id in self.pending_vehicles:
|
||||
track_id = self.pending_vehicles[display_id]
|
||||
|
||||
# Mark vehicle as processed with the session ID
|
||||
self.tracker.mark_processed(track_id, session_id)
|
||||
self.session_vehicles[session_id] = track_id
|
||||
|
||||
# Mark vehicle as permanently processed (won't process again even after session clear)
|
||||
# Use composite key to distinguish same track IDs across different cameras
|
||||
camera_id = display_id # Using display_id as camera_id for isolation
|
||||
permanent_key = f"{camera_id}:{track_id}"
|
||||
self.permanently_processed[permanent_key] = time.time()
|
||||
|
||||
# Remove from pending
|
||||
del self.pending_vehicles[display_id]
|
||||
|
||||
logger.info(f"Assigned session {session_id} to vehicle {track_id}, marked as permanently processed")
|
||||
else:
|
||||
logger.warning(f"No pending vehicle found for display {display_id} when setting session {session_id}")
|
||||
|
||||
# Check if we have pending processing data for this display
|
||||
if display_id in self.pending_processing_data:
|
||||
processing_data = self.pending_processing_data[display_id]
|
||||
|
||||
# Trigger the processing phase asynchronously
|
||||
asyncio.create_task(self._execute_processing_phase(
|
||||
processing_data=processing_data,
|
||||
session_id=session_id,
|
||||
display_id=display_id
|
||||
))
|
||||
|
||||
# Remove from pending processing
|
||||
del self.pending_processing_data[display_id]
|
||||
|
||||
logger.info(f"Triggered processing phase for session {session_id} on display {display_id}")
|
||||
else:
|
||||
logger.warning(f"No pending processing data found for display {display_id} when setting session {session_id}")
|
||||
|
||||
# FALLBACK: Execute pipeline for POS-initiated sessions
|
||||
# Skip if session_id is None (no car present or car has left)
|
||||
if session_id is not None:
|
||||
# Use stored subscription_id instead of creating fake one
|
||||
stored_subscription_id = self.display_to_subscription.get(display_id)
|
||||
if stored_subscription_id:
|
||||
logger.info(f"[FALLBACK] Triggering fallback pipeline for session {session_id} on display {display_id} with subscription {stored_subscription_id}")
|
||||
|
||||
# Trigger the fallback pipeline asynchronously with real subscription_id
|
||||
asyncio.create_task(self._execute_fallback_pipeline(
|
||||
display_id=display_id,
|
||||
session_id=session_id,
|
||||
subscription_id=stored_subscription_id
|
||||
))
|
||||
else:
|
||||
logger.error(f"[FALLBACK] No subscription_id stored for display {display_id}, cannot execute fallback pipeline")
|
||||
else:
|
||||
logger.debug(f"[FALLBACK] Skipping pipeline execution for session_id=None on display {display_id}")
|
||||
|
||||
def clear_session_id(self, session_id: str):
|
||||
"""
|
||||
Clear session ID (post-fueling).
|
||||
|
||||
Args:
|
||||
session_id: Session identifier to clear
|
||||
"""
|
||||
# Mark session as cleared
|
||||
self.cleared_sessions[session_id] = time.time()
|
||||
|
||||
# Clear from tracker
|
||||
self.tracker.clear_session(session_id)
|
||||
|
||||
# Remove from active sessions
|
||||
display_to_remove = None
|
||||
for display_id, sess_id in self.active_sessions.items():
|
||||
if sess_id == session_id:
|
||||
display_to_remove = display_id
|
||||
break
|
||||
|
||||
if display_to_remove:
|
||||
del self.active_sessions[display_to_remove]
|
||||
|
||||
if session_id in self.session_vehicles:
|
||||
del self.session_vehicles[session_id]
|
||||
|
||||
logger.info(f"Cleared session {session_id}")
|
||||
|
||||
# Clean old cleared sessions (older than 5 minutes)
|
||||
current_time = time.time()
|
||||
old_sessions = [
|
||||
sid for sid, clear_time in self.cleared_sessions.items()
|
||||
if (current_time - clear_time) > 300
|
||||
]
|
||||
for sid in old_sessions:
|
||||
del self.cleared_sessions[sid]
|
||||
|
||||
def get_session_for_display(self, display_id: str) -> Optional[str]:
|
||||
"""Get active session for a display."""
|
||||
return self.active_sessions.get(display_id)
|
||||
|
||||
def reset_tracking(self):
|
||||
"""Reset all tracking state."""
|
||||
self.tracker.reset_tracking()
|
||||
self.active_sessions.clear()
|
||||
self.session_vehicles.clear()
|
||||
self.cleared_sessions.clear()
|
||||
self.pending_vehicles.clear()
|
||||
self.pending_processing_data.clear()
|
||||
self.display_to_subscription.clear()
|
||||
self.permanently_processed.clear()
|
||||
self.progression_stages.clear()
|
||||
self.last_detection_time.clear()
|
||||
logger.info("Tracking pipeline integration reset")
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive statistics."""
|
||||
tracker_stats = self.tracker.get_statistics()
|
||||
validator_stats = self.validator.get_statistics()
|
||||
|
||||
return {
|
||||
'integration': self.stats,
|
||||
'tracker': tracker_stats,
|
||||
'validator': validator_stats,
|
||||
'active_sessions': len(self.active_sessions),
|
||||
'cleared_sessions': len(self.cleared_sessions)
|
||||
}
|
||||
|
||||
async def _check_car_abandonment(self, display_id: str, subscription_id: str):
|
||||
"""
|
||||
Check if a car has abandoned the fueling process (left after getting car_wait_staff stage).
|
||||
|
||||
Args:
|
||||
display_id: Display identifier
|
||||
subscription_id: Subscription identifier
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# Check all sessions in car_wait_staff stage
|
||||
abandoned_sessions = []
|
||||
for session_id, stage in self.progression_stages.items():
|
||||
if stage == "car_wait_staff":
|
||||
# Check if we have recent detections for this session's display
|
||||
session_display = None
|
||||
for disp_id, sess_id in self.active_sessions.items():
|
||||
if sess_id == session_id:
|
||||
session_display = disp_id
|
||||
break
|
||||
|
||||
if session_display:
|
||||
last_detection = self.last_detection_time.get(session_display, 0)
|
||||
time_since_detection = current_time - last_detection
|
||||
|
||||
logger.info(f"[ABANDON CHECK] Session {session_id} (display: {session_display}): "
|
||||
f"time_since_detection={time_since_detection:.1f}s, "
|
||||
f"timeout={self.abandonment_timeout}s")
|
||||
|
||||
if time_since_detection > self.abandonment_timeout:
|
||||
logger.warning(f"🚨 Car abandonment detected: session {session_id}, "
|
||||
f"no detection for {time_since_detection:.1f}s")
|
||||
abandoned_sessions.append(session_id)
|
||||
else:
|
||||
logger.debug(f"[ABANDON CHECK] Session {session_id} has no associated display")
|
||||
|
||||
# Send abandonment detection for each abandoned session
|
||||
for session_id in abandoned_sessions:
|
||||
await self._send_abandonment_detection(subscription_id, session_id)
|
||||
# Remove from progression stages to avoid repeated detection
|
||||
if session_id in self.progression_stages:
|
||||
del self.progression_stages[session_id]
|
||||
logger.info(f"[ABANDON] Removed session {session_id} from progression_stages after notification")
|
||||
|
||||
async def _send_abandonment_detection(self, subscription_id: str, session_id: str):
|
||||
"""
|
||||
Send imageDetection with null detection to indicate car abandonment.
|
||||
|
||||
Args:
|
||||
subscription_id: Subscription identifier
|
||||
session_id: Session ID of the abandoned car
|
||||
"""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from ..communication.messages import create_image_detection
|
||||
|
||||
# Create abandonment detection message with null detection
|
||||
detection_message = create_image_detection(
|
||||
subscription_identifier=subscription_id,
|
||||
detection_data=None, # Null detection indicates abandonment
|
||||
model_id=self.model_id,
|
||||
model_name=self.pipeline_parser.tracking_config.model_id if self.pipeline_parser.tracking_config else "tracking_model"
|
||||
)
|
||||
|
||||
# Send to backend via WebSocket if sender is available
|
||||
if self.message_sender:
|
||||
await self.message_sender(detection_message)
|
||||
logger.info(f"[CAR ABANDONMENT] Sent null detection for session {session_id}")
|
||||
else:
|
||||
logger.info(f"[CAR ABANDONMENT] No message sender available, would send: {detection_message}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending abandonment detection: {e}", exc_info=True)
|
||||
|
||||
def set_progression_stage(self, session_id: str, stage: str):
|
||||
"""
|
||||
Set progression stage for a session (from backend setProgessionStage message).
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
stage: Progression stage (e.g., "car_wait_staff")
|
||||
"""
|
||||
self.progression_stages[session_id] = stage
|
||||
logger.info(f"Set progression stage for session {session_id}: {stage}")
|
||||
|
||||
# If car reaches car_wait_staff, start monitoring for abandonment
|
||||
if stage == "car_wait_staff":
|
||||
logger.info(f"Started monitoring session {session_id} for car abandonment")
|
||||
|
||||
def _fetch_snapshot(self) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Fetch high-quality snapshot from camera's snapshot URL.
|
||||
Reusable method for both processing phase and fallback pipeline.
|
||||
|
||||
Returns:
|
||||
Snapshot frame or None if unavailable
|
||||
"""
|
||||
if not (self.subscription_info and self.subscription_info.stream_config.snapshot_url):
|
||||
logger.warning("[SNAPSHOT] No subscription info or snapshot URL available")
|
||||
return None
|
||||
|
||||
try:
|
||||
from ..streaming.readers import HTTPSnapshotReader
|
||||
|
||||
logger.info(f"[SNAPSHOT] Fetching snapshot for {self.subscription_info.camera_id}")
|
||||
snapshot_reader = HTTPSnapshotReader(
|
||||
camera_id=self.subscription_info.camera_id,
|
||||
snapshot_url=self.subscription_info.stream_config.snapshot_url,
|
||||
max_retries=3
|
||||
)
|
||||
|
||||
frame = snapshot_reader.fetch_single_snapshot()
|
||||
|
||||
if frame is not None:
|
||||
logger.info(f"[SNAPSHOT] Successfully fetched {frame.shape[1]}x{frame.shape[0]} snapshot")
|
||||
return frame
|
||||
else:
|
||||
logger.warning("[SNAPSHOT] Failed to fetch snapshot")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SNAPSHOT] Error fetching snapshot: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _execute_fallback_pipeline(self, display_id: str, session_id: str, subscription_id: str):
|
||||
"""
|
||||
Execute fallback pipeline when sessionId is received without prior detection.
|
||||
This handles POS-initiated sessions where backend starts transaction before car detection.
|
||||
|
||||
Args:
|
||||
display_id: Display identifier
|
||||
session_id: Session ID from backend
|
||||
subscription_id: Subscription identifier for pipeline execution
|
||||
"""
|
||||
try:
|
||||
logger.info(f"[FALLBACK PIPELINE] Executing for session {session_id}, display {display_id}")
|
||||
|
||||
# Fetch fresh snapshot from camera
|
||||
frame = self._fetch_snapshot()
|
||||
|
||||
if frame is None:
|
||||
logger.error(f"[FALLBACK] Failed to fetch snapshot for session {session_id}, cannot execute pipeline")
|
||||
return
|
||||
|
||||
logger.info(f"[FALLBACK] Using snapshot frame {frame.shape[1]}x{frame.shape[0]} for session {session_id}")
|
||||
|
||||
# Check if detection pipeline is available
|
||||
if not self.detection_pipeline:
|
||||
logger.error(f"[FALLBACK] Detection pipeline not available for session {session_id}")
|
||||
return
|
||||
|
||||
# Execute detection phase to get detected regions
|
||||
detection_result = await self.detection_pipeline.execute_detection_phase(
|
||||
frame=frame,
|
||||
display_id=display_id,
|
||||
subscription_id=subscription_id
|
||||
)
|
||||
|
||||
logger.info(f"[FALLBACK] Detection phase completed for session {session_id}: "
|
||||
f"status={detection_result.get('status', 'unknown')}, "
|
||||
f"regions={list(detection_result.get('detected_regions', {}).keys())}")
|
||||
|
||||
# If detection found regions, execute processing phase
|
||||
detected_regions = detection_result.get('detected_regions', {})
|
||||
if detected_regions:
|
||||
processing_result = await self.detection_pipeline.execute_processing_phase(
|
||||
frame=frame,
|
||||
display_id=display_id,
|
||||
session_id=session_id,
|
||||
subscription_id=subscription_id,
|
||||
detected_regions=detected_regions
|
||||
)
|
||||
|
||||
logger.info(f"[FALLBACK] Processing phase completed for session {session_id}: "
|
||||
f"status={processing_result.get('status', 'unknown')}, "
|
||||
f"branches={len(processing_result.get('branch_results', {}))}, "
|
||||
f"actions={len(processing_result.get('actions_executed', []))}")
|
||||
|
||||
# Update statistics
|
||||
self.stats['pipelines_executed'] += 1
|
||||
|
||||
else:
|
||||
logger.warning(f"[FALLBACK] No detections found in snapshot for session {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[FALLBACK] Error executing fallback pipeline for session {session_id}: {e}", exc_info=True)
|
||||
|
||||
def _filter_small_frontals(self, tracking_results, frame):
|
||||
"""
|
||||
Filter out frontal detections that are smaller than minimum bbox area percentage.
|
||||
This prevents processing of cars from neighboring pumps that appear in camera view.
|
||||
|
||||
Args:
|
||||
tracking_results: YOLO tracking results with detections
|
||||
frame: Input frame for calculating frame area
|
||||
|
||||
Returns:
|
||||
Modified tracking_results with small frontals removed
|
||||
"""
|
||||
if not hasattr(tracking_results, 'detections') or not tracking_results.detections:
|
||||
return tracking_results
|
||||
|
||||
# Calculate frame area and minimum bbox area threshold
|
||||
frame_area = frame.shape[0] * frame.shape[1] # height * width
|
||||
min_bbox_area = frame_area * (self.min_bbox_area_percentage / 100.0)
|
||||
|
||||
# Filter detections
|
||||
filtered_detections = []
|
||||
filtered_count = 0
|
||||
|
||||
for detection in tracking_results.detections:
|
||||
# Calculate detection bbox area
|
||||
bbox = detection.bbox # Assuming bbox is [x1, y1, x2, y2]
|
||||
bbox_area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
||||
|
||||
if bbox_area >= min_bbox_area:
|
||||
# Keep detection - bbox is large enough
|
||||
filtered_detections.append(detection)
|
||||
else:
|
||||
# Filter out small detection
|
||||
filtered_count += 1
|
||||
area_percentage = (bbox_area / frame_area) * 100
|
||||
logger.debug(f"Filtered small frontal: area={bbox_area:.0f}px² ({area_percentage:.1f}% of frame, "
|
||||
f"min required: {self.min_bbox_area_percentage}%)")
|
||||
|
||||
# Update tracking results with filtered detections
|
||||
tracking_results.detections = filtered_detections
|
||||
|
||||
# Update statistics
|
||||
if filtered_count > 0:
|
||||
self.stats['frontals_filtered_small'] += filtered_count
|
||||
logger.info(f"Filtered {filtered_count} small frontal detections, "
|
||||
f"{len(filtered_detections)} remaining (total filtered: {self.stats['frontals_filtered_small']})")
|
||||
|
||||
return tracking_results
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources."""
|
||||
self.executor.shutdown(wait=False)
|
||||
self.reset_tracking()
|
||||
|
||||
|
||||
# Cleanup detection pipeline
|
||||
if self.detection_pipeline:
|
||||
self.detection_pipeline.cleanup()
|
||||
|
||||
logger.info("Tracking pipeline integration cleaned up")
|
||||
|
|
@ -1,293 +0,0 @@
|
|||
"""
|
||||
Vehicle Tracking Module - BoT-SORT based tracking with camera isolation
|
||||
Implements vehicle identification, persistence, and motion analysis using external tracker.
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass, field
|
||||
import numpy as np
|
||||
from threading import Lock
|
||||
|
||||
from .bot_sort_tracker import MultiCameraBoTSORT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrackedVehicle:
|
||||
"""Represents a tracked vehicle with all its state information."""
|
||||
track_id: int
|
||||
camera_id: str
|
||||
first_seen: float
|
||||
last_seen: float
|
||||
session_id: Optional[str] = None
|
||||
display_id: Optional[str] = None
|
||||
confidence: float = 0.0
|
||||
bbox: Tuple[int, int, int, int] = (0, 0, 0, 0) # x1, y1, x2, y2
|
||||
center: Tuple[float, float] = (0.0, 0.0)
|
||||
stable_frames: int = 0
|
||||
total_frames: int = 0
|
||||
is_stable: bool = False
|
||||
processed_pipeline: bool = False
|
||||
last_position_history: List[Tuple[float, float]] = field(default_factory=list)
|
||||
avg_confidence: float = 0.0
|
||||
hit_streak: int = 0
|
||||
age: int = 0
|
||||
|
||||
def update_position(self, bbox: Tuple[int, int, int, int], confidence: float):
|
||||
"""Update vehicle position and confidence."""
|
||||
self.bbox = bbox
|
||||
self.center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
|
||||
self.last_seen = time.time()
|
||||
self.confidence = confidence
|
||||
self.total_frames += 1
|
||||
|
||||
# Update confidence average
|
||||
self.avg_confidence = ((self.avg_confidence * (self.total_frames - 1)) + confidence) / self.total_frames
|
||||
|
||||
# Maintain position history (last 10 positions)
|
||||
self.last_position_history.append(self.center)
|
||||
if len(self.last_position_history) > 10:
|
||||
self.last_position_history.pop(0)
|
||||
|
||||
def calculate_stability(self) -> float:
|
||||
"""Calculate stability score based on position history."""
|
||||
if len(self.last_position_history) < 2:
|
||||
return 0.0
|
||||
|
||||
# Calculate movement variance
|
||||
positions = np.array(self.last_position_history)
|
||||
if len(positions) < 2:
|
||||
return 0.0
|
||||
|
||||
# Calculate standard deviation of positions
|
||||
std_x = np.std(positions[:, 0])
|
||||
std_y = np.std(positions[:, 1])
|
||||
|
||||
# Lower variance means more stable (inverse relationship)
|
||||
# Normalize to 0-1 range (assuming max reasonable std is 50 pixels)
|
||||
stability = max(0, 1 - (std_x + std_y) / 100)
|
||||
return stability
|
||||
|
||||
def is_expired(self, timeout_seconds: float = 2.0) -> bool:
|
||||
"""Check if vehicle tracking has expired."""
|
||||
return (time.time() - self.last_seen) > timeout_seconds
|
||||
|
||||
|
||||
class VehicleTracker:
|
||||
"""
|
||||
Main vehicle tracking implementation using BoT-SORT with camera isolation.
|
||||
Manages continuous tracking, vehicle identification, and state persistence.
|
||||
"""
|
||||
|
||||
def __init__(self, tracking_config: Optional[Dict] = None):
|
||||
"""
|
||||
Initialize the vehicle tracker.
|
||||
|
||||
Args:
|
||||
tracking_config: Configuration from pipeline.json tracking section
|
||||
"""
|
||||
self.config = tracking_config or {}
|
||||
self.trigger_classes = self.config.get('trigger_classes', self.config.get('triggerClasses', ['frontal']))
|
||||
self.min_confidence = self.config.get('minConfidence', 0.6)
|
||||
|
||||
# BoT-SORT multi-camera tracker
|
||||
self.bot_sort = MultiCameraBoTSORT(self.trigger_classes, self.min_confidence)
|
||||
|
||||
# Tracking state - maintain compatibility with existing code
|
||||
self.tracked_vehicles: Dict[str, Dict[int, TrackedVehicle]] = {} # camera_id -> {track_id: vehicle}
|
||||
self.lock = Lock()
|
||||
|
||||
# Tracking parameters
|
||||
self.stability_threshold = 0.7
|
||||
self.min_stable_frames = 5
|
||||
self.timeout_seconds = 2.0
|
||||
|
||||
logger.info(f"VehicleTracker initialized with BoT-SORT: trigger_classes={self.trigger_classes}, "
|
||||
f"min_confidence={self.min_confidence}")
|
||||
|
||||
def process_detections(self,
|
||||
results: Any,
|
||||
display_id: str,
|
||||
frame: np.ndarray) -> List[TrackedVehicle]:
|
||||
"""
|
||||
Process detection results using BoT-SORT tracking.
|
||||
|
||||
Args:
|
||||
results: Detection results (InferenceResult)
|
||||
display_id: Display identifier for this stream
|
||||
frame: Current frame being processed
|
||||
|
||||
Returns:
|
||||
List of currently tracked vehicles
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# Extract camera_id from display_id for tracking isolation
|
||||
camera_id = display_id # Using display_id as camera_id for isolation
|
||||
|
||||
with self.lock:
|
||||
# Update BoT-SORT tracker
|
||||
track_results = self.bot_sort.update(camera_id, results)
|
||||
|
||||
# Ensure camera tracking dict exists
|
||||
if camera_id not in self.tracked_vehicles:
|
||||
self.tracked_vehicles[camera_id] = {}
|
||||
|
||||
# Update tracked vehicles based on BoT-SORT results
|
||||
current_tracks = {}
|
||||
active_tracks = []
|
||||
|
||||
for track_result in track_results:
|
||||
track_id = track_result['track_id']
|
||||
|
||||
# Create or update TrackedVehicle
|
||||
if track_id in self.tracked_vehicles[camera_id]:
|
||||
# Update existing vehicle
|
||||
vehicle = self.tracked_vehicles[camera_id][track_id]
|
||||
vehicle.update_position(track_result['bbox'], track_result['confidence'])
|
||||
vehicle.hit_streak = track_result['hit_streak']
|
||||
vehicle.age = track_result['age']
|
||||
|
||||
# Update stability based on hit_streak
|
||||
if vehicle.hit_streak >= self.min_stable_frames:
|
||||
vehicle.is_stable = True
|
||||
vehicle.stable_frames = vehicle.hit_streak
|
||||
|
||||
logger.debug(f"Updated track {track_id}: conf={vehicle.confidence:.2f}, "
|
||||
f"stable={vehicle.is_stable}, hit_streak={vehicle.hit_streak}")
|
||||
else:
|
||||
# Create new vehicle
|
||||
x1, y1, x2, y2 = track_result['bbox']
|
||||
vehicle = TrackedVehicle(
|
||||
track_id=track_id,
|
||||
camera_id=camera_id,
|
||||
first_seen=current_time,
|
||||
last_seen=current_time,
|
||||
display_id=display_id,
|
||||
confidence=track_result['confidence'],
|
||||
bbox=tuple(track_result['bbox']),
|
||||
center=((x1 + x2) / 2, (y1 + y2) / 2),
|
||||
total_frames=1,
|
||||
hit_streak=track_result['hit_streak'],
|
||||
age=track_result['age']
|
||||
)
|
||||
vehicle.last_position_history.append(vehicle.center)
|
||||
logger.info(f"New vehicle tracked: ID={track_id}, camera={camera_id}, display={display_id}")
|
||||
|
||||
current_tracks[track_id] = vehicle
|
||||
active_tracks.append(vehicle)
|
||||
|
||||
# Update the camera's tracked vehicles
|
||||
self.tracked_vehicles[camera_id] = current_tracks
|
||||
|
||||
return active_tracks
|
||||
|
||||
def get_stable_vehicles(self, display_id: Optional[str] = None) -> List[TrackedVehicle]:
|
||||
"""
|
||||
Get all stable vehicles, optionally filtered by display.
|
||||
|
||||
Args:
|
||||
display_id: Optional display ID to filter by
|
||||
|
||||
Returns:
|
||||
List of stable tracked vehicles
|
||||
"""
|
||||
with self.lock:
|
||||
stable = []
|
||||
camera_id = display_id # Using display_id as camera_id
|
||||
|
||||
if camera_id in self.tracked_vehicles:
|
||||
for vehicle in self.tracked_vehicles[camera_id].values():
|
||||
if (vehicle.is_stable and not vehicle.is_expired(self.timeout_seconds) and
|
||||
(display_id is None or vehicle.display_id == display_id)):
|
||||
stable.append(vehicle)
|
||||
|
||||
return stable
|
||||
|
||||
def get_vehicle_by_session(self, session_id: str) -> Optional[TrackedVehicle]:
|
||||
"""
|
||||
Get a tracked vehicle by its session ID.
|
||||
|
||||
Args:
|
||||
session_id: Session ID to look up
|
||||
|
||||
Returns:
|
||||
Tracked vehicle if found, None otherwise
|
||||
"""
|
||||
with self.lock:
|
||||
# Search across all cameras
|
||||
for camera_vehicles in self.tracked_vehicles.values():
|
||||
for vehicle in camera_vehicles.values():
|
||||
if vehicle.session_id == session_id:
|
||||
return vehicle
|
||||
return None
|
||||
|
||||
def mark_processed(self, track_id: int, session_id: str):
|
||||
"""
|
||||
Mark a vehicle as processed through the pipeline.
|
||||
|
||||
Args:
|
||||
track_id: Track ID of the vehicle
|
||||
session_id: Session ID assigned to this vehicle
|
||||
"""
|
||||
with self.lock:
|
||||
# Search across all cameras for the track_id
|
||||
for camera_vehicles in self.tracked_vehicles.values():
|
||||
if track_id in camera_vehicles:
|
||||
vehicle = camera_vehicles[track_id]
|
||||
vehicle.processed_pipeline = True
|
||||
vehicle.session_id = session_id
|
||||
logger.info(f"Marked vehicle {track_id} as processed with session {session_id}")
|
||||
return
|
||||
|
||||
def clear_session(self, session_id: str):
|
||||
"""
|
||||
Clear session ID from a tracked vehicle (post-fueling).
|
||||
|
||||
Args:
|
||||
session_id: Session ID to clear
|
||||
"""
|
||||
with self.lock:
|
||||
# Search across all cameras
|
||||
for camera_vehicles in self.tracked_vehicles.values():
|
||||
for vehicle in camera_vehicles.values():
|
||||
if vehicle.session_id == session_id:
|
||||
logger.info(f"Clearing session {session_id} from vehicle {vehicle.track_id}")
|
||||
vehicle.session_id = None
|
||||
# Keep processed_pipeline=True to prevent re-processing
|
||||
|
||||
def reset_tracking(self):
|
||||
"""Reset all tracking state."""
|
||||
with self.lock:
|
||||
self.tracked_vehicles.clear()
|
||||
self.bot_sort.reset_all()
|
||||
logger.info("Vehicle tracking state reset")
|
||||
|
||||
def get_statistics(self) -> Dict:
|
||||
"""Get tracking statistics."""
|
||||
with self.lock:
|
||||
total = 0
|
||||
stable = 0
|
||||
processed = 0
|
||||
all_confidences = []
|
||||
|
||||
# Aggregate stats across all cameras
|
||||
for camera_vehicles in self.tracked_vehicles.values():
|
||||
total += len(camera_vehicles)
|
||||
for vehicle in camera_vehicles.values():
|
||||
if vehicle.is_stable:
|
||||
stable += 1
|
||||
if vehicle.processed_pipeline:
|
||||
processed += 1
|
||||
all_confidences.append(vehicle.avg_confidence)
|
||||
|
||||
return {
|
||||
'total_tracked': total,
|
||||
'stable_vehicles': stable,
|
||||
'processed_vehicles': processed,
|
||||
'avg_confidence': np.mean(all_confidences) if all_confidences else 0.0,
|
||||
'bot_sort_stats': self.bot_sort.get_statistics()
|
||||
}
|
||||
|
|
@ -1,407 +0,0 @@
|
|||
"""
|
||||
Vehicle Validation Module - Stable car detection and validation logic.
|
||||
Differentiates between stable (fueling) cars and passing-by vehicles.
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from .tracker import TrackedVehicle
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VehicleState(Enum):
|
||||
"""Vehicle state classification."""
|
||||
UNKNOWN = "unknown"
|
||||
ENTERING = "entering"
|
||||
STABLE = "stable"
|
||||
LEAVING = "leaving"
|
||||
PASSING_BY = "passing_by"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of vehicle validation."""
|
||||
is_valid: bool
|
||||
state: VehicleState
|
||||
confidence: float
|
||||
reason: str
|
||||
should_process: bool = False
|
||||
track_id: Optional[int] = None
|
||||
|
||||
|
||||
class StableCarValidator:
|
||||
"""
|
||||
Validates whether a tracked vehicle should be processed through the pipeline.
|
||||
|
||||
Updated for BoT-SORT integration: Trusts the sophisticated BoT-SORT tracking algorithm
|
||||
for stability determination and focuses on business logic validation:
|
||||
- Duration requirements for processing
|
||||
- Confidence thresholds
|
||||
- Session management and cooldowns
|
||||
- Camera isolation with composite keys
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None):
|
||||
"""
|
||||
Initialize the validator with configuration.
|
||||
|
||||
Args:
|
||||
config: Optional configuration dictionary
|
||||
"""
|
||||
self.config = config or {}
|
||||
|
||||
# Validation thresholds
|
||||
# Optimized for 6 FPS RTSP source with 8 concurrent cameras on GPU
|
||||
# GPU contention reduces effective FPS to ~3-5 per camera
|
||||
# Reduced from 3.0s to 1.5s to achieve ~2.75s total validation time (was ~4.25s)
|
||||
self.min_stable_duration = self.config.get('min_stable_duration', 1.5) # seconds
|
||||
# Reduced from 10 to 5 to align with tracker requirement and reduce validation time
|
||||
self.min_stable_frames = self.config.get('min_stable_frames', 5)
|
||||
self.position_variance_threshold = self.config.get('position_variance_threshold', 25.0) # pixels
|
||||
# Reduced from 0.7 to 0.45 to be more permissive under GPU load
|
||||
self.min_confidence = self.config.get('min_confidence', 0.45)
|
||||
self.velocity_threshold = self.config.get('velocity_threshold', 5.0) # pixels/frame
|
||||
self.entering_zone_ratio = self.config.get('entering_zone_ratio', 0.3) # 30% of frame
|
||||
self.leaving_zone_ratio = self.config.get('leaving_zone_ratio', 0.3)
|
||||
|
||||
# Frame dimensions (will be updated on first frame)
|
||||
self.frame_width = 1920
|
||||
self.frame_height = 1080
|
||||
|
||||
# History for validation
|
||||
self.validation_history: Dict[int, List[VehicleState]] = {}
|
||||
self.last_processed_vehicles: Dict[int, float] = {} # track_id -> last_process_time
|
||||
|
||||
logger.info(f"StableCarValidator initialized with min_duration={self.min_stable_duration}s, "
|
||||
f"min_frames={self.min_stable_frames}, position_variance={self.position_variance_threshold}")
|
||||
|
||||
def update_frame_dimensions(self, width: int, height: int):
|
||||
"""Update frame dimensions for zone calculations."""
|
||||
self.frame_width = width
|
||||
self.frame_height = height
|
||||
# Commented out verbose frame dimension logging
|
||||
# logger.debug(f"Updated frame dimensions: {width}x{height}")
|
||||
|
||||
def validate_vehicle(self, vehicle: TrackedVehicle, frame_shape: Optional[Tuple] = None) -> ValidationResult:
|
||||
"""
|
||||
Validate whether a tracked vehicle is stable and should be processed.
|
||||
|
||||
Args:
|
||||
vehicle: The tracked vehicle to validate
|
||||
frame_shape: Optional frame shape (height, width, channels)
|
||||
|
||||
Returns:
|
||||
ValidationResult with validation status and reasoning
|
||||
"""
|
||||
# Update frame dimensions if provided
|
||||
if frame_shape:
|
||||
self.update_frame_dimensions(frame_shape[1], frame_shape[0])
|
||||
|
||||
# Initialize validation history for new vehicles
|
||||
if vehicle.track_id not in self.validation_history:
|
||||
self.validation_history[vehicle.track_id] = []
|
||||
|
||||
# Check if already processed
|
||||
if vehicle.processed_pipeline:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=VehicleState.STABLE,
|
||||
confidence=1.0,
|
||||
reason="Already processed through pipeline",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
|
||||
# Check if recently processed (cooldown period)
|
||||
if vehicle.track_id in self.last_processed_vehicles:
|
||||
time_since_process = time.time() - self.last_processed_vehicles[vehicle.track_id]
|
||||
if time_since_process < 10.0: # 10 second cooldown
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=VehicleState.STABLE,
|
||||
confidence=1.0,
|
||||
reason=f"Recently processed ({time_since_process:.1f}s ago)",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
|
||||
# Determine vehicle state
|
||||
state = self._determine_vehicle_state(vehicle)
|
||||
|
||||
# Update history
|
||||
self.validation_history[vehicle.track_id].append(state)
|
||||
if len(self.validation_history[vehicle.track_id]) > 20:
|
||||
self.validation_history[vehicle.track_id].pop(0)
|
||||
|
||||
# Validate based on state
|
||||
if state == VehicleState.STABLE:
|
||||
return self._validate_stable_vehicle(vehicle)
|
||||
elif state == VehicleState.PASSING_BY:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=state,
|
||||
confidence=0.8,
|
||||
reason="Vehicle is passing by",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
elif state == VehicleState.ENTERING:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=state,
|
||||
confidence=0.5,
|
||||
reason="Vehicle is entering, waiting for stability",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
elif state == VehicleState.LEAVING:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=state,
|
||||
confidence=0.5,
|
||||
reason="Vehicle is leaving",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
else:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=state,
|
||||
confidence=0.0,
|
||||
reason="Unknown vehicle state",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
|
||||
def _determine_vehicle_state(self, vehicle: TrackedVehicle) -> VehicleState:
|
||||
"""
|
||||
Determine the current state of the vehicle based on BoT-SORT tracking results.
|
||||
|
||||
BoT-SORT provides sophisticated tracking, so we trust its stability determination
|
||||
and focus on business logic validation.
|
||||
|
||||
Args:
|
||||
vehicle: The tracked vehicle
|
||||
|
||||
Returns:
|
||||
Current vehicle state
|
||||
"""
|
||||
# Trust BoT-SORT's stability determination
|
||||
if vehicle.is_stable:
|
||||
# Check if it's been stable long enough for processing
|
||||
duration = time.time() - vehicle.first_seen
|
||||
if duration >= self.min_stable_duration:
|
||||
return VehicleState.STABLE
|
||||
else:
|
||||
return VehicleState.ENTERING
|
||||
|
||||
# For non-stable vehicles, use simplified state determination
|
||||
if len(vehicle.last_position_history) < 2:
|
||||
return VehicleState.UNKNOWN
|
||||
|
||||
# Calculate velocity for movement classification
|
||||
velocity = self._calculate_velocity(vehicle)
|
||||
|
||||
# Basic movement classification
|
||||
if velocity > self.velocity_threshold:
|
||||
# Vehicle is moving - classify as passing by or entering/leaving
|
||||
x_position = vehicle.center[0] / self.frame_width
|
||||
|
||||
# Simple heuristic: vehicles near edges are entering/leaving, center vehicles are passing
|
||||
if x_position < 0.2 or x_position > 0.8:
|
||||
return VehicleState.ENTERING
|
||||
else:
|
||||
return VehicleState.PASSING_BY
|
||||
|
||||
# Low velocity but not marked stable by tracker - likely entering
|
||||
return VehicleState.ENTERING
|
||||
|
||||
def _validate_stable_vehicle(self, vehicle: TrackedVehicle) -> ValidationResult:
|
||||
"""
|
||||
Perform business logic validation of a stable vehicle.
|
||||
|
||||
Since BoT-SORT already determined the vehicle is stable, we focus on:
|
||||
- Duration requirements for processing
|
||||
- Confidence thresholds
|
||||
- Business logic constraints
|
||||
|
||||
Args:
|
||||
vehicle: The stable vehicle to validate
|
||||
|
||||
Returns:
|
||||
Detailed validation result
|
||||
"""
|
||||
# Check duration (business requirement)
|
||||
duration = time.time() - vehicle.first_seen
|
||||
if duration < self.min_stable_duration:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=VehicleState.STABLE,
|
||||
confidence=0.6,
|
||||
reason=f"Not stable long enough ({duration:.1f}s < {self.min_stable_duration}s)",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
|
||||
# Check confidence (business requirement)
|
||||
if vehicle.avg_confidence < self.min_confidence:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=VehicleState.STABLE,
|
||||
confidence=vehicle.avg_confidence,
|
||||
reason=f"Confidence too low ({vehicle.avg_confidence:.2f} < {self.min_confidence})",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
|
||||
# Trust BoT-SORT's stability determination - skip position variance check
|
||||
# BoT-SORT's sophisticated tracking already ensures consistent positioning
|
||||
|
||||
# Simplified state history check - just ensure recent stability
|
||||
if vehicle.track_id in self.validation_history:
|
||||
history = self.validation_history[vehicle.track_id][-3:] # Last 3 states
|
||||
stable_count = sum(1 for s in history if s == VehicleState.STABLE)
|
||||
if len(history) >= 2 and stable_count == 0: # Only fail if clear instability
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
state=VehicleState.STABLE,
|
||||
confidence=0.7,
|
||||
reason="Recent state history shows instability",
|
||||
should_process=False,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
|
||||
# All checks passed - vehicle is valid for processing
|
||||
self.last_processed_vehicles[vehicle.track_id] = time.time()
|
||||
|
||||
return ValidationResult(
|
||||
is_valid=True,
|
||||
state=VehicleState.STABLE,
|
||||
confidence=vehicle.avg_confidence,
|
||||
reason="Vehicle is stable and ready for processing (BoT-SORT validated)",
|
||||
should_process=True,
|
||||
track_id=vehicle.track_id
|
||||
)
|
||||
|
||||
def _calculate_velocity(self, vehicle: TrackedVehicle) -> float:
|
||||
"""
|
||||
Calculate the velocity of the vehicle based on position history.
|
||||
|
||||
Args:
|
||||
vehicle: The tracked vehicle
|
||||
|
||||
Returns:
|
||||
Velocity in pixels per frame
|
||||
"""
|
||||
if len(vehicle.last_position_history) < 2:
|
||||
return 0.0
|
||||
|
||||
positions = np.array(vehicle.last_position_history)
|
||||
if len(positions) < 2:
|
||||
return 0.0
|
||||
|
||||
# Calculate velocity over last 3 frames
|
||||
recent_positions = positions[-min(3, len(positions)):]
|
||||
velocities = []
|
||||
|
||||
for i in range(1, len(recent_positions)):
|
||||
dx = recent_positions[i][0] - recent_positions[i-1][0]
|
||||
dy = recent_positions[i][1] - recent_positions[i-1][1]
|
||||
velocity = np.sqrt(dx**2 + dy**2)
|
||||
velocities.append(velocity)
|
||||
|
||||
return np.mean(velocities) if velocities else 0.0
|
||||
|
||||
def _calculate_position_variance(self, vehicle: TrackedVehicle) -> float:
|
||||
"""
|
||||
Calculate the position variance of the vehicle.
|
||||
|
||||
Args:
|
||||
vehicle: The tracked vehicle
|
||||
|
||||
Returns:
|
||||
Position variance in pixels
|
||||
"""
|
||||
if len(vehicle.last_position_history) < 2:
|
||||
return 0.0
|
||||
|
||||
positions = np.array(vehicle.last_position_history)
|
||||
variance_x = np.var(positions[:, 0])
|
||||
variance_y = np.var(positions[:, 1])
|
||||
|
||||
return np.sqrt(variance_x + variance_y)
|
||||
|
||||
def should_skip_same_car(self,
|
||||
vehicle: TrackedVehicle,
|
||||
session_cleared: bool = False,
|
||||
permanently_processed: Dict[str, float] = None) -> bool:
|
||||
"""
|
||||
Determine if we should skip processing for the same car after session clear.
|
||||
|
||||
Args:
|
||||
vehicle: The tracked vehicle
|
||||
session_cleared: Whether the session was recently cleared
|
||||
permanently_processed: Dict of permanently processed vehicles (camera_id:track_id -> time)
|
||||
|
||||
Returns:
|
||||
True if we should skip this vehicle
|
||||
"""
|
||||
# Check if this vehicle was permanently processed (never process again)
|
||||
if permanently_processed:
|
||||
# Create composite key using camera_id and track_id
|
||||
permanent_key = f"{vehicle.camera_id}:{vehicle.track_id}"
|
||||
if permanent_key in permanently_processed:
|
||||
process_time = permanently_processed[permanent_key]
|
||||
time_since = time.time() - process_time
|
||||
logger.debug(f"Skipping permanently processed vehicle {vehicle.track_id} on camera {vehicle.camera_id} "
|
||||
f"(processed {time_since:.1f}s ago)")
|
||||
return True
|
||||
|
||||
# If vehicle has a session_id but it was cleared, skip for a period
|
||||
if vehicle.session_id is None and vehicle.processed_pipeline and session_cleared:
|
||||
# Check if enough time has passed since processing
|
||||
if vehicle.track_id in self.last_processed_vehicles:
|
||||
time_since = time.time() - self.last_processed_vehicles[vehicle.track_id]
|
||||
if time_since < 30.0: # 30 second cooldown after session clear
|
||||
logger.debug(f"Skipping same car {vehicle.track_id} after session clear "
|
||||
f"({time_since:.1f}s since processing)")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def reset_vehicle(self, track_id: int):
|
||||
"""
|
||||
Reset validation state for a specific vehicle.
|
||||
|
||||
Args:
|
||||
track_id: Track ID of the vehicle to reset
|
||||
"""
|
||||
if track_id in self.validation_history:
|
||||
del self.validation_history[track_id]
|
||||
if track_id in self.last_processed_vehicles:
|
||||
del self.last_processed_vehicles[track_id]
|
||||
logger.debug(f"Reset validation state for vehicle {track_id}")
|
||||
|
||||
def get_statistics(self) -> Dict:
|
||||
"""Get validation statistics."""
|
||||
return {
|
||||
'vehicles_in_history': len(self.validation_history),
|
||||
'recently_processed': len(self.last_processed_vehicles),
|
||||
'state_distribution': self._get_state_distribution()
|
||||
}
|
||||
|
||||
def _get_state_distribution(self) -> Dict[str, int]:
|
||||
"""Get distribution of current vehicle states."""
|
||||
distribution = {state.value: 0 for state in VehicleState}
|
||||
|
||||
for history in self.validation_history.values():
|
||||
if history:
|
||||
current_state = history[-1]
|
||||
distribution[current_state.value] += 1
|
||||
|
||||
return distribution
|
||||
|
|
@ -1,214 +0,0 @@
|
|||
"""
|
||||
FFmpeg hardware acceleration detection and configuration
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger("detector_worker")
|
||||
|
||||
|
||||
class FFmpegCapabilities:
|
||||
"""Detect and configure FFmpeg hardware acceleration capabilities."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize FFmpeg capabilities detector."""
|
||||
self.hwaccels = []
|
||||
self.codecs = {}
|
||||
self.nvidia_support = False
|
||||
self.vaapi_support = False
|
||||
self.qsv_support = False
|
||||
|
||||
self._detect_capabilities()
|
||||
|
||||
def _detect_capabilities(self):
|
||||
"""Detect available hardware acceleration methods."""
|
||||
try:
|
||||
# Get hardware accelerators
|
||||
result = subprocess.run(
|
||||
['ffmpeg', '-hide_banner', '-hwaccels'],
|
||||
capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if result.returncode == 0:
|
||||
self.hwaccels = [line.strip() for line in result.stdout.strip().split('\n')[1:] if line.strip()]
|
||||
logger.info(f"Available FFmpeg hardware accelerators: {', '.join(self.hwaccels)}")
|
||||
|
||||
# Check for NVIDIA support
|
||||
self.nvidia_support = any(hw in self.hwaccels for hw in ['cuda', 'cuvid', 'nvdec'])
|
||||
self.vaapi_support = 'vaapi' in self.hwaccels
|
||||
self.qsv_support = 'qsv' in self.hwaccels
|
||||
|
||||
# Get decoder information
|
||||
self._detect_decoders()
|
||||
|
||||
# Log capabilities
|
||||
if self.nvidia_support:
|
||||
logger.info("NVIDIA hardware acceleration available (CUDA/CUVID/NVDEC)")
|
||||
logger.info(f"Detected hardware codecs: {self.codecs}")
|
||||
if self.vaapi_support:
|
||||
logger.info("VAAPI hardware acceleration available")
|
||||
if self.qsv_support:
|
||||
logger.info("Intel QuickSync hardware acceleration available")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to detect FFmpeg capabilities: {e}")
|
||||
|
||||
def _detect_decoders(self):
|
||||
"""Detect available hardware decoders."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['ffmpeg', '-hide_banner', '-decoders'],
|
||||
capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if result.returncode == 0:
|
||||
# Parse decoder output to find hardware decoders
|
||||
for line in result.stdout.split('\n'):
|
||||
if 'cuvid' in line or 'nvdec' in line:
|
||||
match = re.search(r'(\w+)\s+.*?(\w+(?:_cuvid|_nvdec))', line)
|
||||
if match:
|
||||
codec_type, decoder = match.groups()
|
||||
if 'h264' in decoder:
|
||||
self.codecs['h264_hw'] = decoder
|
||||
elif 'hevc' in decoder or 'h265' in decoder:
|
||||
self.codecs['h265_hw'] = decoder
|
||||
elif 'vaapi' in line:
|
||||
match = re.search(r'(\w+)\s+.*?(\w+_vaapi)', line)
|
||||
if match:
|
||||
codec_type, decoder = match.groups()
|
||||
if 'h264' in decoder:
|
||||
self.codecs['h264_vaapi'] = decoder
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to detect decoders: {e}")
|
||||
|
||||
def get_optimal_capture_options(self, codec: str = 'h264') -> Dict[str, str]:
|
||||
"""
|
||||
Get optimal FFmpeg capture options for the given codec.
|
||||
|
||||
Args:
|
||||
codec: Video codec (h264, h265, etc.)
|
||||
|
||||
Returns:
|
||||
Dictionary of FFmpeg options
|
||||
"""
|
||||
options = {
|
||||
'rtsp_transport': 'tcp',
|
||||
'buffer_size': '1024k',
|
||||
'max_delay': '500000', # 500ms
|
||||
'fflags': '+genpts',
|
||||
'flags': '+low_delay',
|
||||
'probesize': '32',
|
||||
'analyzeduration': '0'
|
||||
}
|
||||
|
||||
# Add hardware acceleration if available
|
||||
if self.nvidia_support:
|
||||
# Force enable CUDA hardware acceleration for H.264 if CUDA is available
|
||||
if codec == 'h264':
|
||||
options.update({
|
||||
'hwaccel': 'cuda',
|
||||
'hwaccel_device': '0'
|
||||
})
|
||||
logger.info("Using NVIDIA NVDEC hardware acceleration for H.264")
|
||||
elif codec == 'h265':
|
||||
options.update({
|
||||
'hwaccel': 'cuda',
|
||||
'hwaccel_device': '0',
|
||||
'video_codec': 'hevc_cuvid',
|
||||
'hwaccel_output_format': 'cuda'
|
||||
})
|
||||
logger.info("Using NVIDIA CUVID hardware acceleration for H.265")
|
||||
|
||||
elif self.vaapi_support:
|
||||
if codec == 'h264':
|
||||
options.update({
|
||||
'hwaccel': 'vaapi',
|
||||
'hwaccel_device': '/dev/dri/renderD128',
|
||||
'video_codec': 'h264_vaapi'
|
||||
})
|
||||
logger.debug("Using VAAPI hardware acceleration")
|
||||
|
||||
return options
|
||||
|
||||
def format_opencv_options(self, options: Dict[str, str]) -> str:
|
||||
"""
|
||||
Format options for OpenCV FFmpeg backend.
|
||||
|
||||
Args:
|
||||
options: Dictionary of FFmpeg options
|
||||
|
||||
Returns:
|
||||
Formatted options string for OpenCV
|
||||
"""
|
||||
return '|'.join(f"{key};{value}" for key, value in options.items())
|
||||
|
||||
def get_hardware_encoder_options(self, codec: str = 'h264', quality: str = 'fast') -> Dict[str, str]:
|
||||
"""
|
||||
Get optimal hardware encoding options.
|
||||
|
||||
Args:
|
||||
codec: Video codec for encoding
|
||||
quality: Quality preset (fast, medium, slow)
|
||||
|
||||
Returns:
|
||||
Dictionary of encoding options
|
||||
"""
|
||||
options = {}
|
||||
|
||||
if self.nvidia_support:
|
||||
if codec == 'h264':
|
||||
options.update({
|
||||
'video_codec': 'h264_nvenc',
|
||||
'preset': quality,
|
||||
'tune': 'zerolatency',
|
||||
'gpu': '0',
|
||||
'rc': 'cbr_hq',
|
||||
'surfaces': '64'
|
||||
})
|
||||
elif codec == 'h265':
|
||||
options.update({
|
||||
'video_codec': 'hevc_nvenc',
|
||||
'preset': quality,
|
||||
'tune': 'zerolatency',
|
||||
'gpu': '0'
|
||||
})
|
||||
|
||||
elif self.vaapi_support:
|
||||
if codec == 'h264':
|
||||
options.update({
|
||||
'video_codec': 'h264_vaapi',
|
||||
'vaapi_device': '/dev/dri/renderD128'
|
||||
})
|
||||
|
||||
return options
|
||||
|
||||
|
||||
# Global instance
|
||||
_ffmpeg_caps = None
|
||||
|
||||
def get_ffmpeg_capabilities() -> FFmpegCapabilities:
|
||||
"""Get or create the global FFmpeg capabilities instance."""
|
||||
global _ffmpeg_caps
|
||||
if _ffmpeg_caps is None:
|
||||
_ffmpeg_caps = FFmpegCapabilities()
|
||||
return _ffmpeg_caps
|
||||
|
||||
def get_optimal_rtsp_options(rtsp_url: str) -> str:
|
||||
"""
|
||||
Get optimal OpenCV FFmpeg options for RTSP streaming.
|
||||
|
||||
Args:
|
||||
rtsp_url: RTSP stream URL
|
||||
|
||||
Returns:
|
||||
Formatted options string for cv2.VideoCapture
|
||||
"""
|
||||
caps = get_ffmpeg_capabilities()
|
||||
|
||||
# Detect codec from URL or assume H.264
|
||||
codec = 'h265' if any(x in rtsp_url.lower() for x in ['h265', 'hevc']) else 'h264'
|
||||
|
||||
options = caps.get_optimal_capture_options(codec)
|
||||
return caps.format_opencv_options(options)
|
||||
|
|
@ -1,173 +0,0 @@
|
|||
"""
|
||||
Hardware-accelerated image encoding using NVIDIA NVENC or Intel QuickSync
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
import os
|
||||
|
||||
logger = logging.getLogger("detector_worker")
|
||||
|
||||
|
||||
class HardwareEncoder:
|
||||
"""Hardware-accelerated JPEG encoder using GPU."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize hardware encoder."""
|
||||
self.nvenc_available = False
|
||||
self.vaapi_available = False
|
||||
self.turbojpeg_available = False
|
||||
|
||||
# Check for TurboJPEG (fastest CPU-based option)
|
||||
try:
|
||||
from turbojpeg import TurboJPEG
|
||||
self.turbojpeg = TurboJPEG()
|
||||
self.turbojpeg_available = True
|
||||
logger.info("TurboJPEG accelerated encoding available")
|
||||
except ImportError:
|
||||
logger.debug("TurboJPEG not available")
|
||||
|
||||
# Check for NVIDIA NVENC support
|
||||
try:
|
||||
# Test if we can create an NVENC encoder
|
||||
test_frame = np.zeros((720, 1280, 3), dtype=np.uint8)
|
||||
fourcc = cv2.VideoWriter_fourcc(*'H264')
|
||||
test_writer = cv2.VideoWriter(
|
||||
"test.mp4",
|
||||
fourcc,
|
||||
30,
|
||||
(1280, 720),
|
||||
[cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_ANY]
|
||||
)
|
||||
if test_writer.isOpened():
|
||||
self.nvenc_available = True
|
||||
logger.info("NVENC hardware encoding available")
|
||||
test_writer.release()
|
||||
if os.path.exists("test.mp4"):
|
||||
os.remove("test.mp4")
|
||||
except Exception as e:
|
||||
logger.debug(f"NVENC not available: {e}")
|
||||
|
||||
def encode_jpeg(self, frame: np.ndarray, quality: int = 85) -> Optional[bytes]:
|
||||
"""
|
||||
Encode frame to JPEG using the fastest available method.
|
||||
|
||||
Args:
|
||||
frame: BGR image frame
|
||||
quality: JPEG quality (1-100)
|
||||
|
||||
Returns:
|
||||
Encoded JPEG bytes or None on failure
|
||||
"""
|
||||
try:
|
||||
# Method 1: TurboJPEG (3-5x faster than cv2.imencode)
|
||||
if self.turbojpeg_available:
|
||||
# Convert BGR to RGB for TurboJPEG
|
||||
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
encoded = self.turbojpeg.encode(rgb_frame, quality=quality)
|
||||
return encoded
|
||||
|
||||
# Method 2: Hardware-accelerated encoding via GStreamer (if available)
|
||||
if self.nvenc_available:
|
||||
return self._encode_with_nvenc(frame, quality)
|
||||
|
||||
# Fallback: Standard OpenCV encoding
|
||||
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
|
||||
success, encoded = cv2.imencode('.jpg', frame, encode_params)
|
||||
if success:
|
||||
return encoded.tobytes()
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to encode frame: {e}")
|
||||
return None
|
||||
|
||||
def _encode_with_nvenc(self, frame: np.ndarray, quality: int) -> Optional[bytes]:
|
||||
"""
|
||||
Encode using NVIDIA NVENC hardware encoder.
|
||||
|
||||
This is complex to implement directly, so we'll use a GStreamer pipeline
|
||||
if available.
|
||||
"""
|
||||
try:
|
||||
# Create a GStreamer pipeline for hardware encoding
|
||||
height, width = frame.shape[:2]
|
||||
gst_pipeline = (
|
||||
f"appsrc ! "
|
||||
f"video/x-raw,format=BGR,width={width},height={height},framerate=30/1 ! "
|
||||
f"videoconvert ! "
|
||||
f"nvvideoconvert ! " # GPU color conversion
|
||||
f"nvjpegenc quality={quality} ! " # Hardware JPEG encoder
|
||||
f"appsink"
|
||||
)
|
||||
|
||||
# This would require GStreamer Python bindings
|
||||
# For now, fall back to TurboJPEG or standard encoding
|
||||
logger.debug("NVENC JPEG encoding not fully implemented, using fallback")
|
||||
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
|
||||
success, encoded = cv2.imencode('.jpg', frame, encode_params)
|
||||
if success:
|
||||
return encoded.tobytes()
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"NVENC encoding failed: {e}")
|
||||
return None
|
||||
|
||||
def encode_batch(self, frames: list, quality: int = 85) -> list:
|
||||
"""
|
||||
Batch encode multiple frames for better GPU utilization.
|
||||
|
||||
Args:
|
||||
frames: List of BGR frames
|
||||
quality: JPEG quality
|
||||
|
||||
Returns:
|
||||
List of encoded JPEG bytes
|
||||
"""
|
||||
encoded_frames = []
|
||||
|
||||
if self.turbojpeg_available:
|
||||
# TurboJPEG can handle batch encoding efficiently
|
||||
for frame in frames:
|
||||
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
encoded = self.turbojpeg.encode(rgb_frame, quality=quality)
|
||||
encoded_frames.append(encoded)
|
||||
else:
|
||||
# Fallback to sequential encoding
|
||||
for frame in frames:
|
||||
encoded = self.encode_jpeg(frame, quality)
|
||||
encoded_frames.append(encoded)
|
||||
|
||||
return encoded_frames
|
||||
|
||||
|
||||
# Global encoder instance
|
||||
_hardware_encoder = None
|
||||
|
||||
|
||||
def get_hardware_encoder() -> HardwareEncoder:
|
||||
"""Get or create the global hardware encoder instance."""
|
||||
global _hardware_encoder
|
||||
if _hardware_encoder is None:
|
||||
_hardware_encoder = HardwareEncoder()
|
||||
return _hardware_encoder
|
||||
|
||||
|
||||
def encode_frame_hardware(frame: np.ndarray, quality: int = 85) -> Optional[bytes]:
|
||||
"""
|
||||
Convenience function to encode a frame using hardware acceleration.
|
||||
|
||||
Args:
|
||||
frame: BGR image frame
|
||||
quality: JPEG quality (1-100)
|
||||
|
||||
Returns:
|
||||
Encoded JPEG bytes or None on failure
|
||||
"""
|
||||
encoder = get_hardware_encoder()
|
||||
return encoder.encode_jpeg(frame, quality)
|
||||
4
debug/cuda.py
Normal file
4
debug/cuda.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
import torch
|
||||
print(torch.cuda.is_available()) # True if CUDA is available
|
||||
print(torch.cuda.get_device_name(0)) # GPU name
|
||||
print(torch.version.cuda) # CUDA version PyTorch was compiled with
|
||||
1
feeder/note.txt
Normal file
1
feeder/note.txt
Normal file
|
|
@ -0,0 +1 @@
|
|||
python simple_track.py --source video/sample.mp4 --show-vid --save-vid --enable-json-log
|
||||
0
feeder/sender/__init__.py
Normal file
0
feeder/sender/__init__.py
Normal file
21
feeder/sender/base.py
Normal file
21
feeder/sender/base.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
class NumpyArrayEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
elif isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
else:
|
||||
return super(NumpyArrayEncoder, self).default(obj)
|
||||
|
||||
class BasSender:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def send(self, messages):
|
||||
raise NotImplementedError()
|
||||
13
feeder/sender/jsonlogger.py
Normal file
13
feeder/sender/jsonlogger.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from .base import BasSender
|
||||
from loguru import logger
|
||||
import json
|
||||
from .base import NumpyArrayEncoder
|
||||
|
||||
class JsonLogger(BasSender):
|
||||
def __init__(self, log_filename:str = "tracking.log") -> None:
|
||||
super().__init__()
|
||||
self.logger = logger
|
||||
self.logger.add(log_filename, format="{message}", level="INFO")
|
||||
|
||||
def send(self, messages):
|
||||
self.logger.info(json.dumps(messages, cls=NumpyArrayEncoder))
|
||||
14
feeder/sender/szmq.py
Normal file
14
feeder/sender/szmq.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from .base import BasSender, NumpyArrayEncoder
|
||||
import zmq
|
||||
import json
|
||||
|
||||
|
||||
class ZmqLogger(BasSender):
|
||||
def __init__(self, ip_addr:str = "localhost", port:int = 5555) -> None:
|
||||
super().__init__()
|
||||
self.context = zmq.Context()
|
||||
self.producer = self.context.socket(zmq.PUB)
|
||||
self.producer.connect(f"tcp://{ip_addr}:{port}")
|
||||
|
||||
def send(self, messages):
|
||||
self.producer.send_string(json.dumps(messages, cls = NumpyArrayEncoder))
|
||||
245
feeder/simple_track.py
Normal file
245
feeder/simple_track.py
Normal file
|
|
@ -0,0 +1,245 @@
|
|||
import argparse
|
||||
import cv2
|
||||
import os
|
||||
os.environ["OMP_NUM_THREADS"] = "1"
|
||||
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
||||
os.environ["MKL_NUM_THREADS"] = "1"
|
||||
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
|
||||
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[0]
|
||||
WEIGHTS = ROOT / 'weights'
|
||||
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.append(str(ROOT))
|
||||
if str(ROOT / 'trackers' / 'strongsort') not in sys.path:
|
||||
sys.path.append(str(ROOT / 'trackers' / 'strongsort'))
|
||||
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
|
||||
from ultralytics.yolo.data.utils import VID_FORMATS
|
||||
from ultralytics.yolo.utils import LOGGER, colorstr
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.torch_utils import select_device
|
||||
from ultralytics.yolo.utils.ops import Profile, non_max_suppression, scale_boxes
|
||||
from ultralytics.yolo.utils.plotting import Annotator, colors
|
||||
|
||||
from trackers.multi_tracker_zoo import create_tracker
|
||||
from sender.jsonlogger import JsonLogger
|
||||
from sender.szmq import ZmqLogger
|
||||
|
||||
@torch.no_grad()
|
||||
def run(
|
||||
source='0',
|
||||
yolo_weights=WEIGHTS / 'yolov8n.pt',
|
||||
reid_weights=WEIGHTS / 'osnet_x0_25_msmt17.pt',
|
||||
imgsz=(640, 640),
|
||||
conf_thres=0.7,
|
||||
iou_thres=0.45,
|
||||
max_det=1000,
|
||||
device='',
|
||||
show_vid=True,
|
||||
save_vid=True,
|
||||
project=ROOT / 'runs' / 'track',
|
||||
name='exp',
|
||||
exist_ok=False,
|
||||
line_thickness=2,
|
||||
hide_labels=False,
|
||||
hide_conf=False,
|
||||
half=False,
|
||||
vid_stride=1,
|
||||
enable_json_log=False,
|
||||
enable_zmq=False,
|
||||
zmq_ip='localhost',
|
||||
zmq_port=5555,
|
||||
):
|
||||
source = str(source)
|
||||
is_file = Path(source).suffix[1:] in (VID_FORMATS)
|
||||
|
||||
if is_file:
|
||||
source = check_file(source)
|
||||
|
||||
device = select_device(device)
|
||||
|
||||
model = AutoBackend(yolo_weights, device=device, dnn=False, fp16=half)
|
||||
stride, names, pt = model.stride, model.names, model.pt
|
||||
imgsz = check_imgsz(imgsz, stride=stride)
|
||||
|
||||
dataset = LoadImages(
|
||||
source,
|
||||
imgsz=imgsz,
|
||||
stride=stride,
|
||||
auto=pt,
|
||||
transforms=getattr(model.model, 'transforms', None),
|
||||
vid_stride=vid_stride
|
||||
)
|
||||
bs = len(dataset)
|
||||
|
||||
tracking_config = ROOT / 'trackers' / 'strongsort' / 'configs' / 'strongsort.yaml'
|
||||
tracker = create_tracker('strongsort', tracking_config, reid_weights, device, half)
|
||||
|
||||
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok)
|
||||
(save_dir / 'tracks').mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize loggers
|
||||
json_logger = JsonLogger(f"{source}-strongsort.log") if enable_json_log else None
|
||||
zmq_logger = ZmqLogger(zmq_ip, zmq_port) if enable_zmq else None
|
||||
|
||||
vid_path, vid_writer = [None] * bs, [None] * bs
|
||||
dt = (Profile(), Profile(), Profile())
|
||||
|
||||
for frame_idx, (path, im, im0s, vid_cap, s) in enumerate(dataset):
|
||||
|
||||
with dt[0]:
|
||||
im = torch.from_numpy(im).to(model.device)
|
||||
im = im.half() if model.fp16 else im.float()
|
||||
im /= 255.0
|
||||
if len(im.shape) == 3:
|
||||
im = im[None]
|
||||
|
||||
with dt[1]:
|
||||
pred = model(im, augment=False, visualize=False)
|
||||
|
||||
with dt[2]:
|
||||
pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=max_det)
|
||||
|
||||
for i, det in enumerate(pred):
|
||||
seen = 0
|
||||
p, im0, _ = path, im0s.copy(), dataset.count
|
||||
p = Path(p)
|
||||
|
||||
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
|
||||
|
||||
if len(det):
|
||||
# Filter detections for 'car' class only (class 2 in COCO dataset)
|
||||
car_mask = det[:, 5] == 2 # car class index is 2
|
||||
det = det[car_mask]
|
||||
|
||||
if len(det):
|
||||
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
|
||||
|
||||
for *xyxy, conf, cls in reversed(det):
|
||||
c = int(cls)
|
||||
id = f'{c}'
|
||||
label = None if hide_labels else (f'{id} {names[c]}' if hide_conf else f'{id} {names[c]} {conf:.2f}')
|
||||
annotator.box_label(xyxy, label, color=colors(c, True))
|
||||
|
||||
t_outputs = tracker.update(det.cpu(), im0)
|
||||
|
||||
if len(t_outputs) > 0:
|
||||
for j, (output) in enumerate(t_outputs):
|
||||
bbox = output[0:4]
|
||||
id = output[4]
|
||||
cls = output[5]
|
||||
conf = output[6]
|
||||
|
||||
# Log tracking data
|
||||
if json_logger or zmq_logger:
|
||||
track_data = {
|
||||
'bbox': bbox.tolist() if hasattr(bbox, 'tolist') else list(bbox),
|
||||
'id': int(id),
|
||||
'cls': int(cls),
|
||||
'conf': float(conf),
|
||||
'frame_idx': frame_idx,
|
||||
'source': source,
|
||||
'class_name': names[int(cls)]
|
||||
}
|
||||
|
||||
if json_logger:
|
||||
json_logger.send(track_data)
|
||||
if zmq_logger:
|
||||
zmq_logger.send(track_data)
|
||||
|
||||
if save_vid or show_vid:
|
||||
c = int(cls)
|
||||
id = int(id)
|
||||
label = f'{id} {names[c]}' if not hide_labels else f'{id}'
|
||||
if not hide_conf:
|
||||
label += f' {conf:.2f}'
|
||||
annotator.box_label(bbox, label, color=colors(c, True))
|
||||
|
||||
im0 = annotator.result()
|
||||
|
||||
if show_vid:
|
||||
cv2.imshow(str(p), im0)
|
||||
if cv2.waitKey(1) == ord('q'):
|
||||
break
|
||||
|
||||
if save_vid:
|
||||
if vid_path[i] != str(save_dir / p.name):
|
||||
vid_path[i] = str(save_dir / p.name)
|
||||
if isinstance(vid_writer[i], cv2.VideoWriter):
|
||||
vid_writer[i].release()
|
||||
|
||||
if vid_cap:
|
||||
fps = vid_cap.get(cv2.CAP_PROP_FPS)
|
||||
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
else:
|
||||
fps, w, h = 30, im0.shape[1], im0.shape[0]
|
||||
|
||||
vid_writer[i] = cv2.VideoWriter(vid_path[i], cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
|
||||
vid_writer[i].write(im0)
|
||||
|
||||
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
|
||||
|
||||
for i, vid_writer_obj in enumerate(vid_writer):
|
||||
if isinstance(vid_writer_obj, cv2.VideoWriter):
|
||||
vid_writer_obj.release()
|
||||
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
|
||||
|
||||
def xyxy2xywh(x):
|
||||
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
|
||||
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
|
||||
y[:, 2] = x[:, 2] - x[:, 0] # width
|
||||
y[:, 3] = x[:, 3] - x[:, 1] # height
|
||||
return y
|
||||
|
||||
def parse_opt():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--source', type=str, default='0', help='file/dir/URL/glob, 0 for webcam')
|
||||
parser.add_argument('--yolo-weights', nargs='+', type=str, default=WEIGHTS / 'yolov8n.pt', help='model path')
|
||||
parser.add_argument('--reid-weights', type=str, default=WEIGHTS / 'osnet_x0_25_msmt17.pt')
|
||||
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
|
||||
parser.add_argument('--conf-thres', type=float, default=0.7, help='confidence threshold')
|
||||
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
|
||||
parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
|
||||
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||
parser.add_argument('--show-vid', action='store_true', help='display results')
|
||||
parser.add_argument('--save-vid', action='store_true', help='save video tracking results')
|
||||
parser.add_argument('--project', default=ROOT / 'runs' / 'track', help='save results to project/name')
|
||||
parser.add_argument('--name', default='exp', help='save results to project/name')
|
||||
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
|
||||
parser.add_argument('--line-thickness', default=2, type=int, help='bounding box thickness (pixels)')
|
||||
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
|
||||
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
|
||||
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
|
||||
parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
|
||||
parser.add_argument('--enable-json-log', action='store_true', help='enable JSON file logging')
|
||||
parser.add_argument('--enable-zmq', action='store_true', help='enable ZMQ messaging')
|
||||
parser.add_argument('--zmq-ip', type=str, default='localhost', help='ZMQ server IP')
|
||||
parser.add_argument('--zmq-port', type=int, default=5555, help='ZMQ server port')
|
||||
|
||||
opt = parser.parse_args()
|
||||
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1
|
||||
return opt
|
||||
|
||||
def main(opt):
|
||||
run(**vars(opt))
|
||||
|
||||
if __name__ == "__main__":
|
||||
opt = parse_opt()
|
||||
main(opt)
|
||||
0
feeder/trackers/__init__.py
Normal file
0
feeder/trackers/__init__.py
Normal file
60
feeder/trackers/botsort/basetrack.py
Normal file
60
feeder/trackers/botsort/basetrack.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class TrackState(object):
|
||||
New = 0
|
||||
Tracked = 1
|
||||
Lost = 2
|
||||
LongLost = 3
|
||||
Removed = 4
|
||||
|
||||
|
||||
class BaseTrack(object):
|
||||
_count = 0
|
||||
|
||||
track_id = 0
|
||||
is_activated = False
|
||||
state = TrackState.New
|
||||
|
||||
history = OrderedDict()
|
||||
features = []
|
||||
curr_feature = None
|
||||
score = 0
|
||||
start_frame = 0
|
||||
frame_id = 0
|
||||
time_since_update = 0
|
||||
|
||||
# multi-camera
|
||||
location = (np.inf, np.inf)
|
||||
|
||||
@property
|
||||
def end_frame(self):
|
||||
return self.frame_id
|
||||
|
||||
@staticmethod
|
||||
def next_id():
|
||||
BaseTrack._count += 1
|
||||
return BaseTrack._count
|
||||
|
||||
def activate(self, *args):
|
||||
raise NotImplementedError
|
||||
|
||||
def predict(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def mark_lost(self):
|
||||
self.state = TrackState.Lost
|
||||
|
||||
def mark_long_lost(self):
|
||||
self.state = TrackState.LongLost
|
||||
|
||||
def mark_removed(self):
|
||||
self.state = TrackState.Removed
|
||||
|
||||
@staticmethod
|
||||
def clear_count():
|
||||
BaseTrack._count = 0
|
||||
534
feeder/trackers/botsort/bot_sort.py
Normal file
534
feeder/trackers/botsort/bot_sort.py
Normal file
|
|
@ -0,0 +1,534 @@
|
|||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
|
||||
from trackers.botsort import matching
|
||||
from trackers.botsort.gmc import GMC
|
||||
from trackers.botsort.basetrack import BaseTrack, TrackState
|
||||
from trackers.botsort.kalman_filter import KalmanFilter
|
||||
|
||||
# from fast_reid.fast_reid_interfece import FastReIDInterface
|
||||
|
||||
from reid_multibackend import ReIDDetectMultiBackend
|
||||
from ultralytics.yolo.utils.ops import xyxy2xywh, xywh2xyxy
|
||||
|
||||
|
||||
class STrack(BaseTrack):
|
||||
shared_kalman = KalmanFilter()
|
||||
|
||||
def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
|
||||
|
||||
# wait activate
|
||||
self._tlwh = np.asarray(tlwh, dtype=np.float32)
|
||||
self.kalman_filter = None
|
||||
self.mean, self.covariance = None, None
|
||||
self.is_activated = False
|
||||
|
||||
self.cls = -1
|
||||
self.cls_hist = [] # (cls id, freq)
|
||||
self.update_cls(cls, score)
|
||||
|
||||
self.score = score
|
||||
self.tracklet_len = 0
|
||||
|
||||
self.smooth_feat = None
|
||||
self.curr_feat = None
|
||||
if feat is not None:
|
||||
self.update_features(feat)
|
||||
self.features = deque([], maxlen=feat_history)
|
||||
self.alpha = 0.9
|
||||
|
||||
def update_features(self, feat):
|
||||
feat /= np.linalg.norm(feat)
|
||||
self.curr_feat = feat
|
||||
if self.smooth_feat is None:
|
||||
self.smooth_feat = feat
|
||||
else:
|
||||
self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat
|
||||
self.features.append(feat)
|
||||
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
|
||||
|
||||
def update_cls(self, cls, score):
|
||||
if len(self.cls_hist) > 0:
|
||||
max_freq = 0
|
||||
found = False
|
||||
for c in self.cls_hist:
|
||||
if cls == c[0]:
|
||||
c[1] += score
|
||||
found = True
|
||||
|
||||
if c[1] > max_freq:
|
||||
max_freq = c[1]
|
||||
self.cls = c[0]
|
||||
if not found:
|
||||
self.cls_hist.append([cls, score])
|
||||
self.cls = cls
|
||||
else:
|
||||
self.cls_hist.append([cls, score])
|
||||
self.cls = cls
|
||||
|
||||
def predict(self):
|
||||
mean_state = self.mean.copy()
|
||||
if self.state != TrackState.Tracked:
|
||||
mean_state[6] = 0
|
||||
mean_state[7] = 0
|
||||
|
||||
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
||||
|
||||
@staticmethod
|
||||
def multi_predict(stracks):
|
||||
if len(stracks) > 0:
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
for i, st in enumerate(stracks):
|
||||
if st.state != TrackState.Tracked:
|
||||
multi_mean[i][6] = 0
|
||||
multi_mean[i][7] = 0
|
||||
multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
|
||||
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
||||
stracks[i].mean = mean
|
||||
stracks[i].covariance = cov
|
||||
|
||||
@staticmethod
|
||||
def multi_gmc(stracks, H=np.eye(2, 3)):
|
||||
if len(stracks) > 0:
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
|
||||
R = H[:2, :2]
|
||||
R8x8 = np.kron(np.eye(4, dtype=float), R)
|
||||
t = H[:2, 2]
|
||||
|
||||
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
||||
mean = R8x8.dot(mean)
|
||||
mean[:2] += t
|
||||
cov = R8x8.dot(cov).dot(R8x8.transpose())
|
||||
|
||||
stracks[i].mean = mean
|
||||
stracks[i].covariance = cov
|
||||
|
||||
def activate(self, kalman_filter, frame_id):
|
||||
"""Start a new tracklet"""
|
||||
self.kalman_filter = kalman_filter
|
||||
self.track_id = self.next_id()
|
||||
|
||||
self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xywh(self._tlwh))
|
||||
|
||||
self.tracklet_len = 0
|
||||
self.state = TrackState.Tracked
|
||||
if frame_id == 1:
|
||||
self.is_activated = True
|
||||
self.frame_id = frame_id
|
||||
self.start_frame = frame_id
|
||||
|
||||
def re_activate(self, new_track, frame_id, new_id=False):
|
||||
|
||||
self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, self.tlwh_to_xywh(new_track.tlwh))
|
||||
if new_track.curr_feat is not None:
|
||||
self.update_features(new_track.curr_feat)
|
||||
self.tracklet_len = 0
|
||||
self.state = TrackState.Tracked
|
||||
self.is_activated = True
|
||||
self.frame_id = frame_id
|
||||
if new_id:
|
||||
self.track_id = self.next_id()
|
||||
self.score = new_track.score
|
||||
|
||||
self.update_cls(new_track.cls, new_track.score)
|
||||
|
||||
def update(self, new_track, frame_id):
|
||||
"""
|
||||
Update a matched track
|
||||
:type new_track: STrack
|
||||
:type frame_id: int
|
||||
:type update_feature: bool
|
||||
:return:
|
||||
"""
|
||||
self.frame_id = frame_id
|
||||
self.tracklet_len += 1
|
||||
|
||||
new_tlwh = new_track.tlwh
|
||||
|
||||
self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, self.tlwh_to_xywh(new_tlwh))
|
||||
|
||||
if new_track.curr_feat is not None:
|
||||
self.update_features(new_track.curr_feat)
|
||||
|
||||
self.state = TrackState.Tracked
|
||||
self.is_activated = True
|
||||
|
||||
self.score = new_track.score
|
||||
self.update_cls(new_track.cls, new_track.score)
|
||||
|
||||
@property
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y,
|
||||
width, height)`.
|
||||
"""
|
||||
if self.mean is None:
|
||||
return self._tlwh.copy()
|
||||
ret = self.mean[:4].copy()
|
||||
ret[:2] -= ret[2:] / 2
|
||||
return ret
|
||||
|
||||
@property
|
||||
def tlbr(self):
|
||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
||||
`(top left, bottom right)`.
|
||||
"""
|
||||
ret = self.tlwh.copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
@property
|
||||
def xywh(self):
|
||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
||||
`(top left, bottom right)`.
|
||||
"""
|
||||
ret = self.tlwh.copy()
|
||||
ret[:2] += ret[2:] / 2.0
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_xyah(tlwh):
|
||||
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
||||
height)`, where the aspect ratio is `width / height`.
|
||||
"""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
ret[2] /= ret[3]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_xywh(tlwh):
|
||||
"""Convert bounding box to format `(center x, center y, width,
|
||||
height)`.
|
||||
"""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
return ret
|
||||
|
||||
def to_xywh(self):
|
||||
return self.tlwh_to_xywh(self.tlwh)
|
||||
|
||||
@staticmethod
|
||||
def tlbr_to_tlwh(tlbr):
|
||||
ret = np.asarray(tlbr).copy()
|
||||
ret[2:] -= ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_tlbr(tlwh):
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
def __repr__(self):
|
||||
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
|
||||
|
||||
|
||||
class BoTSORT(object):
|
||||
def __init__(self,
|
||||
model_weights,
|
||||
device,
|
||||
fp16,
|
||||
track_high_thresh:float = 0.45,
|
||||
new_track_thresh:float = 0.6,
|
||||
track_buffer:int = 30,
|
||||
match_thresh:float = 0.8,
|
||||
proximity_thresh:float = 0.5,
|
||||
appearance_thresh:float = 0.25,
|
||||
cmc_method:str = 'sparseOptFlow',
|
||||
frame_rate=30,
|
||||
lambda_=0.985
|
||||
):
|
||||
|
||||
self.tracked_stracks = [] # type: list[STrack]
|
||||
self.lost_stracks = [] # type: list[STrack]
|
||||
self.removed_stracks = [] # type: list[STrack]
|
||||
BaseTrack.clear_count()
|
||||
|
||||
self.frame_id = 0
|
||||
|
||||
self.lambda_ = lambda_
|
||||
self.track_high_thresh = track_high_thresh
|
||||
self.new_track_thresh = new_track_thresh
|
||||
|
||||
self.buffer_size = int(frame_rate / 30.0 * track_buffer)
|
||||
self.max_time_lost = self.buffer_size
|
||||
self.kalman_filter = KalmanFilter()
|
||||
|
||||
# ReID module
|
||||
self.proximity_thresh = proximity_thresh
|
||||
self.appearance_thresh = appearance_thresh
|
||||
self.match_thresh = match_thresh
|
||||
|
||||
self.model = ReIDDetectMultiBackend(weights=model_weights, device=device, fp16=fp16)
|
||||
|
||||
self.gmc = GMC(method=cmc_method, verbose=[None,False])
|
||||
|
||||
def update(self, output_results, img):
|
||||
self.frame_id += 1
|
||||
activated_starcks = []
|
||||
refind_stracks = []
|
||||
lost_stracks = []
|
||||
removed_stracks = []
|
||||
|
||||
xyxys = output_results[:, 0:4]
|
||||
xywh = xyxy2xywh(xyxys.numpy())
|
||||
confs = output_results[:, 4]
|
||||
clss = output_results[:, 5]
|
||||
|
||||
classes = clss.numpy()
|
||||
xyxys = xyxys.numpy()
|
||||
confs = confs.numpy()
|
||||
|
||||
remain_inds = confs > self.track_high_thresh
|
||||
inds_low = confs > 0.1
|
||||
inds_high = confs < self.track_high_thresh
|
||||
|
||||
inds_second = np.logical_and(inds_low, inds_high)
|
||||
|
||||
dets_second = xywh[inds_second]
|
||||
dets = xywh[remain_inds]
|
||||
|
||||
scores_keep = confs[remain_inds]
|
||||
scores_second = confs[inds_second]
|
||||
|
||||
classes_keep = classes[remain_inds]
|
||||
clss_second = classes[inds_second]
|
||||
|
||||
self.height, self.width = img.shape[:2]
|
||||
|
||||
'''Extract embeddings '''
|
||||
features_keep = self._get_features(dets, img)
|
||||
|
||||
if len(dets) > 0:
|
||||
'''Detections'''
|
||||
|
||||
detections = [STrack(xyxy, s, c, f.cpu().numpy()) for
|
||||
(xyxy, s, c, f) in zip(dets, scores_keep, classes_keep, features_keep)]
|
||||
else:
|
||||
detections = []
|
||||
|
||||
''' Add newly detected tracklets to tracked_stracks'''
|
||||
unconfirmed = []
|
||||
tracked_stracks = [] # type: list[STrack]
|
||||
for track in self.tracked_stracks:
|
||||
if not track.is_activated:
|
||||
unconfirmed.append(track)
|
||||
else:
|
||||
tracked_stracks.append(track)
|
||||
|
||||
''' Step 2: First association, with high score detection boxes'''
|
||||
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
|
||||
|
||||
# Predict the current location with KF
|
||||
STrack.multi_predict(strack_pool)
|
||||
|
||||
# Fix camera motion
|
||||
warp = self.gmc.apply(img, dets)
|
||||
STrack.multi_gmc(strack_pool, warp)
|
||||
STrack.multi_gmc(unconfirmed, warp)
|
||||
|
||||
# Associate with high score detection boxes
|
||||
raw_emb_dists = matching.embedding_distance(strack_pool, detections)
|
||||
dists = matching.fuse_motion(self.kalman_filter, raw_emb_dists, strack_pool, detections, only_position=False, lambda_=self.lambda_)
|
||||
|
||||
# ious_dists = matching.iou_distance(strack_pool, detections)
|
||||
# ious_dists_mask = (ious_dists > self.proximity_thresh)
|
||||
|
||||
# ious_dists = matching.fuse_score(ious_dists, detections)
|
||||
|
||||
# emb_dists = matching.embedding_distance(strack_pool, detections) / 2.0
|
||||
# raw_emb_dists = emb_dists.copy()
|
||||
# emb_dists[emb_dists > self.appearance_thresh] = 1.0
|
||||
# emb_dists[ious_dists_mask] = 1.0
|
||||
# dists = np.minimum(ious_dists, emb_dists)
|
||||
|
||||
# Popular ReID method (JDE / FairMOT)
|
||||
# raw_emb_dists = matching.embedding_distance(strack_pool, detections)
|
||||
# dists = matching.fuse_motion(self.kalman_filter, raw_emb_dists, strack_pool, detections)
|
||||
# emb_dists = dists
|
||||
|
||||
# IoU making ReID
|
||||
# dists = matching.embedding_distance(strack_pool, detections)
|
||||
# dists[ious_dists_mask] = 1.0
|
||||
|
||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.match_thresh)
|
||||
|
||||
for itracked, idet in matches:
|
||||
track = strack_pool[itracked]
|
||||
det = detections[idet]
|
||||
if track.state == TrackState.Tracked:
|
||||
track.update(detections[idet], self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
|
||||
''' Step 3: Second association, with low score detection boxes'''
|
||||
# if len(scores):
|
||||
# inds_high = scores < self.track_high_thresh
|
||||
# inds_low = scores > self.track_low_thresh
|
||||
# inds_second = np.logical_and(inds_low, inds_high)
|
||||
# dets_second = bboxes[inds_second]
|
||||
# scores_second = scores[inds_second]
|
||||
# classes_second = classes[inds_second]
|
||||
# else:
|
||||
# dets_second = []
|
||||
# scores_second = []
|
||||
# classes_second = []
|
||||
|
||||
# association the untrack to the low score detections
|
||||
if len(dets_second) > 0:
|
||||
'''Detections'''
|
||||
detections_second = [STrack(STrack.tlbr_to_tlwh(tlbr), s, c) for
|
||||
(tlbr, s, c) in zip(dets_second, scores_second, clss_second)]
|
||||
else:
|
||||
detections_second = []
|
||||
|
||||
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
|
||||
dists = matching.iou_distance(r_tracked_stracks, detections_second)
|
||||
matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)
|
||||
for itracked, idet in matches:
|
||||
track = r_tracked_stracks[itracked]
|
||||
det = detections_second[idet]
|
||||
if track.state == TrackState.Tracked:
|
||||
track.update(det, self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
|
||||
for it in u_track:
|
||||
track = r_tracked_stracks[it]
|
||||
if not track.state == TrackState.Lost:
|
||||
track.mark_lost()
|
||||
lost_stracks.append(track)
|
||||
|
||||
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
|
||||
detections = [detections[i] for i in u_detection]
|
||||
ious_dists = matching.iou_distance(unconfirmed, detections)
|
||||
ious_dists_mask = (ious_dists > self.proximity_thresh)
|
||||
|
||||
ious_dists = matching.fuse_score(ious_dists, detections)
|
||||
|
||||
emb_dists = matching.embedding_distance(unconfirmed, detections) / 2.0
|
||||
raw_emb_dists = emb_dists.copy()
|
||||
emb_dists[emb_dists > self.appearance_thresh] = 1.0
|
||||
emb_dists[ious_dists_mask] = 1.0
|
||||
dists = np.minimum(ious_dists, emb_dists)
|
||||
|
||||
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
||||
for itracked, idet in matches:
|
||||
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
||||
activated_starcks.append(unconfirmed[itracked])
|
||||
for it in u_unconfirmed:
|
||||
track = unconfirmed[it]
|
||||
track.mark_removed()
|
||||
removed_stracks.append(track)
|
||||
|
||||
""" Step 4: Init new stracks"""
|
||||
for inew in u_detection:
|
||||
track = detections[inew]
|
||||
if track.score < self.new_track_thresh:
|
||||
continue
|
||||
|
||||
track.activate(self.kalman_filter, self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
|
||||
""" Step 5: Update state"""
|
||||
for track in self.lost_stracks:
|
||||
if self.frame_id - track.end_frame > self.max_time_lost:
|
||||
track.mark_removed()
|
||||
removed_stracks.append(track)
|
||||
|
||||
""" Merge """
|
||||
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
|
||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
|
||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
|
||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
|
||||
self.lost_stracks.extend(lost_stracks)
|
||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
|
||||
self.removed_stracks.extend(removed_stracks)
|
||||
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
|
||||
|
||||
# output_stracks = [track for track in self.tracked_stracks if track.is_activated]
|
||||
output_stracks = [track for track in self.tracked_stracks if track.is_activated]
|
||||
outputs = []
|
||||
for t in output_stracks:
|
||||
output= []
|
||||
tlwh = t.tlwh
|
||||
tid = t.track_id
|
||||
tlwh = np.expand_dims(tlwh, axis=0)
|
||||
xyxy = xywh2xyxy(tlwh)
|
||||
xyxy = np.squeeze(xyxy, axis=0)
|
||||
output.extend(xyxy)
|
||||
output.append(tid)
|
||||
output.append(t.cls)
|
||||
output.append(t.score)
|
||||
outputs.append(output)
|
||||
|
||||
return outputs
|
||||
|
||||
def _xywh_to_xyxy(self, bbox_xywh):
|
||||
x, y, w, h = bbox_xywh
|
||||
x1 = max(int(x - w / 2), 0)
|
||||
x2 = min(int(x + w / 2), self.width - 1)
|
||||
y1 = max(int(y - h / 2), 0)
|
||||
y2 = min(int(y + h / 2), self.height - 1)
|
||||
return x1, y1, x2, y2
|
||||
|
||||
def _get_features(self, bbox_xywh, ori_img):
|
||||
im_crops = []
|
||||
for box in bbox_xywh:
|
||||
x1, y1, x2, y2 = self._xywh_to_xyxy(box)
|
||||
im = ori_img[y1:y2, x1:x2]
|
||||
im_crops.append(im)
|
||||
if im_crops:
|
||||
features = self.model(im_crops)
|
||||
else:
|
||||
features = np.array([])
|
||||
return features
|
||||
|
||||
def joint_stracks(tlista, tlistb):
|
||||
exists = {}
|
||||
res = []
|
||||
for t in tlista:
|
||||
exists[t.track_id] = 1
|
||||
res.append(t)
|
||||
for t in tlistb:
|
||||
tid = t.track_id
|
||||
if not exists.get(tid, 0):
|
||||
exists[tid] = 1
|
||||
res.append(t)
|
||||
return res
|
||||
|
||||
|
||||
def sub_stracks(tlista, tlistb):
|
||||
stracks = {}
|
||||
for t in tlista:
|
||||
stracks[t.track_id] = t
|
||||
for t in tlistb:
|
||||
tid = t.track_id
|
||||
if stracks.get(tid, 0):
|
||||
del stracks[tid]
|
||||
return list(stracks.values())
|
||||
|
||||
|
||||
def remove_duplicate_stracks(stracksa, stracksb):
|
||||
pdist = matching.iou_distance(stracksa, stracksb)
|
||||
pairs = np.where(pdist < 0.15)
|
||||
dupa, dupb = list(), list()
|
||||
for p, q in zip(*pairs):
|
||||
timep = stracksa[p].frame_id - stracksa[p].start_frame
|
||||
timeq = stracksb[q].frame_id - stracksb[q].start_frame
|
||||
if timep > timeq:
|
||||
dupb.append(q)
|
||||
else:
|
||||
dupa.append(p)
|
||||
resa = [t for i, t in enumerate(stracksa) if not i in dupa]
|
||||
resb = [t for i, t in enumerate(stracksb) if not i in dupb]
|
||||
return resa, resb
|
||||
13
feeder/trackers/botsort/configs/botsort.yaml
Normal file
13
feeder/trackers/botsort/configs/botsort.yaml
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
# Trial number: 232
|
||||
# HOTA, MOTA, IDF1: [45.31]
|
||||
botsort:
|
||||
appearance_thresh: 0.4818211117541298
|
||||
cmc_method: sparseOptFlow
|
||||
conf_thres: 0.3501265956918775
|
||||
frame_rate: 30
|
||||
lambda_: 0.9896143462366406
|
||||
match_thresh: 0.22734550911325851
|
||||
new_track_thresh: 0.21144301345190655
|
||||
proximity_thresh: 0.5945380911899254
|
||||
track_buffer: 60
|
||||
track_high_thresh: 0.33824964456239337
|
||||
316
feeder/trackers/botsort/gmc.py
Normal file
316
feeder/trackers/botsort/gmc.py
Normal file
|
|
@ -0,0 +1,316 @@
|
|||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import copy
|
||||
import time
|
||||
|
||||
|
||||
class GMC:
|
||||
def __init__(self, method='sparseOptFlow', downscale=2, verbose=None):
|
||||
super(GMC, self).__init__()
|
||||
|
||||
self.method = method
|
||||
self.downscale = max(1, int(downscale))
|
||||
|
||||
if self.method == 'orb':
|
||||
self.detector = cv2.FastFeatureDetector_create(20)
|
||||
self.extractor = cv2.ORB_create()
|
||||
self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
|
||||
|
||||
elif self.method == 'sift':
|
||||
self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
|
||||
self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
|
||||
self.matcher = cv2.BFMatcher(cv2.NORM_L2)
|
||||
|
||||
elif self.method == 'ecc':
|
||||
number_of_iterations = 5000
|
||||
termination_eps = 1e-6
|
||||
self.warp_mode = cv2.MOTION_EUCLIDEAN
|
||||
self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
|
||||
|
||||
elif self.method == 'sparseOptFlow':
|
||||
self.feature_params = dict(maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3,
|
||||
useHarrisDetector=False, k=0.04)
|
||||
# self.gmc_file = open('GMC_results.txt', 'w')
|
||||
|
||||
elif self.method == 'file' or self.method == 'files':
|
||||
seqName = verbose[0]
|
||||
ablation = verbose[1]
|
||||
if ablation:
|
||||
filePath = r'tracker/GMC_files/MOT17_ablation'
|
||||
else:
|
||||
filePath = r'tracker/GMC_files/MOTChallenge'
|
||||
|
||||
if '-FRCNN' in seqName:
|
||||
seqName = seqName[:-6]
|
||||
elif '-DPM' in seqName:
|
||||
seqName = seqName[:-4]
|
||||
elif '-SDP' in seqName:
|
||||
seqName = seqName[:-4]
|
||||
|
||||
self.gmcFile = open(filePath + "/GMC-" + seqName + ".txt", 'r')
|
||||
|
||||
if self.gmcFile is None:
|
||||
raise ValueError("Error: Unable to open GMC file in directory:" + filePath)
|
||||
elif self.method == 'none' or self.method == 'None':
|
||||
self.method = 'none'
|
||||
else:
|
||||
raise ValueError("Error: Unknown CMC method:" + method)
|
||||
|
||||
self.prevFrame = None
|
||||
self.prevKeyPoints = None
|
||||
self.prevDescriptors = None
|
||||
|
||||
self.initializedFirstFrame = False
|
||||
|
||||
def apply(self, raw_frame, detections=None):
|
||||
if self.method == 'orb' or self.method == 'sift':
|
||||
return self.applyFeaures(raw_frame, detections)
|
||||
elif self.method == 'ecc':
|
||||
return self.applyEcc(raw_frame, detections)
|
||||
elif self.method == 'sparseOptFlow':
|
||||
return self.applySparseOptFlow(raw_frame, detections)
|
||||
elif self.method == 'file':
|
||||
return self.applyFile(raw_frame, detections)
|
||||
elif self.method == 'none':
|
||||
return np.eye(2, 3)
|
||||
else:
|
||||
return np.eye(2, 3)
|
||||
|
||||
def applyEcc(self, raw_frame, detections=None):
|
||||
|
||||
# Initialize
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
H = np.eye(2, 3, dtype=np.float32)
|
||||
|
||||
# Downscale image (TODO: consider using pyramids)
|
||||
if self.downscale > 1.0:
|
||||
frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
width = width // self.downscale
|
||||
height = height // self.downscale
|
||||
|
||||
# Handle first frame
|
||||
if not self.initializedFirstFrame:
|
||||
# Initialize data
|
||||
self.prevFrame = frame.copy()
|
||||
|
||||
# Initialization done
|
||||
self.initializedFirstFrame = True
|
||||
|
||||
return H
|
||||
|
||||
# Run the ECC algorithm. The results are stored in warp_matrix.
|
||||
# (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
|
||||
try:
|
||||
(cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
|
||||
except:
|
||||
print('Warning: find transform failed. Set warp as identity')
|
||||
|
||||
return H
|
||||
|
||||
def applyFeaures(self, raw_frame, detections=None):
|
||||
|
||||
# Initialize
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
H = np.eye(2, 3)
|
||||
|
||||
# Downscale image (TODO: consider using pyramids)
|
||||
if self.downscale > 1.0:
|
||||
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
width = width // self.downscale
|
||||
height = height // self.downscale
|
||||
|
||||
# find the keypoints
|
||||
mask = np.zeros_like(frame)
|
||||
# mask[int(0.05 * height): int(0.95 * height), int(0.05 * width): int(0.95 * width)] = 255
|
||||
mask[int(0.02 * height): int(0.98 * height), int(0.02 * width): int(0.98 * width)] = 255
|
||||
if detections is not None:
|
||||
for det in detections:
|
||||
tlbr = (det[:4] / self.downscale).astype(np.int_)
|
||||
mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0
|
||||
|
||||
keypoints = self.detector.detect(frame, mask)
|
||||
|
||||
# compute the descriptors
|
||||
keypoints, descriptors = self.extractor.compute(frame, keypoints)
|
||||
|
||||
# Handle first frame
|
||||
if not self.initializedFirstFrame:
|
||||
# Initialize data
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
self.prevDescriptors = copy.copy(descriptors)
|
||||
|
||||
# Initialization done
|
||||
self.initializedFirstFrame = True
|
||||
|
||||
return H
|
||||
|
||||
# Match descriptors.
|
||||
knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
|
||||
|
||||
# Filtered matches based on smallest spatial distance
|
||||
matches = []
|
||||
spatialDistances = []
|
||||
|
||||
maxSpatialDistance = 0.25 * np.array([width, height])
|
||||
|
||||
# Handle empty matches case
|
||||
if len(knnMatches) == 0:
|
||||
# Store to next iteration
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
self.prevDescriptors = copy.copy(descriptors)
|
||||
|
||||
return H
|
||||
|
||||
for m, n in knnMatches:
|
||||
if m.distance < 0.9 * n.distance:
|
||||
prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
|
||||
currKeyPointLocation = keypoints[m.trainIdx].pt
|
||||
|
||||
spatialDistance = (prevKeyPointLocation[0] - currKeyPointLocation[0],
|
||||
prevKeyPointLocation[1] - currKeyPointLocation[1])
|
||||
|
||||
if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \
|
||||
(np.abs(spatialDistance[1]) < maxSpatialDistance[1]):
|
||||
spatialDistances.append(spatialDistance)
|
||||
matches.append(m)
|
||||
|
||||
meanSpatialDistances = np.mean(spatialDistances, 0)
|
||||
stdSpatialDistances = np.std(spatialDistances, 0)
|
||||
|
||||
inliesrs = (spatialDistances - meanSpatialDistances) < 2.5 * stdSpatialDistances
|
||||
|
||||
goodMatches = []
|
||||
prevPoints = []
|
||||
currPoints = []
|
||||
for i in range(len(matches)):
|
||||
if inliesrs[i, 0] and inliesrs[i, 1]:
|
||||
goodMatches.append(matches[i])
|
||||
prevPoints.append(self.prevKeyPoints[matches[i].queryIdx].pt)
|
||||
currPoints.append(keypoints[matches[i].trainIdx].pt)
|
||||
|
||||
prevPoints = np.array(prevPoints)
|
||||
currPoints = np.array(currPoints)
|
||||
|
||||
# Draw the keypoint matches on the output image
|
||||
if 0:
|
||||
matches_img = np.hstack((self.prevFrame, frame))
|
||||
matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR)
|
||||
W = np.size(self.prevFrame, 1)
|
||||
for m in goodMatches:
|
||||
prev_pt = np.array(self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_)
|
||||
curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_)
|
||||
curr_pt[0] += W
|
||||
color = np.random.randint(0, 255, (3,))
|
||||
color = (int(color[0]), int(color[1]), int(color[2]))
|
||||
|
||||
matches_img = cv2.line(matches_img, prev_pt, curr_pt, tuple(color), 1, cv2.LINE_AA)
|
||||
matches_img = cv2.circle(matches_img, prev_pt, 2, tuple(color), -1)
|
||||
matches_img = cv2.circle(matches_img, curr_pt, 2, tuple(color), -1)
|
||||
|
||||
plt.figure()
|
||||
plt.imshow(matches_img)
|
||||
plt.show()
|
||||
|
||||
# Find rigid matrix
|
||||
if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)):
|
||||
H, inliesrs = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
|
||||
|
||||
# Handle downscale
|
||||
if self.downscale > 1.0:
|
||||
H[0, 2] *= self.downscale
|
||||
H[1, 2] *= self.downscale
|
||||
else:
|
||||
print('Warning: not enough matching points')
|
||||
|
||||
# Store to next iteration
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
self.prevDescriptors = copy.copy(descriptors)
|
||||
|
||||
return H
|
||||
|
||||
def applySparseOptFlow(self, raw_frame, detections=None):
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
# Initialize
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
H = np.eye(2, 3)
|
||||
|
||||
# Downscale image
|
||||
if self.downscale > 1.0:
|
||||
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
|
||||
# find the keypoints
|
||||
keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params)
|
||||
|
||||
# Handle first frame
|
||||
if not self.initializedFirstFrame:
|
||||
# Initialize data
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
|
||||
# Initialization done
|
||||
self.initializedFirstFrame = True
|
||||
|
||||
return H
|
||||
|
||||
# find correspondences
|
||||
matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
|
||||
|
||||
# leave good correspondences only
|
||||
prevPoints = []
|
||||
currPoints = []
|
||||
|
||||
for i in range(len(status)):
|
||||
if status[i]:
|
||||
prevPoints.append(self.prevKeyPoints[i])
|
||||
currPoints.append(matchedKeypoints[i])
|
||||
|
||||
prevPoints = np.array(prevPoints)
|
||||
currPoints = np.array(currPoints)
|
||||
|
||||
# Find rigid matrix
|
||||
if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)):
|
||||
H, inliesrs = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
|
||||
|
||||
# Handle downscale
|
||||
if self.downscale > 1.0:
|
||||
H[0, 2] *= self.downscale
|
||||
H[1, 2] *= self.downscale
|
||||
else:
|
||||
print('Warning: not enough matching points')
|
||||
|
||||
# Store to next iteration
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
|
||||
t1 = time.time()
|
||||
|
||||
# gmc_line = str(1000 * (t1 - t0)) + "\t" + str(H[0, 0]) + "\t" + str(H[0, 1]) + "\t" + str(
|
||||
# H[0, 2]) + "\t" + str(H[1, 0]) + "\t" + str(H[1, 1]) + "\t" + str(H[1, 2]) + "\n"
|
||||
# self.gmc_file.write(gmc_line)
|
||||
|
||||
return H
|
||||
|
||||
def applyFile(self, raw_frame, detections=None):
|
||||
line = self.gmcFile.readline()
|
||||
tokens = line.split("\t")
|
||||
H = np.eye(2, 3, dtype=np.float_)
|
||||
H[0, 0] = float(tokens[1])
|
||||
H[0, 1] = float(tokens[2])
|
||||
H[0, 2] = float(tokens[3])
|
||||
H[1, 0] = float(tokens[4])
|
||||
H[1, 1] = float(tokens[5])
|
||||
H[1, 2] = float(tokens[6])
|
||||
|
||||
return H
|
||||
269
feeder/trackers/botsort/kalman_filter.py
Normal file
269
feeder/trackers/botsort/kalman_filter.py
Normal file
|
|
@ -0,0 +1,269 @@
|
|||
# vim: expandtab:ts=4:sw=4
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
|
||||
|
||||
"""
|
||||
Table for the 0.95 quantile of the chi-square distribution with N degrees of
|
||||
freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
|
||||
function and used as Mahalanobis gating threshold.
|
||||
"""
|
||||
chi2inv95 = {
|
||||
1: 3.8415,
|
||||
2: 5.9915,
|
||||
3: 7.8147,
|
||||
4: 9.4877,
|
||||
5: 11.070,
|
||||
6: 12.592,
|
||||
7: 14.067,
|
||||
8: 15.507,
|
||||
9: 16.919}
|
||||
|
||||
|
||||
class KalmanFilter(object):
|
||||
"""
|
||||
A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space
|
||||
|
||||
x, y, w, h, vx, vy, vw, vh
|
||||
|
||||
contains the bounding box center position (x, y), width w, height h,
|
||||
and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location
|
||||
(x, y, w, h) is taken as direct observation of the state space (linear
|
||||
observation model).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
ndim, dt = 4, 1.
|
||||
|
||||
# Create Kalman filter model matrices.
|
||||
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
||||
for i in range(ndim):
|
||||
self._motion_mat[i, ndim + i] = dt
|
||||
self._update_mat = np.eye(ndim, 2 * ndim)
|
||||
|
||||
# Motion and observation uncertainty are chosen relative to the current
|
||||
# state estimate. These weights control the amount of uncertainty in
|
||||
# the model. This is a bit hacky.
|
||||
self._std_weight_position = 1. / 20
|
||||
self._std_weight_velocity = 1. / 160
|
||||
|
||||
def initiate(self, measurement):
|
||||
"""Create track from unassociated measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measurement : ndarray
|
||||
Bounding box coordinates (x, y, w, h) with center position (x, y),
|
||||
width w, and height h.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||
dimensional) of the new track. Unobserved velocities are initialized
|
||||
to 0 mean.
|
||||
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
mean = np.r_[mean_pos, mean_vel]
|
||||
|
||||
std = [
|
||||
2 * self._std_weight_position * measurement[2],
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
2 * self._std_weight_position * measurement[2],
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[2],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[2],
|
||||
10 * self._std_weight_velocity * measurement[3]]
|
||||
covariance = np.diag(np.square(std))
|
||||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The 8 dimensional mean vector of the object state at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The 8x8 dimensional covariance matrix of the object state at the
|
||||
previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[2],
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[2],
|
||||
self._std_weight_position * mean[3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[2],
|
||||
self._std_weight_velocity * mean[3],
|
||||
self._std_weight_velocity * mean[2],
|
||||
self._std_weight_velocity * mean[3]]
|
||||
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
covariance = np.linalg.multi_dot((
|
||||
self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance):
|
||||
"""Project state distribution to measurement space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The state's mean vector (8 dimensional array).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the projected mean and covariance matrix of the given state
|
||||
estimate.
|
||||
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[2],
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[2],
|
||||
self._std_weight_position * mean[3]]
|
||||
innovation_cov = np.diag(np.square(std))
|
||||
|
||||
mean = np.dot(self._update_mat, mean)
|
||||
covariance = np.linalg.multi_dot((
|
||||
self._update_mat, covariance, self._update_mat.T))
|
||||
return mean, covariance + innovation_cov
|
||||
|
||||
def multi_predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step (Vectorized version).
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The Nx8 dimensional mean matrix of the object states at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The Nx8x8 dimensional covariance matrics of the object states at the
|
||||
previous time step.
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[:, 2],
|
||||
self._std_weight_position * mean[:, 3],
|
||||
self._std_weight_position * mean[:, 2],
|
||||
self._std_weight_position * mean[:, 3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[:, 2],
|
||||
self._std_weight_velocity * mean[:, 3],
|
||||
self._std_weight_velocity * mean[:, 2],
|
||||
self._std_weight_velocity * mean[:, 3]]
|
||||
sqr = np.square(np.r_[std_pos, std_vel]).T
|
||||
|
||||
motion_cov = []
|
||||
for i in range(len(mean)):
|
||||
motion_cov.append(np.diag(sqr[i]))
|
||||
motion_cov = np.asarray(motion_cov)
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
|
||||
covariance = np.dot(left, self._motion_mat.T) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def update(self, mean, covariance, measurement):
|
||||
"""Run Kalman filter correction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The predicted state's mean vector (8 dimensional).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
measurement : ndarray
|
||||
The 4 dimensional measurement vector (x, y, w, h), where (x, y)
|
||||
is the center position, w the width, and h the height of the
|
||||
bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the measurement-corrected state distribution.
|
||||
|
||||
"""
|
||||
projected_mean, projected_cov = self.project(mean, covariance)
|
||||
|
||||
chol_factor, lower = scipy.linalg.cho_factor(
|
||||
projected_cov, lower=True, check_finite=False)
|
||||
kalman_gain = scipy.linalg.cho_solve(
|
||||
(chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
|
||||
check_finite=False).T
|
||||
innovation = measurement - projected_mean
|
||||
|
||||
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
||||
new_covariance = covariance - np.linalg.multi_dot((
|
||||
kalman_gain, projected_cov, kalman_gain.T))
|
||||
return new_mean, new_covariance
|
||||
|
||||
def gating_distance(self, mean, covariance, measurements,
|
||||
only_position=False, metric='maha'):
|
||||
"""Compute gating distance between state distribution and measurements.
|
||||
A suitable distance threshold can be obtained from `chi2inv95`. If
|
||||
`only_position` is False, the chi-square distribution has 4 degrees of
|
||||
freedom, otherwise 2.
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector over the state distribution (8 dimensional).
|
||||
covariance : ndarray
|
||||
Covariance of the state distribution (8x8 dimensional).
|
||||
measurements : ndarray
|
||||
An Nx4 dimensional matrix of N measurements, each in
|
||||
format (x, y, a, h) where (x, y) is the bounding box center
|
||||
position, a the aspect ratio, and h the height.
|
||||
only_position : Optional[bool]
|
||||
If True, distance computation is done with respect to the bounding
|
||||
box center position only.
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns an array of length N, where the i-th element contains the
|
||||
squared Mahalanobis distance between (mean, covariance) and
|
||||
`measurements[i]`.
|
||||
"""
|
||||
mean, covariance = self.project(mean, covariance)
|
||||
if only_position:
|
||||
mean, covariance = mean[:2], covariance[:2, :2]
|
||||
measurements = measurements[:, :2]
|
||||
|
||||
d = measurements - mean
|
||||
if metric == 'gaussian':
|
||||
return np.sum(d * d, axis=1)
|
||||
elif metric == 'maha':
|
||||
cholesky_factor = np.linalg.cholesky(covariance)
|
||||
z = scipy.linalg.solve_triangular(
|
||||
cholesky_factor, d.T, lower=True, check_finite=False,
|
||||
overwrite_b=True)
|
||||
squared_maha = np.sum(z * z, axis=0)
|
||||
return squared_maha
|
||||
else:
|
||||
raise ValueError('invalid distance metric')
|
||||
234
feeder/trackers/botsort/matching.py
Normal file
234
feeder/trackers/botsort/matching.py
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
import numpy as np
|
||||
import scipy
|
||||
import lap
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from trackers.botsort import kalman_filter
|
||||
|
||||
|
||||
def merge_matches(m1, m2, shape):
|
||||
O,P,Q = shape
|
||||
m1 = np.asarray(m1)
|
||||
m2 = np.asarray(m2)
|
||||
|
||||
M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P))
|
||||
M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q))
|
||||
|
||||
mask = M1*M2
|
||||
match = mask.nonzero()
|
||||
match = list(zip(match[0], match[1]))
|
||||
unmatched_O = tuple(set(range(O)) - set([i for i, j in match]))
|
||||
unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match]))
|
||||
|
||||
return match, unmatched_O, unmatched_Q
|
||||
|
||||
|
||||
def _indices_to_matches(cost_matrix, indices, thresh):
|
||||
matched_cost = cost_matrix[tuple(zip(*indices))]
|
||||
matched_mask = (matched_cost <= thresh)
|
||||
|
||||
matches = indices[matched_mask]
|
||||
unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
|
||||
unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
|
||||
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
def linear_assignment(cost_matrix, thresh):
|
||||
if cost_matrix.size == 0:
|
||||
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
|
||||
matches, unmatched_a, unmatched_b = [], [], []
|
||||
cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
|
||||
for ix, mx in enumerate(x):
|
||||
if mx >= 0:
|
||||
matches.append([ix, mx])
|
||||
unmatched_a = np.where(x < 0)[0]
|
||||
unmatched_b = np.where(y < 0)[0]
|
||||
matches = np.asarray(matches)
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
def ious(atlbrs, btlbrs):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atlbrs: list[tlbr] | np.ndarray
|
||||
:type atlbrs: list[tlbr] | np.ndarray
|
||||
|
||||
:rtype ious np.ndarray
|
||||
"""
|
||||
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
|
||||
if ious.size == 0:
|
||||
return ious
|
||||
|
||||
ious = bbox_ious(
|
||||
np.ascontiguousarray(atlbrs, dtype=np.float32),
|
||||
np.ascontiguousarray(btlbrs, dtype=np.float32)
|
||||
)
|
||||
|
||||
return ious
|
||||
|
||||
|
||||
def tlbr_expand(tlbr, scale=1.2):
|
||||
w = tlbr[2] - tlbr[0]
|
||||
h = tlbr[3] - tlbr[1]
|
||||
|
||||
half_scale = 0.5 * scale
|
||||
|
||||
tlbr[0] -= half_scale * w
|
||||
tlbr[1] -= half_scale * h
|
||||
tlbr[2] += half_scale * w
|
||||
tlbr[3] += half_scale * h
|
||||
|
||||
return tlbr
|
||||
|
||||
|
||||
def iou_distance(atracks, btracks):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atracks: list[STrack]
|
||||
:type btracks: list[STrack]
|
||||
|
||||
:rtype cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
|
||||
atlbrs = atracks
|
||||
btlbrs = btracks
|
||||
else:
|
||||
atlbrs = [track.tlbr for track in atracks]
|
||||
btlbrs = [track.tlbr for track in btracks]
|
||||
_ious = ious(atlbrs, btlbrs)
|
||||
cost_matrix = 1 - _ious
|
||||
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def v_iou_distance(atracks, btracks):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atracks: list[STrack]
|
||||
:type btracks: list[STrack]
|
||||
|
||||
:rtype cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
|
||||
atlbrs = atracks
|
||||
btlbrs = btracks
|
||||
else:
|
||||
atlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in atracks]
|
||||
btlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in btracks]
|
||||
_ious = ious(atlbrs, btlbrs)
|
||||
cost_matrix = 1 - _ious
|
||||
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def embedding_distance(tracks, detections, metric='cosine'):
|
||||
"""
|
||||
:param tracks: list[STrack]
|
||||
:param detections: list[BaseTrack]
|
||||
:param metric:
|
||||
:return: cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32)
|
||||
track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32)
|
||||
|
||||
cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # / 2.0 # Nomalized features
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False):
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = kalman_filter.chi2inv95[gating_dim]
|
||||
# measurements = np.asarray([det.to_xyah() for det in detections])
|
||||
measurements = np.asarray([det.to_xywh() for det in detections])
|
||||
for row, track in enumerate(tracks):
|
||||
gating_distance = kf.gating_distance(
|
||||
track.mean, track.covariance, measurements, only_position)
|
||||
cost_matrix[row, gating_distance > gating_threshold] = np.inf
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98):
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = kalman_filter.chi2inv95[gating_dim]
|
||||
# measurements = np.asarray([det.to_xyah() for det in detections])
|
||||
measurements = np.asarray([det.to_xywh() for det in detections])
|
||||
for row, track in enumerate(tracks):
|
||||
gating_distance = kf.gating_distance(
|
||||
track.mean, track.covariance, measurements, only_position, metric='maha')
|
||||
cost_matrix[row, gating_distance > gating_threshold] = np.inf
|
||||
cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def fuse_iou(cost_matrix, tracks, detections):
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
reid_sim = 1 - cost_matrix
|
||||
iou_dist = iou_distance(tracks, detections)
|
||||
iou_sim = 1 - iou_dist
|
||||
fuse_sim = reid_sim * (1 + iou_sim) / 2
|
||||
det_scores = np.array([det.score for det in detections])
|
||||
det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
|
||||
#fuse_sim = fuse_sim * (1 + det_scores) / 2
|
||||
fuse_cost = 1 - fuse_sim
|
||||
return fuse_cost
|
||||
|
||||
|
||||
def fuse_score(cost_matrix, detections):
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
iou_sim = 1 - cost_matrix
|
||||
det_scores = np.array([det.score for det in detections])
|
||||
det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
|
||||
fuse_sim = iou_sim * det_scores
|
||||
fuse_cost = 1 - fuse_sim
|
||||
return fuse_cost
|
||||
|
||||
def bbox_ious(boxes, query_boxes):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
boxes: (N, 4) ndarray of float
|
||||
query_boxes: (K, 4) ndarray of float
|
||||
Returns
|
||||
-------
|
||||
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
|
||||
"""
|
||||
N = boxes.shape[0]
|
||||
K = query_boxes.shape[0]
|
||||
overlaps = np.zeros((N, K), dtype=np.float32)
|
||||
|
||||
for k in range(K):
|
||||
box_area = (
|
||||
(query_boxes[k, 2] - query_boxes[k, 0] + 1) *
|
||||
(query_boxes[k, 3] - query_boxes[k, 1] + 1)
|
||||
)
|
||||
for n in range(N):
|
||||
iw = (
|
||||
min(boxes[n, 2], query_boxes[k, 2]) -
|
||||
max(boxes[n, 0], query_boxes[k, 0]) + 1
|
||||
)
|
||||
if iw > 0:
|
||||
ih = (
|
||||
min(boxes[n, 3], query_boxes[k, 3]) -
|
||||
max(boxes[n, 1], query_boxes[k, 1]) + 1
|
||||
)
|
||||
if ih > 0:
|
||||
ua = float(
|
||||
(boxes[n, 2] - boxes[n, 0] + 1) *
|
||||
(boxes[n, 3] - boxes[n, 1] + 1) +
|
||||
box_area - iw * ih
|
||||
)
|
||||
overlaps[n, k] = iw * ih / ua
|
||||
return overlaps
|
||||
52
feeder/trackers/bytetrack/basetrack.py
Normal file
52
feeder/trackers/bytetrack/basetrack.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class TrackState(object):
|
||||
New = 0
|
||||
Tracked = 1
|
||||
Lost = 2
|
||||
Removed = 3
|
||||
|
||||
|
||||
class BaseTrack(object):
|
||||
_count = 0
|
||||
|
||||
track_id = 0
|
||||
is_activated = False
|
||||
state = TrackState.New
|
||||
|
||||
history = OrderedDict()
|
||||
features = []
|
||||
curr_feature = None
|
||||
score = 0
|
||||
start_frame = 0
|
||||
frame_id = 0
|
||||
time_since_update = 0
|
||||
|
||||
# multi-camera
|
||||
location = (np.inf, np.inf)
|
||||
|
||||
@property
|
||||
def end_frame(self):
|
||||
return self.frame_id
|
||||
|
||||
@staticmethod
|
||||
def next_id():
|
||||
BaseTrack._count += 1
|
||||
return BaseTrack._count
|
||||
|
||||
def activate(self, *args):
|
||||
raise NotImplementedError
|
||||
|
||||
def predict(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def mark_lost(self):
|
||||
self.state = TrackState.Lost
|
||||
|
||||
def mark_removed(self):
|
||||
self.state = TrackState.Removed
|
||||
348
feeder/trackers/bytetrack/byte_tracker.py
Normal file
348
feeder/trackers/bytetrack/byte_tracker.py
Normal file
|
|
@ -0,0 +1,348 @@
|
|||
import numpy as np
|
||||
|
||||
from ultralytics.yolo.utils.ops import xywh2xyxy, xyxy2xywh
|
||||
|
||||
|
||||
from trackers.bytetrack.kalman_filter import KalmanFilter
|
||||
from trackers.bytetrack import matching
|
||||
from trackers.bytetrack.basetrack import BaseTrack, TrackState
|
||||
|
||||
class STrack(BaseTrack):
|
||||
shared_kalman = KalmanFilter()
|
||||
def __init__(self, tlwh, score, cls):
|
||||
|
||||
# wait activate
|
||||
self._tlwh = np.asarray(tlwh, dtype=np.float32)
|
||||
self.kalman_filter = None
|
||||
self.mean, self.covariance = None, None
|
||||
self.is_activated = False
|
||||
|
||||
self.score = score
|
||||
self.tracklet_len = 0
|
||||
self.cls = cls
|
||||
|
||||
def predict(self):
|
||||
mean_state = self.mean.copy()
|
||||
if self.state != TrackState.Tracked:
|
||||
mean_state[7] = 0
|
||||
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
||||
|
||||
@staticmethod
|
||||
def multi_predict(stracks):
|
||||
if len(stracks) > 0:
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
for i, st in enumerate(stracks):
|
||||
if st.state != TrackState.Tracked:
|
||||
multi_mean[i][7] = 0
|
||||
multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
|
||||
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
||||
stracks[i].mean = mean
|
||||
stracks[i].covariance = cov
|
||||
|
||||
def activate(self, kalman_filter, frame_id):
|
||||
"""Start a new tracklet"""
|
||||
self.kalman_filter = kalman_filter
|
||||
self.track_id = self.next_id()
|
||||
self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))
|
||||
|
||||
self.tracklet_len = 0
|
||||
self.state = TrackState.Tracked
|
||||
if frame_id == 1:
|
||||
self.is_activated = True
|
||||
# self.is_activated = True
|
||||
self.frame_id = frame_id
|
||||
self.start_frame = frame_id
|
||||
|
||||
def re_activate(self, new_track, frame_id, new_id=False):
|
||||
self.mean, self.covariance = self.kalman_filter.update(
|
||||
self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
|
||||
)
|
||||
self.tracklet_len = 0
|
||||
self.state = TrackState.Tracked
|
||||
self.is_activated = True
|
||||
self.frame_id = frame_id
|
||||
if new_id:
|
||||
self.track_id = self.next_id()
|
||||
self.score = new_track.score
|
||||
self.cls = new_track.cls
|
||||
|
||||
def update(self, new_track, frame_id):
|
||||
"""
|
||||
Update a matched track
|
||||
:type new_track: STrack
|
||||
:type frame_id: int
|
||||
:type update_feature: bool
|
||||
:return:
|
||||
"""
|
||||
self.frame_id = frame_id
|
||||
self.tracklet_len += 1
|
||||
# self.cls = cls
|
||||
|
||||
new_tlwh = new_track.tlwh
|
||||
self.mean, self.covariance = self.kalman_filter.update(
|
||||
self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
|
||||
self.state = TrackState.Tracked
|
||||
self.is_activated = True
|
||||
|
||||
self.score = new_track.score
|
||||
|
||||
@property
|
||||
# @jit(nopython=True)
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y,
|
||||
width, height)`.
|
||||
"""
|
||||
if self.mean is None:
|
||||
return self._tlwh.copy()
|
||||
ret = self.mean[:4].copy()
|
||||
ret[2] *= ret[3]
|
||||
ret[:2] -= ret[2:] / 2
|
||||
return ret
|
||||
|
||||
@property
|
||||
# @jit(nopython=True)
|
||||
def tlbr(self):
|
||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
||||
`(top left, bottom right)`.
|
||||
"""
|
||||
ret = self.tlwh.copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
# @jit(nopython=True)
|
||||
def tlwh_to_xyah(tlwh):
|
||||
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
||||
height)`, where the aspect ratio is `width / height`.
|
||||
"""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
ret[2] /= ret[3]
|
||||
return ret
|
||||
|
||||
def to_xyah(self):
|
||||
return self.tlwh_to_xyah(self.tlwh)
|
||||
|
||||
@staticmethod
|
||||
# @jit(nopython=True)
|
||||
def tlbr_to_tlwh(tlbr):
|
||||
ret = np.asarray(tlbr).copy()
|
||||
ret[2:] -= ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
# @jit(nopython=True)
|
||||
def tlwh_to_tlbr(tlwh):
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
def __repr__(self):
|
||||
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
|
||||
|
||||
|
||||
class BYTETracker(object):
|
||||
def __init__(self, track_thresh=0.45, match_thresh=0.8, track_buffer=25, frame_rate=30):
|
||||
self.tracked_stracks = [] # type: list[STrack]
|
||||
self.lost_stracks = [] # type: list[STrack]
|
||||
self.removed_stracks = [] # type: list[STrack]
|
||||
|
||||
self.frame_id = 0
|
||||
self.track_buffer=track_buffer
|
||||
|
||||
self.track_thresh = track_thresh
|
||||
self.match_thresh = match_thresh
|
||||
self.det_thresh = track_thresh + 0.1
|
||||
self.buffer_size = int(frame_rate / 30.0 * track_buffer)
|
||||
self.max_time_lost = self.buffer_size
|
||||
self.kalman_filter = KalmanFilter()
|
||||
|
||||
def update(self, dets, _):
|
||||
self.frame_id += 1
|
||||
activated_starcks = []
|
||||
refind_stracks = []
|
||||
lost_stracks = []
|
||||
removed_stracks = []
|
||||
|
||||
xyxys = dets[:, 0:4]
|
||||
xywh = xyxy2xywh(xyxys.numpy())
|
||||
confs = dets[:, 4]
|
||||
clss = dets[:, 5]
|
||||
|
||||
classes = clss.numpy()
|
||||
xyxys = xyxys.numpy()
|
||||
confs = confs.numpy()
|
||||
|
||||
remain_inds = confs > self.track_thresh
|
||||
inds_low = confs > 0.1
|
||||
inds_high = confs < self.track_thresh
|
||||
|
||||
inds_second = np.logical_and(inds_low, inds_high)
|
||||
|
||||
dets_second = xywh[inds_second]
|
||||
dets = xywh[remain_inds]
|
||||
|
||||
scores_keep = confs[remain_inds]
|
||||
scores_second = confs[inds_second]
|
||||
|
||||
clss_keep = classes[remain_inds]
|
||||
clss_second = classes[inds_second]
|
||||
|
||||
|
||||
if len(dets) > 0:
|
||||
'''Detections'''
|
||||
detections = [STrack(xyxy, s, c) for
|
||||
(xyxy, s, c) in zip(dets, scores_keep, clss_keep)]
|
||||
else:
|
||||
detections = []
|
||||
|
||||
''' Add newly detected tracklets to tracked_stracks'''
|
||||
unconfirmed = []
|
||||
tracked_stracks = [] # type: list[STrack]
|
||||
for track in self.tracked_stracks:
|
||||
if not track.is_activated:
|
||||
unconfirmed.append(track)
|
||||
else:
|
||||
tracked_stracks.append(track)
|
||||
|
||||
''' Step 2: First association, with high score detection boxes'''
|
||||
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
|
||||
# Predict the current location with KF
|
||||
STrack.multi_predict(strack_pool)
|
||||
dists = matching.iou_distance(strack_pool, detections)
|
||||
#if not self.args.mot20:
|
||||
dists = matching.fuse_score(dists, detections)
|
||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.match_thresh)
|
||||
|
||||
for itracked, idet in matches:
|
||||
track = strack_pool[itracked]
|
||||
det = detections[idet]
|
||||
if track.state == TrackState.Tracked:
|
||||
track.update(detections[idet], self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
|
||||
''' Step 3: Second association, with low score detection boxes'''
|
||||
# association the untrack to the low score detections
|
||||
if len(dets_second) > 0:
|
||||
'''Detections'''
|
||||
detections_second = [STrack(xywh, s, c) for (xywh, s, c) in zip(dets_second, scores_second, clss_second)]
|
||||
else:
|
||||
detections_second = []
|
||||
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
|
||||
dists = matching.iou_distance(r_tracked_stracks, detections_second)
|
||||
matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)
|
||||
for itracked, idet in matches:
|
||||
track = r_tracked_stracks[itracked]
|
||||
det = detections_second[idet]
|
||||
if track.state == TrackState.Tracked:
|
||||
track.update(det, self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
|
||||
for it in u_track:
|
||||
track = r_tracked_stracks[it]
|
||||
if not track.state == TrackState.Lost:
|
||||
track.mark_lost()
|
||||
lost_stracks.append(track)
|
||||
|
||||
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
|
||||
detections = [detections[i] for i in u_detection]
|
||||
dists = matching.iou_distance(unconfirmed, detections)
|
||||
#if not self.args.mot20:
|
||||
dists = matching.fuse_score(dists, detections)
|
||||
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
||||
for itracked, idet in matches:
|
||||
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
||||
activated_starcks.append(unconfirmed[itracked])
|
||||
for it in u_unconfirmed:
|
||||
track = unconfirmed[it]
|
||||
track.mark_removed()
|
||||
removed_stracks.append(track)
|
||||
|
||||
""" Step 4: Init new stracks"""
|
||||
for inew in u_detection:
|
||||
track = detections[inew]
|
||||
if track.score < self.det_thresh:
|
||||
continue
|
||||
track.activate(self.kalman_filter, self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
""" Step 5: Update state"""
|
||||
for track in self.lost_stracks:
|
||||
if self.frame_id - track.end_frame > self.max_time_lost:
|
||||
track.mark_removed()
|
||||
removed_stracks.append(track)
|
||||
|
||||
# print('Ramained match {} s'.format(t4-t3))
|
||||
|
||||
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
|
||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
|
||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
|
||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
|
||||
self.lost_stracks.extend(lost_stracks)
|
||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
|
||||
self.removed_stracks.extend(removed_stracks)
|
||||
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
|
||||
# get scores of lost tracks
|
||||
output_stracks = [track for track in self.tracked_stracks if track.is_activated]
|
||||
outputs = []
|
||||
for t in output_stracks:
|
||||
output= []
|
||||
tlwh = t.tlwh
|
||||
tid = t.track_id
|
||||
tlwh = np.expand_dims(tlwh, axis=0)
|
||||
xyxy = xywh2xyxy(tlwh)
|
||||
xyxy = np.squeeze(xyxy, axis=0)
|
||||
output.extend(xyxy)
|
||||
output.append(tid)
|
||||
output.append(t.cls)
|
||||
output.append(t.score)
|
||||
outputs.append(output)
|
||||
|
||||
return outputs
|
||||
#track_id, class_id, conf
|
||||
|
||||
def joint_stracks(tlista, tlistb):
|
||||
exists = {}
|
||||
res = []
|
||||
for t in tlista:
|
||||
exists[t.track_id] = 1
|
||||
res.append(t)
|
||||
for t in tlistb:
|
||||
tid = t.track_id
|
||||
if not exists.get(tid, 0):
|
||||
exists[tid] = 1
|
||||
res.append(t)
|
||||
return res
|
||||
|
||||
|
||||
def sub_stracks(tlista, tlistb):
|
||||
stracks = {}
|
||||
for t in tlista:
|
||||
stracks[t.track_id] = t
|
||||
for t in tlistb:
|
||||
tid = t.track_id
|
||||
if stracks.get(tid, 0):
|
||||
del stracks[tid]
|
||||
return list(stracks.values())
|
||||
|
||||
|
||||
def remove_duplicate_stracks(stracksa, stracksb):
|
||||
pdist = matching.iou_distance(stracksa, stracksb)
|
||||
pairs = np.where(pdist < 0.15)
|
||||
dupa, dupb = list(), list()
|
||||
for p, q in zip(*pairs):
|
||||
timep = stracksa[p].frame_id - stracksa[p].start_frame
|
||||
timeq = stracksb[q].frame_id - stracksb[q].start_frame
|
||||
if timep > timeq:
|
||||
dupb.append(q)
|
||||
else:
|
||||
dupa.append(p)
|
||||
resa = [t for i, t in enumerate(stracksa) if not i in dupa]
|
||||
resb = [t for i, t in enumerate(stracksb) if not i in dupb]
|
||||
return resa, resb
|
||||
7
feeder/trackers/bytetrack/configs/bytetrack.yaml
Normal file
7
feeder/trackers/bytetrack/configs/bytetrack.yaml
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
bytetrack:
|
||||
track_thresh: 0.6 # tracking confidence threshold
|
||||
track_buffer: 30 # the frames for keep lost tracks
|
||||
match_thresh: 0.8 # matching threshold for tracking
|
||||
frame_rate: 30 # FPS
|
||||
conf_thres: 0.5122620708221085
|
||||
|
||||
270
feeder/trackers/bytetrack/kalman_filter.py
Normal file
270
feeder/trackers/bytetrack/kalman_filter.py
Normal file
|
|
@ -0,0 +1,270 @@
|
|||
# vim: expandtab:ts=4:sw=4
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
|
||||
|
||||
"""
|
||||
Table for the 0.95 quantile of the chi-square distribution with N degrees of
|
||||
freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
|
||||
function and used as Mahalanobis gating threshold.
|
||||
"""
|
||||
chi2inv95 = {
|
||||
1: 3.8415,
|
||||
2: 5.9915,
|
||||
3: 7.8147,
|
||||
4: 9.4877,
|
||||
5: 11.070,
|
||||
6: 12.592,
|
||||
7: 14.067,
|
||||
8: 15.507,
|
||||
9: 16.919}
|
||||
|
||||
|
||||
class KalmanFilter(object):
|
||||
"""
|
||||
A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space
|
||||
|
||||
x, y, a, h, vx, vy, va, vh
|
||||
|
||||
contains the bounding box center position (x, y), aspect ratio a, height h,
|
||||
and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location
|
||||
(x, y, a, h) is taken as direct observation of the state space (linear
|
||||
observation model).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
ndim, dt = 4, 1.
|
||||
|
||||
# Create Kalman filter model matrices.
|
||||
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
||||
for i in range(ndim):
|
||||
self._motion_mat[i, ndim + i] = dt
|
||||
self._update_mat = np.eye(ndim, 2 * ndim)
|
||||
|
||||
# Motion and observation uncertainty are chosen relative to the current
|
||||
# state estimate. These weights control the amount of uncertainty in
|
||||
# the model. This is a bit hacky.
|
||||
self._std_weight_position = 1. / 20
|
||||
self._std_weight_velocity = 1. / 160
|
||||
|
||||
def initiate(self, measurement):
|
||||
"""Create track from unassociated measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measurement : ndarray
|
||||
Bounding box coordinates (x, y, a, h) with center position (x, y),
|
||||
aspect ratio a, and height h.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||
dimensional) of the new track. Unobserved velocities are initialized
|
||||
to 0 mean.
|
||||
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
mean = np.r_[mean_pos, mean_vel]
|
||||
|
||||
std = [
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
1e-2,
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
1e-5,
|
||||
10 * self._std_weight_velocity * measurement[3]]
|
||||
covariance = np.diag(np.square(std))
|
||||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The 8 dimensional mean vector of the object state at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The 8x8 dimensional covariance matrix of the object state at the
|
||||
previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[3],
|
||||
1e-2,
|
||||
self._std_weight_position * mean[3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[3],
|
||||
self._std_weight_velocity * mean[3],
|
||||
1e-5,
|
||||
self._std_weight_velocity * mean[3]]
|
||||
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
||||
|
||||
#mean = np.dot(self._motion_mat, mean)
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
covariance = np.linalg.multi_dot((
|
||||
self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance):
|
||||
"""Project state distribution to measurement space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The state's mean vector (8 dimensional array).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the projected mean and covariance matrix of the given state
|
||||
estimate.
|
||||
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[3],
|
||||
1e-1,
|
||||
self._std_weight_position * mean[3]]
|
||||
innovation_cov = np.diag(np.square(std))
|
||||
|
||||
mean = np.dot(self._update_mat, mean)
|
||||
covariance = np.linalg.multi_dot((
|
||||
self._update_mat, covariance, self._update_mat.T))
|
||||
return mean, covariance + innovation_cov
|
||||
|
||||
def multi_predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step (Vectorized version).
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The Nx8 dimensional mean matrix of the object states at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The Nx8x8 dimensional covariance matrics of the object states at the
|
||||
previous time step.
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[:, 3],
|
||||
self._std_weight_position * mean[:, 3],
|
||||
1e-2 * np.ones_like(mean[:, 3]),
|
||||
self._std_weight_position * mean[:, 3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[:, 3],
|
||||
self._std_weight_velocity * mean[:, 3],
|
||||
1e-5 * np.ones_like(mean[:, 3]),
|
||||
self._std_weight_velocity * mean[:, 3]]
|
||||
sqr = np.square(np.r_[std_pos, std_vel]).T
|
||||
|
||||
motion_cov = []
|
||||
for i in range(len(mean)):
|
||||
motion_cov.append(np.diag(sqr[i]))
|
||||
motion_cov = np.asarray(motion_cov)
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
|
||||
covariance = np.dot(left, self._motion_mat.T) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def update(self, mean, covariance, measurement):
|
||||
"""Run Kalman filter correction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The predicted state's mean vector (8 dimensional).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
measurement : ndarray
|
||||
The 4 dimensional measurement vector (x, y, a, h), where (x, y)
|
||||
is the center position, a the aspect ratio, and h the height of the
|
||||
bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the measurement-corrected state distribution.
|
||||
|
||||
"""
|
||||
projected_mean, projected_cov = self.project(mean, covariance)
|
||||
|
||||
chol_factor, lower = scipy.linalg.cho_factor(
|
||||
projected_cov, lower=True, check_finite=False)
|
||||
kalman_gain = scipy.linalg.cho_solve(
|
||||
(chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
|
||||
check_finite=False).T
|
||||
innovation = measurement - projected_mean
|
||||
|
||||
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
||||
new_covariance = covariance - np.linalg.multi_dot((
|
||||
kalman_gain, projected_cov, kalman_gain.T))
|
||||
return new_mean, new_covariance
|
||||
|
||||
def gating_distance(self, mean, covariance, measurements,
|
||||
only_position=False, metric='maha'):
|
||||
"""Compute gating distance between state distribution and measurements.
|
||||
A suitable distance threshold can be obtained from `chi2inv95`. If
|
||||
`only_position` is False, the chi-square distribution has 4 degrees of
|
||||
freedom, otherwise 2.
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector over the state distribution (8 dimensional).
|
||||
covariance : ndarray
|
||||
Covariance of the state distribution (8x8 dimensional).
|
||||
measurements : ndarray
|
||||
An Nx4 dimensional matrix of N measurements, each in
|
||||
format (x, y, a, h) where (x, y) is the bounding box center
|
||||
position, a the aspect ratio, and h the height.
|
||||
only_position : Optional[bool]
|
||||
If True, distance computation is done with respect to the bounding
|
||||
box center position only.
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns an array of length N, where the i-th element contains the
|
||||
squared Mahalanobis distance between (mean, covariance) and
|
||||
`measurements[i]`.
|
||||
"""
|
||||
mean, covariance = self.project(mean, covariance)
|
||||
if only_position:
|
||||
mean, covariance = mean[:2], covariance[:2, :2]
|
||||
measurements = measurements[:, :2]
|
||||
|
||||
d = measurements - mean
|
||||
if metric == 'gaussian':
|
||||
return np.sum(d * d, axis=1)
|
||||
elif metric == 'maha':
|
||||
cholesky_factor = np.linalg.cholesky(covariance)
|
||||
z = scipy.linalg.solve_triangular(
|
||||
cholesky_factor, d.T, lower=True, check_finite=False,
|
||||
overwrite_b=True)
|
||||
squared_maha = np.sum(z * z, axis=0)
|
||||
return squared_maha
|
||||
else:
|
||||
raise ValueError('invalid distance metric')
|
||||
219
feeder/trackers/bytetrack/matching.py
Normal file
219
feeder/trackers/bytetrack/matching.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
import scipy
|
||||
import lap
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from trackers.bytetrack import kalman_filter
|
||||
import time
|
||||
|
||||
def merge_matches(m1, m2, shape):
|
||||
O,P,Q = shape
|
||||
m1 = np.asarray(m1)
|
||||
m2 = np.asarray(m2)
|
||||
|
||||
M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P))
|
||||
M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q))
|
||||
|
||||
mask = M1*M2
|
||||
match = mask.nonzero()
|
||||
match = list(zip(match[0], match[1]))
|
||||
unmatched_O = tuple(set(range(O)) - set([i for i, j in match]))
|
||||
unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match]))
|
||||
|
||||
return match, unmatched_O, unmatched_Q
|
||||
|
||||
|
||||
def _indices_to_matches(cost_matrix, indices, thresh):
|
||||
matched_cost = cost_matrix[tuple(zip(*indices))]
|
||||
matched_mask = (matched_cost <= thresh)
|
||||
|
||||
matches = indices[matched_mask]
|
||||
unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
|
||||
unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
|
||||
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
def linear_assignment(cost_matrix, thresh):
|
||||
if cost_matrix.size == 0:
|
||||
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
|
||||
matches, unmatched_a, unmatched_b = [], [], []
|
||||
cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
|
||||
for ix, mx in enumerate(x):
|
||||
if mx >= 0:
|
||||
matches.append([ix, mx])
|
||||
unmatched_a = np.where(x < 0)[0]
|
||||
unmatched_b = np.where(y < 0)[0]
|
||||
matches = np.asarray(matches)
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
def ious(atlbrs, btlbrs):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atlbrs: list[tlbr] | np.ndarray
|
||||
:type atlbrs: list[tlbr] | np.ndarray
|
||||
|
||||
:rtype ious np.ndarray
|
||||
"""
|
||||
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
|
||||
if ious.size == 0:
|
||||
return ious
|
||||
|
||||
ious = bbox_ious(
|
||||
np.ascontiguousarray(atlbrs, dtype=np.float32),
|
||||
np.ascontiguousarray(btlbrs, dtype=np.float32)
|
||||
)
|
||||
|
||||
return ious
|
||||
|
||||
|
||||
def iou_distance(atracks, btracks):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atracks: list[STrack]
|
||||
:type btracks: list[STrack]
|
||||
|
||||
:rtype cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
|
||||
atlbrs = atracks
|
||||
btlbrs = btracks
|
||||
else:
|
||||
atlbrs = [track.tlbr for track in atracks]
|
||||
btlbrs = [track.tlbr for track in btracks]
|
||||
_ious = ious(atlbrs, btlbrs)
|
||||
cost_matrix = 1 - _ious
|
||||
|
||||
return cost_matrix
|
||||
|
||||
def v_iou_distance(atracks, btracks):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atracks: list[STrack]
|
||||
:type btracks: list[STrack]
|
||||
|
||||
:rtype cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
|
||||
atlbrs = atracks
|
||||
btlbrs = btracks
|
||||
else:
|
||||
atlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in atracks]
|
||||
btlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in btracks]
|
||||
_ious = ious(atlbrs, btlbrs)
|
||||
cost_matrix = 1 - _ious
|
||||
|
||||
return cost_matrix
|
||||
|
||||
def embedding_distance(tracks, detections, metric='cosine'):
|
||||
"""
|
||||
:param tracks: list[STrack]
|
||||
:param detections: list[BaseTrack]
|
||||
:param metric:
|
||||
:return: cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32)
|
||||
#for i, track in enumerate(tracks):
|
||||
#cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric))
|
||||
track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32)
|
||||
cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Nomalized features
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False):
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = kalman_filter.chi2inv95[gating_dim]
|
||||
measurements = np.asarray([det.to_xyah() for det in detections])
|
||||
for row, track in enumerate(tracks):
|
||||
gating_distance = kf.gating_distance(
|
||||
track.mean, track.covariance, measurements, only_position)
|
||||
cost_matrix[row, gating_distance > gating_threshold] = np.inf
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98):
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = kalman_filter.chi2inv95[gating_dim]
|
||||
measurements = np.asarray([det.to_xyah() for det in detections])
|
||||
for row, track in enumerate(tracks):
|
||||
gating_distance = kf.gating_distance(
|
||||
track.mean, track.covariance, measurements, only_position, metric='maha')
|
||||
cost_matrix[row, gating_distance > gating_threshold] = np.inf
|
||||
cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def fuse_iou(cost_matrix, tracks, detections):
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
reid_sim = 1 - cost_matrix
|
||||
iou_dist = iou_distance(tracks, detections)
|
||||
iou_sim = 1 - iou_dist
|
||||
fuse_sim = reid_sim * (1 + iou_sim) / 2
|
||||
det_scores = np.array([det.score for det in detections])
|
||||
det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
|
||||
#fuse_sim = fuse_sim * (1 + det_scores) / 2
|
||||
fuse_cost = 1 - fuse_sim
|
||||
return fuse_cost
|
||||
|
||||
|
||||
def fuse_score(cost_matrix, detections):
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
iou_sim = 1 - cost_matrix
|
||||
det_scores = np.array([det.score for det in detections])
|
||||
det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
|
||||
fuse_sim = iou_sim * det_scores
|
||||
fuse_cost = 1 - fuse_sim
|
||||
return fuse_cost
|
||||
|
||||
|
||||
def bbox_ious(boxes, query_boxes):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
boxes: (N, 4) ndarray of float
|
||||
query_boxes: (K, 4) ndarray of float
|
||||
Returns
|
||||
-------
|
||||
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
|
||||
"""
|
||||
N = boxes.shape[0]
|
||||
K = query_boxes.shape[0]
|
||||
overlaps = np.zeros((N, K), dtype=np.float32)
|
||||
|
||||
for k in range(K):
|
||||
box_area = (
|
||||
(query_boxes[k, 2] - query_boxes[k, 0] + 1) *
|
||||
(query_boxes[k, 3] - query_boxes[k, 1] + 1)
|
||||
)
|
||||
for n in range(N):
|
||||
iw = (
|
||||
min(boxes[n, 2], query_boxes[k, 2]) -
|
||||
max(boxes[n, 0], query_boxes[k, 0]) + 1
|
||||
)
|
||||
if iw > 0:
|
||||
ih = (
|
||||
min(boxes[n, 3], query_boxes[k, 3]) -
|
||||
max(boxes[n, 1], query_boxes[k, 1]) + 1
|
||||
)
|
||||
if ih > 0:
|
||||
ua = float(
|
||||
(boxes[n, 2] - boxes[n, 0] + 1) *
|
||||
(boxes[n, 3] - boxes[n, 1] + 1) +
|
||||
box_area - iw * ih
|
||||
)
|
||||
overlaps[n, k] = iw * ih / ua
|
||||
return overlaps
|
||||
2
feeder/trackers/deepocsort/__init__.py
Normal file
2
feeder/trackers/deepocsort/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from . import args
|
||||
from . import ocsort
|
||||
110
feeder/trackers/deepocsort/args.py
Normal file
110
feeder/trackers/deepocsort/args.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
import argparse
|
||||
|
||||
|
||||
def make_parser():
|
||||
parser = argparse.ArgumentParser("OC-SORT parameters")
|
||||
|
||||
# distributed
|
||||
parser.add_argument("-b", "--batch-size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("-d", "--devices", default=None, type=int, help="device for training")
|
||||
|
||||
parser.add_argument("--local_rank", default=0, type=int, help="local rank for dist training")
|
||||
parser.add_argument("--num_machines", default=1, type=int, help="num of node for training")
|
||||
parser.add_argument("--machine_rank", default=0, type=int, help="node rank for multi-node training")
|
||||
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--exp_file",
|
||||
default=None,
|
||||
type=str,
|
||||
help="pls input your expriment description file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
dest="test",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Evaluating on test-dev set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"opts",
|
||||
help="Modify config options using the command-line",
|
||||
default=None,
|
||||
nargs=argparse.REMAINDER,
|
||||
)
|
||||
|
||||
# det args
|
||||
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
|
||||
parser.add_argument("--conf", default=0.1, type=float, help="test conf")
|
||||
parser.add_argument("--nms", default=0.7, type=float, help="test nms threshold")
|
||||
parser.add_argument("--tsize", default=[800, 1440], nargs="+", type=int, help="test img size")
|
||||
parser.add_argument("--seed", default=None, type=int, help="eval seed")
|
||||
|
||||
# tracking args
|
||||
parser.add_argument("--track_thresh", type=float, default=0.6, help="detection confidence threshold")
|
||||
parser.add_argument(
|
||||
"--iou_thresh",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="the iou threshold in Sort for matching",
|
||||
)
|
||||
parser.add_argument("--min_hits", type=int, default=3, help="min hits to create track in SORT")
|
||||
parser.add_argument(
|
||||
"--inertia",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="the weight of VDC term in cost matrix",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--deltat",
|
||||
type=int,
|
||||
default=3,
|
||||
help="time step difference to estimate direction",
|
||||
)
|
||||
parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
|
||||
parser.add_argument(
|
||||
"--match_thresh",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="matching threshold for tracking",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gt-type",
|
||||
type=str,
|
||||
default="_val_half",
|
||||
help="suffix to find the gt annotation",
|
||||
)
|
||||
parser.add_argument("--public", action="store_true", help="use public detection")
|
||||
parser.add_argument("--asso", default="iou", help="similarity function: iou/giou/diou/ciou/ctdis")
|
||||
|
||||
# for kitti/bdd100k inference with public detections
|
||||
parser.add_argument(
|
||||
"--raw_results_path",
|
||||
type=str,
|
||||
default="exps/permatrack_kitti_test/",
|
||||
help="path to the raw tracking results from other tracks",
|
||||
)
|
||||
parser.add_argument("--out_path", type=str, help="path to save output results")
|
||||
parser.add_argument(
|
||||
"--hp",
|
||||
action="store_true",
|
||||
help="use head padding to add the missing objects during \
|
||||
initializing the tracks (offline).",
|
||||
)
|
||||
|
||||
# for demo video
|
||||
parser.add_argument("--demo_type", default="image", help="demo type, eg. image, video and webcam")
|
||||
parser.add_argument("--path", default="./videos/demo.mp4", help="path to images or video")
|
||||
parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
|
||||
parser.add_argument(
|
||||
"--save_result",
|
||||
action="store_true",
|
||||
help="whether to save the inference result of image/video",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
default="gpu",
|
||||
type=str,
|
||||
help="device to run our model, can either be cpu or gpu",
|
||||
)
|
||||
return parser
|
||||
445
feeder/trackers/deepocsort/association.py
Normal file
445
feeder/trackers/deepocsort/association.py
Normal file
|
|
@ -0,0 +1,445 @@
|
|||
import os
|
||||
import pdb
|
||||
|
||||
import numpy as np
|
||||
from scipy.special import softmax
|
||||
|
||||
|
||||
def iou_batch(bboxes1, bboxes2):
|
||||
"""
|
||||
From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2]
|
||||
"""
|
||||
bboxes2 = np.expand_dims(bboxes2, 0)
|
||||
bboxes1 = np.expand_dims(bboxes1, 1)
|
||||
|
||||
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
|
||||
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
|
||||
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
|
||||
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
|
||||
w = np.maximum(0.0, xx2 - xx1)
|
||||
h = np.maximum(0.0, yy2 - yy1)
|
||||
wh = w * h
|
||||
o = wh / (
|
||||
(bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
|
||||
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1])
|
||||
- wh
|
||||
)
|
||||
return o
|
||||
|
||||
|
||||
def giou_batch(bboxes1, bboxes2):
|
||||
"""
|
||||
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
|
||||
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
|
||||
:return:
|
||||
"""
|
||||
# for details should go to https://arxiv.org/pdf/1902.09630.pdf
|
||||
# ensure predict's bbox form
|
||||
bboxes2 = np.expand_dims(bboxes2, 0)
|
||||
bboxes1 = np.expand_dims(bboxes1, 1)
|
||||
|
||||
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
|
||||
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
|
||||
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
|
||||
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
|
||||
w = np.maximum(0.0, xx2 - xx1)
|
||||
h = np.maximum(0.0, yy2 - yy1)
|
||||
wh = w * h
|
||||
iou = wh / (
|
||||
(bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
|
||||
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1])
|
||||
- wh
|
||||
)
|
||||
|
||||
xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0])
|
||||
yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1])
|
||||
xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2])
|
||||
yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3])
|
||||
wc = xxc2 - xxc1
|
||||
hc = yyc2 - yyc1
|
||||
assert (wc > 0).all() and (hc > 0).all()
|
||||
area_enclose = wc * hc
|
||||
giou = iou - (area_enclose - wh) / area_enclose
|
||||
giou = (giou + 1.0) / 2.0 # resize from (-1,1) to (0,1)
|
||||
return giou
|
||||
|
||||
|
||||
def diou_batch(bboxes1, bboxes2):
|
||||
"""
|
||||
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
|
||||
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
|
||||
:return:
|
||||
"""
|
||||
# for details should go to https://arxiv.org/pdf/1902.09630.pdf
|
||||
# ensure predict's bbox form
|
||||
bboxes2 = np.expand_dims(bboxes2, 0)
|
||||
bboxes1 = np.expand_dims(bboxes1, 1)
|
||||
|
||||
# calculate the intersection box
|
||||
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
|
||||
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
|
||||
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
|
||||
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
|
||||
w = np.maximum(0.0, xx2 - xx1)
|
||||
h = np.maximum(0.0, yy2 - yy1)
|
||||
wh = w * h
|
||||
iou = wh / (
|
||||
(bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
|
||||
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1])
|
||||
- wh
|
||||
)
|
||||
|
||||
centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0
|
||||
centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0
|
||||
centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0
|
||||
centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0
|
||||
|
||||
inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2
|
||||
|
||||
xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0])
|
||||
yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1])
|
||||
xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2])
|
||||
yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3])
|
||||
|
||||
outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2
|
||||
diou = iou - inner_diag / outer_diag
|
||||
|
||||
return (diou + 1) / 2.0 # resize from (-1,1) to (0,1)
|
||||
|
||||
|
||||
def ciou_batch(bboxes1, bboxes2):
|
||||
"""
|
||||
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
|
||||
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
|
||||
:return:
|
||||
"""
|
||||
# for details should go to https://arxiv.org/pdf/1902.09630.pdf
|
||||
# ensure predict's bbox form
|
||||
bboxes2 = np.expand_dims(bboxes2, 0)
|
||||
bboxes1 = np.expand_dims(bboxes1, 1)
|
||||
|
||||
# calculate the intersection box
|
||||
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
|
||||
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
|
||||
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
|
||||
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
|
||||
w = np.maximum(0.0, xx2 - xx1)
|
||||
h = np.maximum(0.0, yy2 - yy1)
|
||||
wh = w * h
|
||||
iou = wh / (
|
||||
(bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
|
||||
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1])
|
||||
- wh
|
||||
)
|
||||
|
||||
centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0
|
||||
centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0
|
||||
centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0
|
||||
centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0
|
||||
|
||||
inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2
|
||||
|
||||
xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0])
|
||||
yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1])
|
||||
xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2])
|
||||
yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3])
|
||||
|
||||
outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2
|
||||
|
||||
w1 = bboxes1[..., 2] - bboxes1[..., 0]
|
||||
h1 = bboxes1[..., 3] - bboxes1[..., 1]
|
||||
w2 = bboxes2[..., 2] - bboxes2[..., 0]
|
||||
h2 = bboxes2[..., 3] - bboxes2[..., 1]
|
||||
|
||||
# prevent dividing over zero. add one pixel shift
|
||||
h2 = h2 + 1.0
|
||||
h1 = h1 + 1.0
|
||||
arctan = np.arctan(w2 / h2) - np.arctan(w1 / h1)
|
||||
v = (4 / (np.pi**2)) * (arctan**2)
|
||||
S = 1 - iou
|
||||
alpha = v / (S + v)
|
||||
ciou = iou - inner_diag / outer_diag - alpha * v
|
||||
|
||||
return (ciou + 1) / 2.0 # resize from (-1,1) to (0,1)
|
||||
|
||||
|
||||
def ct_dist(bboxes1, bboxes2):
|
||||
"""
|
||||
Measure the center distance between two sets of bounding boxes,
|
||||
this is a coarse implementation, we don't recommend using it only
|
||||
for association, which can be unstable and sensitive to frame rate
|
||||
and object speed.
|
||||
"""
|
||||
bboxes2 = np.expand_dims(bboxes2, 0)
|
||||
bboxes1 = np.expand_dims(bboxes1, 1)
|
||||
|
||||
centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0
|
||||
centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0
|
||||
centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0
|
||||
centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0
|
||||
|
||||
ct_dist2 = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2
|
||||
|
||||
ct_dist = np.sqrt(ct_dist2)
|
||||
|
||||
# The linear rescaling is a naive version and needs more study
|
||||
ct_dist = ct_dist / ct_dist.max()
|
||||
return ct_dist.max() - ct_dist # resize to (0,1)
|
||||
|
||||
|
||||
def speed_direction_batch(dets, tracks):
|
||||
tracks = tracks[..., np.newaxis]
|
||||
CX1, CY1 = (dets[:, 0] + dets[:, 2]) / 2.0, (dets[:, 1] + dets[:, 3]) / 2.0
|
||||
CX2, CY2 = (tracks[:, 0] + tracks[:, 2]) / 2.0, (tracks[:, 1] + tracks[:, 3]) / 2.0
|
||||
dx = CX1 - CX2
|
||||
dy = CY1 - CY2
|
||||
norm = np.sqrt(dx**2 + dy**2) + 1e-6
|
||||
dx = dx / norm
|
||||
dy = dy / norm
|
||||
return dy, dx # size: num_track x num_det
|
||||
|
||||
|
||||
def linear_assignment(cost_matrix):
|
||||
try:
|
||||
import lap
|
||||
|
||||
_, x, y = lap.lapjv(cost_matrix, extend_cost=True)
|
||||
return np.array([[y[i], i] for i in x if i >= 0]) #
|
||||
except ImportError:
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
x, y = linear_sum_assignment(cost_matrix)
|
||||
return np.array(list(zip(x, y)))
|
||||
|
||||
|
||||
def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3):
|
||||
"""
|
||||
Assigns detections to tracked object (both represented as bounding boxes)
|
||||
Returns 3 lists of matches, unmatched_detections and unmatched_trackers
|
||||
"""
|
||||
if len(trackers) == 0:
|
||||
return (
|
||||
np.empty((0, 2), dtype=int),
|
||||
np.arange(len(detections)),
|
||||
np.empty((0, 5), dtype=int),
|
||||
)
|
||||
|
||||
iou_matrix = iou_batch(detections, trackers)
|
||||
|
||||
if min(iou_matrix.shape) > 0:
|
||||
a = (iou_matrix > iou_threshold).astype(np.int32)
|
||||
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
||||
matched_indices = np.stack(np.where(a), axis=1)
|
||||
else:
|
||||
matched_indices = linear_assignment(-iou_matrix)
|
||||
else:
|
||||
matched_indices = np.empty(shape=(0, 2))
|
||||
|
||||
unmatched_detections = []
|
||||
for d, det in enumerate(detections):
|
||||
if d not in matched_indices[:, 0]:
|
||||
unmatched_detections.append(d)
|
||||
unmatched_trackers = []
|
||||
for t, trk in enumerate(trackers):
|
||||
if t not in matched_indices[:, 1]:
|
||||
unmatched_trackers.append(t)
|
||||
|
||||
# filter out matched with low IOU
|
||||
matches = []
|
||||
for m in matched_indices:
|
||||
if iou_matrix[m[0], m[1]] < iou_threshold:
|
||||
unmatched_detections.append(m[0])
|
||||
unmatched_trackers.append(m[1])
|
||||
else:
|
||||
matches.append(m.reshape(1, 2))
|
||||
if len(matches) == 0:
|
||||
matches = np.empty((0, 2), dtype=int)
|
||||
else:
|
||||
matches = np.concatenate(matches, axis=0)
|
||||
|
||||
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
|
||||
|
||||
|
||||
def compute_aw_max_metric(emb_cost, w_association_emb, bottom=0.5):
|
||||
w_emb = np.full_like(emb_cost, w_association_emb)
|
||||
|
||||
for idx in range(emb_cost.shape[0]):
|
||||
inds = np.argsort(-emb_cost[idx])
|
||||
# If there's less than two matches, just keep original weight
|
||||
if len(inds) < 2:
|
||||
continue
|
||||
if emb_cost[idx, inds[0]] == 0:
|
||||
row_weight = 0
|
||||
else:
|
||||
row_weight = 1 - max((emb_cost[idx, inds[1]] / emb_cost[idx, inds[0]]) - bottom, 0) / (1 - bottom)
|
||||
w_emb[idx] *= row_weight
|
||||
|
||||
for idj in range(emb_cost.shape[1]):
|
||||
inds = np.argsort(-emb_cost[:, idj])
|
||||
# If there's less than two matches, just keep original weight
|
||||
if len(inds) < 2:
|
||||
continue
|
||||
if emb_cost[inds[0], idj] == 0:
|
||||
col_weight = 0
|
||||
else:
|
||||
col_weight = 1 - max((emb_cost[inds[1], idj] / emb_cost[inds[0], idj]) - bottom, 0) / (1 - bottom)
|
||||
w_emb[:, idj] *= col_weight
|
||||
|
||||
return w_emb * emb_cost
|
||||
|
||||
|
||||
def associate(
|
||||
detections, trackers, iou_threshold, velocities, previous_obs, vdc_weight, emb_cost, w_assoc_emb, aw_off, aw_param
|
||||
):
|
||||
if len(trackers) == 0:
|
||||
return (
|
||||
np.empty((0, 2), dtype=int),
|
||||
np.arange(len(detections)),
|
||||
np.empty((0, 5), dtype=int),
|
||||
)
|
||||
|
||||
Y, X = speed_direction_batch(detections, previous_obs)
|
||||
inertia_Y, inertia_X = velocities[:, 0], velocities[:, 1]
|
||||
inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1)
|
||||
inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1)
|
||||
diff_angle_cos = inertia_X * X + inertia_Y * Y
|
||||
diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1)
|
||||
diff_angle = np.arccos(diff_angle_cos)
|
||||
diff_angle = (np.pi / 2.0 - np.abs(diff_angle)) / np.pi
|
||||
|
||||
valid_mask = np.ones(previous_obs.shape[0])
|
||||
valid_mask[np.where(previous_obs[:, 4] < 0)] = 0
|
||||
|
||||
iou_matrix = iou_batch(detections, trackers)
|
||||
scores = np.repeat(detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1)
|
||||
# iou_matrix = iou_matrix * scores # a trick sometiems works, we don't encourage this
|
||||
valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1)
|
||||
|
||||
angle_diff_cost = (valid_mask * diff_angle) * vdc_weight
|
||||
angle_diff_cost = angle_diff_cost.T
|
||||
angle_diff_cost = angle_diff_cost * scores
|
||||
|
||||
if min(iou_matrix.shape) > 0:
|
||||
a = (iou_matrix > iou_threshold).astype(np.int32)
|
||||
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
||||
matched_indices = np.stack(np.where(a), axis=1)
|
||||
else:
|
||||
if emb_cost is None:
|
||||
emb_cost = 0
|
||||
else:
|
||||
emb_cost = emb_cost.cpu().numpy()
|
||||
emb_cost[iou_matrix <= 0] = 0
|
||||
if not aw_off:
|
||||
emb_cost = compute_aw_max_metric(emb_cost, w_assoc_emb, bottom=aw_param)
|
||||
else:
|
||||
emb_cost *= w_assoc_emb
|
||||
|
||||
final_cost = -(iou_matrix + angle_diff_cost + emb_cost)
|
||||
matched_indices = linear_assignment(final_cost)
|
||||
else:
|
||||
matched_indices = np.empty(shape=(0, 2))
|
||||
|
||||
unmatched_detections = []
|
||||
for d, det in enumerate(detections):
|
||||
if d not in matched_indices[:, 0]:
|
||||
unmatched_detections.append(d)
|
||||
unmatched_trackers = []
|
||||
for t, trk in enumerate(trackers):
|
||||
if t not in matched_indices[:, 1]:
|
||||
unmatched_trackers.append(t)
|
||||
|
||||
# filter out matched with low IOU
|
||||
matches = []
|
||||
for m in matched_indices:
|
||||
if iou_matrix[m[0], m[1]] < iou_threshold:
|
||||
unmatched_detections.append(m[0])
|
||||
unmatched_trackers.append(m[1])
|
||||
else:
|
||||
matches.append(m.reshape(1, 2))
|
||||
if len(matches) == 0:
|
||||
matches = np.empty((0, 2), dtype=int)
|
||||
else:
|
||||
matches = np.concatenate(matches, axis=0)
|
||||
|
||||
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
|
||||
|
||||
|
||||
def associate_kitti(detections, trackers, det_cates, iou_threshold, velocities, previous_obs, vdc_weight):
|
||||
if len(trackers) == 0:
|
||||
return (
|
||||
np.empty((0, 2), dtype=int),
|
||||
np.arange(len(detections)),
|
||||
np.empty((0, 5), dtype=int),
|
||||
)
|
||||
|
||||
"""
|
||||
Cost from the velocity direction consistency
|
||||
"""
|
||||
Y, X = speed_direction_batch(detections, previous_obs)
|
||||
inertia_Y, inertia_X = velocities[:, 0], velocities[:, 1]
|
||||
inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1)
|
||||
inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1)
|
||||
diff_angle_cos = inertia_X * X + inertia_Y * Y
|
||||
diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1)
|
||||
diff_angle = np.arccos(diff_angle_cos)
|
||||
diff_angle = (np.pi / 2.0 - np.abs(diff_angle)) / np.pi
|
||||
|
||||
valid_mask = np.ones(previous_obs.shape[0])
|
||||
valid_mask[np.where(previous_obs[:, 4] < 0)] = 0
|
||||
valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1)
|
||||
|
||||
scores = np.repeat(detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1)
|
||||
angle_diff_cost = (valid_mask * diff_angle) * vdc_weight
|
||||
angle_diff_cost = angle_diff_cost.T
|
||||
angle_diff_cost = angle_diff_cost * scores
|
||||
|
||||
"""
|
||||
Cost from IoU
|
||||
"""
|
||||
iou_matrix = iou_batch(detections, trackers)
|
||||
|
||||
"""
|
||||
With multiple categories, generate the cost for catgory mismatch
|
||||
"""
|
||||
num_dets = detections.shape[0]
|
||||
num_trk = trackers.shape[0]
|
||||
cate_matrix = np.zeros((num_dets, num_trk))
|
||||
for i in range(num_dets):
|
||||
for j in range(num_trk):
|
||||
if det_cates[i] != trackers[j, 4]:
|
||||
cate_matrix[i][j] = -1e6
|
||||
|
||||
cost_matrix = -iou_matrix - angle_diff_cost - cate_matrix
|
||||
|
||||
if min(iou_matrix.shape) > 0:
|
||||
a = (iou_matrix > iou_threshold).astype(np.int32)
|
||||
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
||||
matched_indices = np.stack(np.where(a), axis=1)
|
||||
else:
|
||||
matched_indices = linear_assignment(cost_matrix)
|
||||
else:
|
||||
matched_indices = np.empty(shape=(0, 2))
|
||||
|
||||
unmatched_detections = []
|
||||
for d, det in enumerate(detections):
|
||||
if d not in matched_indices[:, 0]:
|
||||
unmatched_detections.append(d)
|
||||
unmatched_trackers = []
|
||||
for t, trk in enumerate(trackers):
|
||||
if t not in matched_indices[:, 1]:
|
||||
unmatched_trackers.append(t)
|
||||
|
||||
# filter out matched with low IOU
|
||||
matches = []
|
||||
for m in matched_indices:
|
||||
if iou_matrix[m[0], m[1]] < iou_threshold:
|
||||
unmatched_detections.append(m[0])
|
||||
unmatched_trackers.append(m[1])
|
||||
else:
|
||||
matches.append(m.reshape(1, 2))
|
||||
if len(matches) == 0:
|
||||
matches = np.empty((0, 2), dtype=int)
|
||||
else:
|
||||
matches = np.concatenate(matches, axis=0)
|
||||
|
||||
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
|
||||
170
feeder/trackers/deepocsort/cmc.py
Normal file
170
feeder/trackers/deepocsort/cmc.py
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
import pdb
|
||||
import pickle
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class CMCComputer:
|
||||
def __init__(self, minimum_features=10, method="sparse"):
|
||||
assert method in ["file", "sparse", "sift"]
|
||||
|
||||
os.makedirs("./cache", exist_ok=True)
|
||||
self.cache_path = "./cache/affine_ocsort.pkl"
|
||||
self.cache = {}
|
||||
if os.path.exists(self.cache_path):
|
||||
with open(self.cache_path, "rb") as fp:
|
||||
self.cache = pickle.load(fp)
|
||||
self.minimum_features = minimum_features
|
||||
self.prev_img = None
|
||||
self.prev_desc = None
|
||||
self.sparse_flow_param = dict(
|
||||
maxCorners=3000,
|
||||
qualityLevel=0.01,
|
||||
minDistance=1,
|
||||
blockSize=3,
|
||||
useHarrisDetector=False,
|
||||
k=0.04,
|
||||
)
|
||||
self.file_computed = {}
|
||||
|
||||
self.comp_function = None
|
||||
if method == "sparse":
|
||||
self.comp_function = self._affine_sparse_flow
|
||||
elif method == "sift":
|
||||
self.comp_function = self._affine_sift
|
||||
# Same BoT-SORT CMC arrays
|
||||
elif method == "file":
|
||||
self.comp_function = self._affine_file
|
||||
self.file_affines = {}
|
||||
# Maps from tag name to file name
|
||||
self.file_names = {}
|
||||
|
||||
# All the ablation file names
|
||||
for f_name in os.listdir("./cache/cmc_files/MOT17_ablation/"):
|
||||
# The tag that'll be passed into compute_affine based on image name
|
||||
tag = f_name.replace("GMC-", "").replace(".txt", "") + "-FRCNN"
|
||||
f_name = os.path.join("./cache/cmc_files/MOT17_ablation/", f_name)
|
||||
self.file_names[tag] = f_name
|
||||
for f_name in os.listdir("./cache/cmc_files/MOT20_ablation/"):
|
||||
tag = f_name.replace("GMC-", "").replace(".txt", "")
|
||||
f_name = os.path.join("./cache/cmc_files/MOT20_ablation/", f_name)
|
||||
self.file_names[tag] = f_name
|
||||
|
||||
# All the test file names
|
||||
for f_name in os.listdir("./cache/cmc_files/MOTChallenge/"):
|
||||
tag = f_name.replace("GMC-", "").replace(".txt", "")
|
||||
if "MOT17" in tag:
|
||||
tag = tag + "-FRCNN"
|
||||
# If it's an ablation one (not test) don't overwrite it
|
||||
if tag in self.file_names:
|
||||
continue
|
||||
f_name = os.path.join("./cache/cmc_files/MOTChallenge/", f_name)
|
||||
self.file_names[tag] = f_name
|
||||
|
||||
def compute_affine(self, img, bbox, tag):
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
if tag in self.cache:
|
||||
A = self.cache[tag]
|
||||
return A
|
||||
mask = np.ones_like(img, dtype=np.uint8)
|
||||
if bbox.shape[0] > 0:
|
||||
bbox = np.round(bbox).astype(np.int32)
|
||||
bbox[bbox < 0] = 0
|
||||
for bb in bbox:
|
||||
mask[bb[1] : bb[3], bb[0] : bb[2]] = 0
|
||||
|
||||
A = self.comp_function(img, mask, tag)
|
||||
self.cache[tag] = A
|
||||
|
||||
return A
|
||||
|
||||
def _load_file(self, name):
|
||||
affines = []
|
||||
with open(self.file_names[name], "r") as fp:
|
||||
for line in fp:
|
||||
tokens = [float(f) for f in line.split("\t")[1:7]]
|
||||
A = np.eye(2, 3)
|
||||
A[0, 0] = tokens[0]
|
||||
A[0, 1] = tokens[1]
|
||||
A[0, 2] = tokens[2]
|
||||
A[1, 0] = tokens[3]
|
||||
A[1, 1] = tokens[4]
|
||||
A[1, 2] = tokens[5]
|
||||
affines.append(A)
|
||||
self.file_affines[name] = affines
|
||||
|
||||
def _affine_file(self, frame, mask, tag):
|
||||
name, num = tag.split(":")
|
||||
if name not in self.file_affines:
|
||||
self._load_file(name)
|
||||
if name not in self.file_affines:
|
||||
raise RuntimeError("Error loading file affines for CMC.")
|
||||
|
||||
return self.file_affines[name][int(num) - 1]
|
||||
|
||||
def _affine_sift(self, frame, mask, tag):
|
||||
A = np.eye(2, 3)
|
||||
detector = cv2.SIFT_create()
|
||||
kp, desc = detector.detectAndCompute(frame, mask)
|
||||
if self.prev_desc is None:
|
||||
self.prev_desc = [kp, desc]
|
||||
return A
|
||||
if desc.shape[0] < self.minimum_features or self.prev_desc[1].shape[0] < self.minimum_features:
|
||||
return A
|
||||
|
||||
bf = cv2.BFMatcher(cv2.NORM_L2)
|
||||
matches = bf.knnMatch(self.prev_desc[1], desc, k=2)
|
||||
good = []
|
||||
for m, n in matches:
|
||||
if m.distance < 0.7 * n.distance:
|
||||
good.append(m)
|
||||
|
||||
if len(good) > self.minimum_features:
|
||||
src_pts = np.float32([self.prev_desc[0][m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
|
||||
dst_pts = np.float32([kp[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
|
||||
A, _ = cv2.estimateAffinePartial2D(src_pts, dst_pts, method=cv2.RANSAC)
|
||||
else:
|
||||
print("Warning: not enough matching points")
|
||||
if A is None:
|
||||
A = np.eye(2, 3)
|
||||
|
||||
self.prev_desc = [kp, desc]
|
||||
return A
|
||||
|
||||
def _affine_sparse_flow(self, frame, mask, tag):
|
||||
# Initialize
|
||||
A = np.eye(2, 3)
|
||||
|
||||
# find the keypoints
|
||||
keypoints = cv2.goodFeaturesToTrack(frame, mask=mask, **self.sparse_flow_param)
|
||||
|
||||
# Handle first frame
|
||||
if self.prev_img is None:
|
||||
self.prev_img = frame
|
||||
self.prev_desc = keypoints
|
||||
return A
|
||||
|
||||
matched_kp, status, err = cv2.calcOpticalFlowPyrLK(self.prev_img, frame, self.prev_desc, None)
|
||||
matched_kp = matched_kp.reshape(-1, 2)
|
||||
status = status.reshape(-1)
|
||||
prev_points = self.prev_desc.reshape(-1, 2)
|
||||
prev_points = prev_points[status]
|
||||
curr_points = matched_kp[status]
|
||||
|
||||
# Find rigid matrix
|
||||
if prev_points.shape[0] > self.minimum_features:
|
||||
A, _ = cv2.estimateAffinePartial2D(prev_points, curr_points, method=cv2.RANSAC)
|
||||
else:
|
||||
print("Warning: not enough matching points")
|
||||
if A is None:
|
||||
A = np.eye(2, 3)
|
||||
|
||||
self.prev_img = frame
|
||||
self.prev_desc = keypoints
|
||||
return A
|
||||
|
||||
def dump_cache(self):
|
||||
with open(self.cache_path, "wb") as fp:
|
||||
pickle.dump(self.cache, fp)
|
||||
12
feeder/trackers/deepocsort/configs/deepocsort.yaml
Normal file
12
feeder/trackers/deepocsort/configs/deepocsort.yaml
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
# Trial number: 137
|
||||
# HOTA, MOTA, IDF1: [55.567]
|
||||
deepocsort:
|
||||
asso_func: giou
|
||||
conf_thres: 0.5122620708221085
|
||||
delta_t: 1
|
||||
det_thresh: 0
|
||||
inertia: 0.3941737016672115
|
||||
iou_thresh: 0.22136877277096445
|
||||
max_age: 50
|
||||
min_hits: 1
|
||||
use_byte: false
|
||||
116
feeder/trackers/deepocsort/embedding.py
Normal file
116
feeder/trackers/deepocsort/embedding.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
import pdb
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
import cv2
|
||||
import torchvision
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
class EmbeddingComputer:
|
||||
def __init__(self, dataset):
|
||||
self.model = None
|
||||
self.dataset = dataset
|
||||
self.crop_size = (128, 384)
|
||||
os.makedirs("./cache/embeddings/", exist_ok=True)
|
||||
self.cache_path = "./cache/embeddings/{}_embedding.pkl"
|
||||
self.cache = {}
|
||||
self.cache_name = ""
|
||||
|
||||
def load_cache(self, path):
|
||||
self.cache_name = path
|
||||
cache_path = self.cache_path.format(path)
|
||||
if os.path.exists(cache_path):
|
||||
with open(cache_path, "rb") as fp:
|
||||
self.cache = pickle.load(fp)
|
||||
|
||||
def compute_embedding(self, img, bbox, tag, is_numpy=True):
|
||||
if self.cache_name != tag.split(":")[0]:
|
||||
self.load_cache(tag.split(":")[0])
|
||||
|
||||
if tag in self.cache:
|
||||
embs = self.cache[tag]
|
||||
if embs.shape[0] != bbox.shape[0]:
|
||||
raise RuntimeError(
|
||||
"ERROR: The number of cached embeddings don't match the "
|
||||
"number of detections.\nWas the detector model changed? Delete cache if so."
|
||||
)
|
||||
return embs
|
||||
|
||||
if self.model is None:
|
||||
self.initialize_model()
|
||||
|
||||
# Make sure bbox is within image frame
|
||||
if is_numpy:
|
||||
h, w = img.shape[:2]
|
||||
else:
|
||||
h, w = img.shape[2:]
|
||||
results = np.round(bbox).astype(np.int32)
|
||||
results[:, 0] = results[:, 0].clip(0, w)
|
||||
results[:, 1] = results[:, 1].clip(0, h)
|
||||
results[:, 2] = results[:, 2].clip(0, w)
|
||||
results[:, 3] = results[:, 3].clip(0, h)
|
||||
|
||||
# Generate all the crops
|
||||
crops = []
|
||||
for p in results:
|
||||
if is_numpy:
|
||||
crop = img[p[1] : p[3], p[0] : p[2]]
|
||||
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
||||
crop = cv2.resize(crop, self.crop_size, interpolation=cv2.INTER_LINEAR)
|
||||
crop = torch.as_tensor(crop.astype("float32").transpose(2, 0, 1))
|
||||
crop = crop.unsqueeze(0)
|
||||
else:
|
||||
crop = img[:, :, p[1] : p[3], p[0] : p[2]]
|
||||
crop = torchvision.transforms.functional.resize(crop, self.crop_size)
|
||||
|
||||
crops.append(crop)
|
||||
|
||||
crops = torch.cat(crops, dim=0)
|
||||
|
||||
# Create embeddings and l2 normalize them
|
||||
with torch.no_grad():
|
||||
crops = crops.cuda()
|
||||
crops = crops.half()
|
||||
embs = self.model(crops)
|
||||
embs = torch.nn.functional.normalize(embs)
|
||||
embs = embs.cpu().numpy()
|
||||
|
||||
self.cache[tag] = embs
|
||||
return embs
|
||||
|
||||
def initialize_model(self):
|
||||
"""
|
||||
model = torchreid.models.build_model(name="osnet_ain_x1_0", num_classes=2510, loss="softmax", pretrained=False)
|
||||
sd = torch.load("external/weights/osnet_ain_ms_d_c.pth.tar")["state_dict"]
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in sd.items():
|
||||
name = k[7:] # remove `module.`
|
||||
new_state_dict[name] = v
|
||||
# load params
|
||||
model.load_state_dict(new_state_dict)
|
||||
model.eval()
|
||||
model.cuda()
|
||||
"""
|
||||
if self.dataset == "mot17":
|
||||
path = "external/weights/mot17_sbs_S50.pth"
|
||||
elif self.dataset == "mot20":
|
||||
path = "external/weights/mot20_sbs_S50.pth"
|
||||
elif self.dataset == "dance":
|
||||
path = None
|
||||
else:
|
||||
raise RuntimeError("Need the path for a new ReID model.")
|
||||
|
||||
model = FastReID(path)
|
||||
model.eval()
|
||||
model.cuda()
|
||||
model.half()
|
||||
self.model = model
|
||||
|
||||
def dump_cache(self):
|
||||
if self.cache_name:
|
||||
with open(self.cache_path.format(self.cache_name), "wb") as fp:
|
||||
pickle.dump(self.cache, fp)
|
||||
1636
feeder/trackers/deepocsort/kalmanfilter.py
Normal file
1636
feeder/trackers/deepocsort/kalmanfilter.py
Normal file
File diff suppressed because it is too large
Load diff
670
feeder/trackers/deepocsort/ocsort.py
Normal file
670
feeder/trackers/deepocsort/ocsort.py
Normal file
|
|
@ -0,0 +1,670 @@
|
|||
"""
|
||||
This script is adopted from the SORT script by Alex Bewley alex@bewley.ai
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import pdb
|
||||
import pickle
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
import numpy as np
|
||||
from .association import *
|
||||
from .embedding import EmbeddingComputer
|
||||
from .cmc import CMCComputer
|
||||
from reid_multibackend import ReIDDetectMultiBackend
|
||||
|
||||
|
||||
def k_previous_obs(observations, cur_age, k):
|
||||
if len(observations) == 0:
|
||||
return [-1, -1, -1, -1, -1]
|
||||
for i in range(k):
|
||||
dt = k - i
|
||||
if cur_age - dt in observations:
|
||||
return observations[cur_age - dt]
|
||||
max_age = max(observations.keys())
|
||||
return observations[max_age]
|
||||
|
||||
|
||||
def convert_bbox_to_z(bbox):
|
||||
"""
|
||||
Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
|
||||
[x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
|
||||
the aspect ratio
|
||||
"""
|
||||
w = bbox[2] - bbox[0]
|
||||
h = bbox[3] - bbox[1]
|
||||
x = bbox[0] + w / 2.0
|
||||
y = bbox[1] + h / 2.0
|
||||
s = w * h # scale is just area
|
||||
r = w / float(h + 1e-6)
|
||||
return np.array([x, y, s, r]).reshape((4, 1))
|
||||
|
||||
|
||||
def convert_bbox_to_z_new(bbox):
|
||||
w = bbox[2] - bbox[0]
|
||||
h = bbox[3] - bbox[1]
|
||||
x = bbox[0] + w / 2.0
|
||||
y = bbox[1] + h / 2.0
|
||||
return np.array([x, y, w, h]).reshape((4, 1))
|
||||
|
||||
|
||||
def convert_x_to_bbox_new(x):
|
||||
x, y, w, h = x.reshape(-1)[:4]
|
||||
return np.array([x - w / 2, y - h / 2, x + w / 2, y + h / 2]).reshape(1, 4)
|
||||
|
||||
|
||||
def convert_x_to_bbox(x, score=None):
|
||||
"""
|
||||
Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
|
||||
[x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
|
||||
"""
|
||||
w = np.sqrt(x[2] * x[3])
|
||||
h = x[2] / w
|
||||
if score == None:
|
||||
return np.array([x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0]).reshape((1, 4))
|
||||
else:
|
||||
return np.array([x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0, score]).reshape((1, 5))
|
||||
|
||||
|
||||
def speed_direction(bbox1, bbox2):
|
||||
cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0
|
||||
cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0
|
||||
speed = np.array([cy2 - cy1, cx2 - cx1])
|
||||
norm = np.sqrt((cy2 - cy1) ** 2 + (cx2 - cx1) ** 2) + 1e-6
|
||||
return speed / norm
|
||||
|
||||
|
||||
def new_kf_process_noise(w, h, p=1 / 20, v=1 / 160):
|
||||
Q = np.diag(
|
||||
((p * w) ** 2, (p * h) ** 2, (p * w) ** 2, (p * h) ** 2, (v * w) ** 2, (v * h) ** 2, (v * w) ** 2, (v * h) ** 2)
|
||||
)
|
||||
return Q
|
||||
|
||||
|
||||
def new_kf_measurement_noise(w, h, m=1 / 20):
|
||||
w_var = (m * w) ** 2
|
||||
h_var = (m * h) ** 2
|
||||
R = np.diag((w_var, h_var, w_var, h_var))
|
||||
return R
|
||||
|
||||
|
||||
class KalmanBoxTracker(object):
|
||||
"""
|
||||
This class represents the internal state of individual tracked objects observed as bbox.
|
||||
"""
|
||||
|
||||
count = 0
|
||||
|
||||
def __init__(self, bbox, cls, delta_t=3, orig=False, emb=None, alpha=0, new_kf=False):
|
||||
"""
|
||||
Initialises a tracker using initial bounding box.
|
||||
|
||||
"""
|
||||
# define constant velocity model
|
||||
if not orig:
|
||||
from .kalmanfilter import KalmanFilterNew as KalmanFilter
|
||||
else:
|
||||
from filterpy.kalman import KalmanFilter
|
||||
self.cls = cls
|
||||
self.conf = bbox[-1]
|
||||
self.new_kf = new_kf
|
||||
if new_kf:
|
||||
self.kf = KalmanFilter(dim_x=8, dim_z=4)
|
||||
self.kf.F = np.array(
|
||||
[
|
||||
# x y w h x' y' w' h'
|
||||
[1, 0, 0, 0, 1, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0, 1, 0, 0],
|
||||
[0, 0, 1, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 1, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
]
|
||||
)
|
||||
self.kf.H = np.array(
|
||||
[
|
||||
[1, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 1, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 1, 0, 0, 0, 0],
|
||||
]
|
||||
)
|
||||
_, _, w, h = convert_bbox_to_z_new(bbox).reshape(-1)
|
||||
self.kf.P = new_kf_process_noise(w, h)
|
||||
self.kf.P[:4, :4] *= 4
|
||||
self.kf.P[4:, 4:] *= 100
|
||||
# Process and measurement uncertainty happen in functions
|
||||
self.bbox_to_z_func = convert_bbox_to_z_new
|
||||
self.x_to_bbox_func = convert_x_to_bbox_new
|
||||
else:
|
||||
self.kf = KalmanFilter(dim_x=7, dim_z=4)
|
||||
self.kf.F = np.array(
|
||||
[
|
||||
# x y s r x' y' s'
|
||||
[1, 0, 0, 0, 1, 0, 0],
|
||||
[0, 1, 0, 0, 0, 1, 0],
|
||||
[0, 0, 1, 0, 0, 0, 1],
|
||||
[0, 0, 0, 1, 0, 0, 0],
|
||||
[0, 0, 0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 1],
|
||||
]
|
||||
)
|
||||
self.kf.H = np.array(
|
||||
[
|
||||
[1, 0, 0, 0, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0, 0, 0],
|
||||
[0, 0, 1, 0, 0, 0, 0],
|
||||
[0, 0, 0, 1, 0, 0, 0],
|
||||
]
|
||||
)
|
||||
self.kf.R[2:, 2:] *= 10.0
|
||||
self.kf.P[4:, 4:] *= 1000.0 # give high uncertainty to the unobservable initial velocities
|
||||
self.kf.P *= 10.0
|
||||
self.kf.Q[-1, -1] *= 0.01
|
||||
self.kf.Q[4:, 4:] *= 0.01
|
||||
self.bbox_to_z_func = convert_bbox_to_z
|
||||
self.x_to_bbox_func = convert_x_to_bbox
|
||||
|
||||
self.kf.x[:4] = self.bbox_to_z_func(bbox)
|
||||
|
||||
self.time_since_update = 0
|
||||
self.id = KalmanBoxTracker.count
|
||||
KalmanBoxTracker.count += 1
|
||||
self.history = []
|
||||
self.hits = 0
|
||||
self.hit_streak = 0
|
||||
self.age = 0
|
||||
"""
|
||||
NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of
|
||||
function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a
|
||||
fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), let's bear it for now.
|
||||
"""
|
||||
# Used for OCR
|
||||
self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder
|
||||
# Used to output track after min_hits reached
|
||||
self.history_observations = []
|
||||
# Used for velocity
|
||||
self.observations = dict()
|
||||
self.velocity = None
|
||||
self.delta_t = delta_t
|
||||
|
||||
self.emb = emb
|
||||
|
||||
self.frozen = False
|
||||
|
||||
def update(self, bbox, cls):
|
||||
"""
|
||||
Updates the state vector with observed bbox.
|
||||
"""
|
||||
if bbox is not None:
|
||||
self.frozen = False
|
||||
self.cls = cls
|
||||
if self.last_observation.sum() >= 0: # no previous observation
|
||||
previous_box = None
|
||||
for dt in range(self.delta_t, 0, -1):
|
||||
if self.age - dt in self.observations:
|
||||
previous_box = self.observations[self.age - dt]
|
||||
break
|
||||
if previous_box is None:
|
||||
previous_box = self.last_observation
|
||||
"""
|
||||
Estimate the track speed direction with observations \Delta t steps away
|
||||
"""
|
||||
self.velocity = speed_direction(previous_box, bbox)
|
||||
"""
|
||||
Insert new observations. This is a ugly way to maintain both self.observations
|
||||
and self.history_observations. Bear it for the moment.
|
||||
"""
|
||||
self.last_observation = bbox
|
||||
self.observations[self.age] = bbox
|
||||
self.history_observations.append(bbox)
|
||||
|
||||
self.time_since_update = 0
|
||||
self.history = []
|
||||
self.hits += 1
|
||||
self.hit_streak += 1
|
||||
if self.new_kf:
|
||||
R = new_kf_measurement_noise(self.kf.x[2, 0], self.kf.x[3, 0])
|
||||
self.kf.update(self.bbox_to_z_func(bbox), R=R)
|
||||
else:
|
||||
self.kf.update(self.bbox_to_z_func(bbox))
|
||||
else:
|
||||
self.kf.update(bbox)
|
||||
self.frozen = True
|
||||
|
||||
def update_emb(self, emb, alpha=0.9):
|
||||
self.emb = alpha * self.emb + (1 - alpha) * emb
|
||||
self.emb /= np.linalg.norm(self.emb)
|
||||
|
||||
def get_emb(self):
|
||||
return self.emb.cpu()
|
||||
|
||||
def apply_affine_correction(self, affine):
|
||||
m = affine[:, :2]
|
||||
t = affine[:, 2].reshape(2, 1)
|
||||
# For OCR
|
||||
if self.last_observation.sum() > 0:
|
||||
ps = self.last_observation[:4].reshape(2, 2).T
|
||||
ps = m @ ps + t
|
||||
self.last_observation[:4] = ps.T.reshape(-1)
|
||||
|
||||
# Apply to each box in the range of velocity computation
|
||||
for dt in range(self.delta_t, -1, -1):
|
||||
if self.age - dt in self.observations:
|
||||
ps = self.observations[self.age - dt][:4].reshape(2, 2).T
|
||||
ps = m @ ps + t
|
||||
self.observations[self.age - dt][:4] = ps.T.reshape(-1)
|
||||
|
||||
# Also need to change kf state, but might be frozen
|
||||
self.kf.apply_affine_correction(m, t, self.new_kf)
|
||||
|
||||
def predict(self):
|
||||
"""
|
||||
Advances the state vector and returns the predicted bounding box estimate.
|
||||
"""
|
||||
# Don't allow negative bounding boxes
|
||||
if self.new_kf:
|
||||
if self.kf.x[2] + self.kf.x[6] <= 0:
|
||||
self.kf.x[6] = 0
|
||||
if self.kf.x[3] + self.kf.x[7] <= 0:
|
||||
self.kf.x[7] = 0
|
||||
|
||||
# Stop velocity, will update in kf during OOS
|
||||
if self.frozen:
|
||||
self.kf.x[6] = self.kf.x[7] = 0
|
||||
Q = new_kf_process_noise(self.kf.x[2, 0], self.kf.x[3, 0])
|
||||
else:
|
||||
if (self.kf.x[6] + self.kf.x[2]) <= 0:
|
||||
self.kf.x[6] *= 0.0
|
||||
Q = None
|
||||
|
||||
self.kf.predict(Q=Q)
|
||||
self.age += 1
|
||||
if self.time_since_update > 0:
|
||||
self.hit_streak = 0
|
||||
self.time_since_update += 1
|
||||
self.history.append(self.x_to_bbox_func(self.kf.x))
|
||||
return self.history[-1]
|
||||
|
||||
def get_state(self):
|
||||
"""
|
||||
Returns the current bounding box estimate.
|
||||
"""
|
||||
return self.x_to_bbox_func(self.kf.x)
|
||||
|
||||
def mahalanobis(self, bbox):
|
||||
"""Should be run after a predict() call for accuracy."""
|
||||
return self.kf.md_for_measurement(self.bbox_to_z_func(bbox))
|
||||
|
||||
|
||||
"""
|
||||
We support multiple ways for association cost calculation, by default
|
||||
we use IoU. GIoU may have better performance in some situations. We note
|
||||
that we hardly normalize the cost by all methods to (0,1) which may not be
|
||||
the best practice.
|
||||
"""
|
||||
ASSO_FUNCS = {
|
||||
"iou": iou_batch,
|
||||
"giou": giou_batch,
|
||||
"ciou": ciou_batch,
|
||||
"diou": diou_batch,
|
||||
"ct_dist": ct_dist,
|
||||
}
|
||||
|
||||
|
||||
class OCSort(object):
|
||||
def __init__(
|
||||
self,
|
||||
model_weights,
|
||||
device,
|
||||
fp16,
|
||||
det_thresh,
|
||||
max_age=30,
|
||||
min_hits=3,
|
||||
iou_threshold=0.3,
|
||||
delta_t=3,
|
||||
asso_func="iou",
|
||||
inertia=0.2,
|
||||
w_association_emb=0.75,
|
||||
alpha_fixed_emb=0.95,
|
||||
aw_param=0.5,
|
||||
embedding_off=False,
|
||||
cmc_off=False,
|
||||
aw_off=False,
|
||||
new_kf_off=False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Sets key parameters for SORT
|
||||
"""
|
||||
self.max_age = max_age
|
||||
self.min_hits = min_hits
|
||||
self.iou_threshold = iou_threshold
|
||||
self.trackers = []
|
||||
self.frame_count = 0
|
||||
self.det_thresh = det_thresh
|
||||
self.delta_t = delta_t
|
||||
self.asso_func = ASSO_FUNCS[asso_func]
|
||||
self.inertia = inertia
|
||||
self.w_association_emb = w_association_emb
|
||||
self.alpha_fixed_emb = alpha_fixed_emb
|
||||
self.aw_param = aw_param
|
||||
KalmanBoxTracker.count = 0
|
||||
|
||||
self.embedder = ReIDDetectMultiBackend(weights=model_weights, device=device, fp16=fp16)
|
||||
self.cmc = CMCComputer()
|
||||
self.embedding_off = embedding_off
|
||||
self.cmc_off = cmc_off
|
||||
self.aw_off = aw_off
|
||||
self.new_kf_off = new_kf_off
|
||||
|
||||
def update(self, dets, img_numpy, tag='blub'):
|
||||
"""
|
||||
Params:
|
||||
dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...]
|
||||
Requires: this method must be called once for each frame even with empty detections (use np.empty((0, 5)) for frames without detections).
|
||||
Returns the a similar array, where the last column is the object ID.
|
||||
NOTE: The number of objects returned may differ from the number of detections provided.
|
||||
"""
|
||||
xyxys = dets[:, 0:4]
|
||||
scores = dets[:, 4]
|
||||
clss = dets[:, 5]
|
||||
|
||||
classes = clss.numpy()
|
||||
xyxys = xyxys.numpy()
|
||||
scores = scores.numpy()
|
||||
|
||||
dets = dets[:, 0:6].numpy()
|
||||
remain_inds = scores > self.det_thresh
|
||||
dets = dets[remain_inds]
|
||||
self.height, self.width = img_numpy.shape[:2]
|
||||
|
||||
# Rescale
|
||||
#scale = min(img_tensor.shape[2] / img_numpy.shape[0], img_tensor.shape[3] / img_numpy.shape[1])
|
||||
#dets[:, :4] /= scale
|
||||
|
||||
# Embedding
|
||||
if self.embedding_off or dets.shape[0] == 0:
|
||||
dets_embs = np.ones((dets.shape[0], 1))
|
||||
else:
|
||||
# (Ndets x X) [512, 1024, 2048]
|
||||
#dets_embs = self.embedder.compute_embedding(img_numpy, dets[:, :4], tag)
|
||||
dets_embs = self._get_features(dets[:, :4], img_numpy)
|
||||
|
||||
# CMC
|
||||
if not self.cmc_off:
|
||||
transform = self.cmc.compute_affine(img_numpy, dets[:, :4], tag)
|
||||
for trk in self.trackers:
|
||||
trk.apply_affine_correction(transform)
|
||||
|
||||
trust = (dets[:, 4] - self.det_thresh) / (1 - self.det_thresh)
|
||||
af = self.alpha_fixed_emb
|
||||
# From [self.alpha_fixed_emb, 1], goes to 1 as detector is less confident
|
||||
dets_alpha = af + (1 - af) * (1 - trust)
|
||||
|
||||
# get predicted locations from existing trackers.
|
||||
trks = np.zeros((len(self.trackers), 5))
|
||||
trk_embs = []
|
||||
to_del = []
|
||||
ret = []
|
||||
for t, trk in enumerate(trks):
|
||||
pos = self.trackers[t].predict()[0]
|
||||
trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
|
||||
if np.any(np.isnan(pos)):
|
||||
to_del.append(t)
|
||||
else:
|
||||
trk_embs.append(self.trackers[t].get_emb())
|
||||
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
|
||||
|
||||
if len(trk_embs) > 0:
|
||||
trk_embs = np.vstack(trk_embs)
|
||||
else:
|
||||
trk_embs = np.array(trk_embs)
|
||||
|
||||
for t in reversed(to_del):
|
||||
self.trackers.pop(t)
|
||||
|
||||
velocities = np.array([trk.velocity if trk.velocity is not None else np.array((0, 0)) for trk in self.trackers])
|
||||
last_boxes = np.array([trk.last_observation for trk in self.trackers])
|
||||
k_observations = np.array([k_previous_obs(trk.observations, trk.age, self.delta_t) for trk in self.trackers])
|
||||
|
||||
"""
|
||||
First round of association
|
||||
"""
|
||||
# (M detections X N tracks, final score)
|
||||
if self.embedding_off or dets.shape[0] == 0 or trk_embs.shape[0] == 0:
|
||||
stage1_emb_cost = None
|
||||
else:
|
||||
stage1_emb_cost = dets_embs @ trk_embs.T
|
||||
matched, unmatched_dets, unmatched_trks = associate(
|
||||
dets,
|
||||
trks,
|
||||
self.iou_threshold,
|
||||
velocities,
|
||||
k_observations,
|
||||
self.inertia,
|
||||
stage1_emb_cost,
|
||||
self.w_association_emb,
|
||||
self.aw_off,
|
||||
self.aw_param,
|
||||
)
|
||||
for m in matched:
|
||||
self.trackers[m[1]].update(dets[m[0], :5], dets[m[0], 5])
|
||||
self.trackers[m[1]].update_emb(dets_embs[m[0]], alpha=dets_alpha[m[0]])
|
||||
|
||||
"""
|
||||
Second round of associaton by OCR
|
||||
"""
|
||||
if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
|
||||
left_dets = dets[unmatched_dets]
|
||||
left_dets_embs = dets_embs[unmatched_dets]
|
||||
left_trks = last_boxes[unmatched_trks]
|
||||
left_trks_embs = trk_embs[unmatched_trks]
|
||||
|
||||
iou_left = self.asso_func(left_dets, left_trks)
|
||||
# TODO: is better without this
|
||||
emb_cost_left = left_dets_embs @ left_trks_embs.T
|
||||
if self.embedding_off:
|
||||
emb_cost_left = np.zeros_like(emb_cost_left)
|
||||
iou_left = np.array(iou_left)
|
||||
if iou_left.max() > self.iou_threshold:
|
||||
"""
|
||||
NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
|
||||
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
|
||||
uniform here for simplicity
|
||||
"""
|
||||
rematched_indices = linear_assignment(-iou_left)
|
||||
to_remove_det_indices = []
|
||||
to_remove_trk_indices = []
|
||||
for m in rematched_indices:
|
||||
det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]]
|
||||
if iou_left[m[0], m[1]] < self.iou_threshold:
|
||||
continue
|
||||
self.trackers[trk_ind].update(dets[det_ind, :5], dets[det_ind, 5])
|
||||
self.trackers[trk_ind].update_emb(dets_embs[det_ind], alpha=dets_alpha[det_ind])
|
||||
to_remove_det_indices.append(det_ind)
|
||||
to_remove_trk_indices.append(trk_ind)
|
||||
unmatched_dets = np.setdiff1d(unmatched_dets, np.array(to_remove_det_indices))
|
||||
unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices))
|
||||
|
||||
for m in unmatched_trks:
|
||||
self.trackers[m].update(None, None)
|
||||
|
||||
# create and initialise new trackers for unmatched detections
|
||||
for i in unmatched_dets:
|
||||
trk = KalmanBoxTracker(
|
||||
dets[i, :5], dets[i, 5], delta_t=self.delta_t, emb=dets_embs[i], alpha=dets_alpha[i], new_kf=not self.new_kf_off
|
||||
)
|
||||
self.trackers.append(trk)
|
||||
i = len(self.trackers)
|
||||
for trk in reversed(self.trackers):
|
||||
if trk.last_observation.sum() < 0:
|
||||
d = trk.get_state()[0]
|
||||
else:
|
||||
"""
|
||||
this is optional to use the recent observation or the kalman filter prediction,
|
||||
we didn't notice significant difference here
|
||||
"""
|
||||
d = trk.last_observation[:4]
|
||||
if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
|
||||
# +1 as MOT benchmark requires positive
|
||||
ret.append(np.concatenate((d, [trk.id + 1], [trk.cls], [trk.conf])).reshape(1, -1))
|
||||
i -= 1
|
||||
# remove dead tracklet
|
||||
if trk.time_since_update > self.max_age:
|
||||
self.trackers.pop(i)
|
||||
if len(ret) > 0:
|
||||
return np.concatenate(ret)
|
||||
return np.empty((0, 5))
|
||||
|
||||
def _xywh_to_xyxy(self, bbox_xywh):
|
||||
x, y, w, h = bbox_xywh
|
||||
x1 = max(int(x - w / 2), 0)
|
||||
x2 = min(int(x + w / 2), self.width - 1)
|
||||
y1 = max(int(y - h / 2), 0)
|
||||
y2 = min(int(y + h / 2), self.height - 1)
|
||||
return x1, y1, x2, y2
|
||||
|
||||
def _get_features(self, bbox_xywh, ori_img):
|
||||
im_crops = []
|
||||
for box in bbox_xywh:
|
||||
x1, y1, x2, y2 = self._xywh_to_xyxy(box)
|
||||
im = ori_img[y1:y2, x1:x2]
|
||||
im_crops.append(im)
|
||||
if im_crops:
|
||||
features = self.embedder(im_crops).cpu()
|
||||
else:
|
||||
features = np.array([])
|
||||
|
||||
return features
|
||||
|
||||
def update_public(self, dets, cates, scores):
|
||||
self.frame_count += 1
|
||||
|
||||
det_scores = np.ones((dets.shape[0], 1))
|
||||
dets = np.concatenate((dets, det_scores), axis=1)
|
||||
|
||||
remain_inds = scores > self.det_thresh
|
||||
|
||||
cates = cates[remain_inds]
|
||||
dets = dets[remain_inds]
|
||||
|
||||
trks = np.zeros((len(self.trackers), 5))
|
||||
to_del = []
|
||||
ret = []
|
||||
for t, trk in enumerate(trks):
|
||||
pos = self.trackers[t].predict()[0]
|
||||
cat = self.trackers[t].cate
|
||||
trk[:] = [pos[0], pos[1], pos[2], pos[3], cat]
|
||||
if np.any(np.isnan(pos)):
|
||||
to_del.append(t)
|
||||
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
|
||||
for t in reversed(to_del):
|
||||
self.trackers.pop(t)
|
||||
|
||||
velocities = np.array([trk.velocity if trk.velocity is not None else np.array((0, 0)) for trk in self.trackers])
|
||||
last_boxes = np.array([trk.last_observation for trk in self.trackers])
|
||||
k_observations = np.array([k_previous_obs(trk.observations, trk.age, self.delta_t) for trk in self.trackers])
|
||||
|
||||
matched, unmatched_dets, unmatched_trks = associate_kitti(
|
||||
dets,
|
||||
trks,
|
||||
cates,
|
||||
self.iou_threshold,
|
||||
velocities,
|
||||
k_observations,
|
||||
self.inertia,
|
||||
)
|
||||
|
||||
for m in matched:
|
||||
self.trackers[m[1]].update(dets[m[0], :])
|
||||
|
||||
if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
|
||||
"""
|
||||
The re-association stage by OCR.
|
||||
NOTE: at this stage, adding other strategy might be able to continue improve
|
||||
the performance, such as BYTE association by ByteTrack.
|
||||
"""
|
||||
left_dets = dets[unmatched_dets]
|
||||
left_trks = last_boxes[unmatched_trks]
|
||||
left_dets_c = left_dets.copy()
|
||||
left_trks_c = left_trks.copy()
|
||||
|
||||
iou_left = self.asso_func(left_dets_c, left_trks_c)
|
||||
iou_left = np.array(iou_left)
|
||||
det_cates_left = cates[unmatched_dets]
|
||||
trk_cates_left = trks[unmatched_trks][:, 4]
|
||||
num_dets = unmatched_dets.shape[0]
|
||||
num_trks = unmatched_trks.shape[0]
|
||||
cate_matrix = np.zeros((num_dets, num_trks))
|
||||
for i in range(num_dets):
|
||||
for j in range(num_trks):
|
||||
if det_cates_left[i] != trk_cates_left[j]:
|
||||
"""
|
||||
For some datasets, such as KITTI, there are different categories,
|
||||
we have to avoid associate them together.
|
||||
"""
|
||||
cate_matrix[i][j] = -1e6
|
||||
iou_left = iou_left + cate_matrix
|
||||
if iou_left.max() > self.iou_threshold - 0.1:
|
||||
rematched_indices = linear_assignment(-iou_left)
|
||||
to_remove_det_indices = []
|
||||
to_remove_trk_indices = []
|
||||
for m in rematched_indices:
|
||||
det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]]
|
||||
if iou_left[m[0], m[1]] < self.iou_threshold - 0.1:
|
||||
continue
|
||||
self.trackers[trk_ind].update(dets[det_ind, :])
|
||||
to_remove_det_indices.append(det_ind)
|
||||
to_remove_trk_indices.append(trk_ind)
|
||||
unmatched_dets = np.setdiff1d(unmatched_dets, np.array(to_remove_det_indices))
|
||||
unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices))
|
||||
|
||||
for i in unmatched_dets:
|
||||
trk = KalmanBoxTracker(dets[i, :])
|
||||
trk.cate = cates[i]
|
||||
self.trackers.append(trk)
|
||||
i = len(self.trackers)
|
||||
|
||||
for trk in reversed(self.trackers):
|
||||
if trk.last_observation.sum() > 0:
|
||||
d = trk.last_observation[:4]
|
||||
else:
|
||||
d = trk.get_state()[0]
|
||||
if trk.time_since_update < 1:
|
||||
if (self.frame_count <= self.min_hits) or (trk.hit_streak >= self.min_hits):
|
||||
# id+1 as MOT benchmark requires positive
|
||||
ret.append(np.concatenate((d, [trk.id + 1], [trk.cls], [trk.conf])).reshape(1, -1))
|
||||
if trk.hit_streak == self.min_hits:
|
||||
# Head Padding (HP): recover the lost steps during initializing the track
|
||||
for prev_i in range(self.min_hits - 1):
|
||||
prev_observation = trk.history_observations[-(prev_i + 2)]
|
||||
ret.append(
|
||||
(
|
||||
np.concatenate(
|
||||
(
|
||||
prev_observation[:4],
|
||||
[trk.id + 1],
|
||||
[trk.cls],
|
||||
[trk.conf],
|
||||
)
|
||||
)
|
||||
).reshape(1, -1)
|
||||
)
|
||||
i -= 1
|
||||
if trk.time_since_update > self.max_age:
|
||||
self.trackers.pop(i)
|
||||
|
||||
if len(ret) > 0:
|
||||
return np.concatenate(ret)
|
||||
return np.empty((0, 7))
|
||||
|
||||
def dump_cache(self):
|
||||
self.cmc.dump_cache()
|
||||
self.embedder.dump_cache()
|
||||
237
feeder/trackers/deepocsort/reid_multibackend.py
Normal file
237
feeder/trackers/deepocsort/reid_multibackend.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from itertools import islice
|
||||
import torchvision.transforms as transforms
|
||||
import cv2
|
||||
import sys
|
||||
import torchvision.transforms as T
|
||||
from collections import OrderedDict, namedtuple
|
||||
import gdown
|
||||
from os.path import exists as file_exists
|
||||
|
||||
|
||||
from yolov8.ultralytics.yolo.utils.checks import check_requirements, check_version
|
||||
from yolov8.ultralytics.yolo.utils import LOGGER
|
||||
from trackers.strongsort.deep.reid_model_factory import (show_downloadeable_models, get_model_url, get_model_name,
|
||||
download_url, load_pretrained_weights)
|
||||
from trackers.strongsort.deep.models import build_model
|
||||
|
||||
|
||||
def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
|
||||
# Check file(s) for acceptable suffix
|
||||
if file and suffix:
|
||||
if isinstance(suffix, str):
|
||||
suffix = [suffix]
|
||||
for f in file if isinstance(file, (list, tuple)) else [file]:
|
||||
s = Path(f).suffix.lower() # file suffix
|
||||
if len(s):
|
||||
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
|
||||
|
||||
|
||||
class ReIDDetectMultiBackend(nn.Module):
|
||||
# ReID models MultiBackend class for python inference on various backends
|
||||
def __init__(self, weights='osnet_x0_25_msmt17.pt', device=torch.device('cpu'), fp16=False):
|
||||
super().__init__()
|
||||
|
||||
w = weights[0] if isinstance(weights, list) else weights
|
||||
self.pt, self.jit, self.onnx, self.xml, self.engine, self.tflite = self.model_type(w) # get backend
|
||||
self.fp16 = fp16
|
||||
self.fp16 &= self.pt or self.jit or self.engine # FP16
|
||||
|
||||
# Build transform functions
|
||||
self.device = device
|
||||
self.image_size=(256, 128)
|
||||
self.pixel_mean=[0.485, 0.456, 0.406]
|
||||
self.pixel_std=[0.229, 0.224, 0.225]
|
||||
self.transforms = []
|
||||
self.transforms += [T.Resize(self.image_size)]
|
||||
self.transforms += [T.ToTensor()]
|
||||
self.transforms += [T.Normalize(mean=self.pixel_mean, std=self.pixel_std)]
|
||||
self.preprocess = T.Compose(self.transforms)
|
||||
self.to_pil = T.ToPILImage()
|
||||
|
||||
model_name = get_model_name(w)
|
||||
|
||||
if w.suffix == '.pt':
|
||||
model_url = get_model_url(w)
|
||||
if not file_exists(w) and model_url is not None:
|
||||
gdown.download(model_url, str(w), quiet=False)
|
||||
elif file_exists(w):
|
||||
pass
|
||||
else:
|
||||
print(f'No URL associated to the chosen StrongSORT weights ({w}). Choose between:')
|
||||
show_downloadeable_models()
|
||||
exit()
|
||||
|
||||
# Build model
|
||||
self.model = build_model(
|
||||
model_name,
|
||||
num_classes=1,
|
||||
pretrained=not (w and w.is_file()),
|
||||
use_gpu=device
|
||||
)
|
||||
|
||||
if self.pt: # PyTorch
|
||||
# populate model arch with weights
|
||||
if w and w.is_file() and w.suffix == '.pt':
|
||||
load_pretrained_weights(self.model, w)
|
||||
|
||||
self.model.to(device).eval()
|
||||
self.model.half() if self.fp16 else self.model.float()
|
||||
elif self.jit:
|
||||
LOGGER.info(f'Loading {w} for TorchScript inference...')
|
||||
self.model = torch.jit.load(w)
|
||||
self.model.half() if self.fp16 else self.model.float()
|
||||
elif self.onnx: # ONNX Runtime
|
||||
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
|
||||
cuda = torch.cuda.is_available() and device.type != 'cpu'
|
||||
#check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
|
||||
import onnxruntime
|
||||
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
|
||||
self.session = onnxruntime.InferenceSession(str(w), providers=providers)
|
||||
elif self.engine: # TensorRT
|
||||
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
||||
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
||||
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
|
||||
if device.type == 'cpu':
|
||||
device = torch.device('cuda:0')
|
||||
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
||||
self.model_ = runtime.deserialize_cuda_engine(f.read())
|
||||
self.context = self.model_.create_execution_context()
|
||||
self.bindings = OrderedDict()
|
||||
self.fp16 = False # default updated below
|
||||
dynamic = False
|
||||
for index in range(self.model_.num_bindings):
|
||||
name = self.model_.get_binding_name(index)
|
||||
dtype = trt.nptype(self.model_.get_binding_dtype(index))
|
||||
if self.model_.binding_is_input(index):
|
||||
if -1 in tuple(self.model_.get_binding_shape(index)): # dynamic
|
||||
dynamic = True
|
||||
self.context.set_binding_shape(index, tuple(self.model_.get_profile_shape(0, index)[2]))
|
||||
if dtype == np.float16:
|
||||
self.fp16 = True
|
||||
shape = tuple(self.context.get_binding_shape(index))
|
||||
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
|
||||
self.bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
||||
self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items())
|
||||
batch_size = self.bindings['images'].shape[0] # if dynamic, this is instead max batch size
|
||||
elif self.xml: # OpenVINO
|
||||
LOGGER.info(f'Loading {w} for OpenVINO inference...')
|
||||
check_requirements(('openvino',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
||||
from openvino.runtime import Core, Layout, get_batch
|
||||
ie = Core()
|
||||
if not Path(w).is_file(): # if not *.xml
|
||||
w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
|
||||
network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
|
||||
if network.get_parameters()[0].get_layout().empty:
|
||||
network.get_parameters()[0].set_layout(Layout("NCWH"))
|
||||
batch_dim = get_batch(network)
|
||||
if batch_dim.is_static:
|
||||
batch_size = batch_dim.get_length()
|
||||
self.executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
|
||||
self.output_layer = next(iter(self.executable_network.outputs))
|
||||
|
||||
elif self.tflite:
|
||||
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
|
||||
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
||||
from tflite_runtime.interpreter import Interpreter, load_delegate
|
||||
except ImportError:
|
||||
import tensorflow as tf
|
||||
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
|
||||
self.interpreter = tf.lite.Interpreter(model_path=w)
|
||||
self.interpreter.allocate_tensors()
|
||||
# Get input and output tensors.
|
||||
self.input_details = self.interpreter.get_input_details()
|
||||
self.output_details = self.interpreter.get_output_details()
|
||||
|
||||
# Test model on random input data.
|
||||
input_data = np.array(np.random.random_sample((1,256,128,3)), dtype=np.float32)
|
||||
self.interpreter.set_tensor(self.input_details[0]['index'], input_data)
|
||||
|
||||
self.interpreter.invoke()
|
||||
|
||||
# The function `get_tensor()` returns a copy of the tensor data.
|
||||
output_data = self.interpreter.get_tensor(self.output_details[0]['index'])
|
||||
else:
|
||||
print('This model framework is not supported yet!')
|
||||
exit()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def model_type(p='path/to/model.pt'):
|
||||
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
||||
from trackers.reid_export import export_formats
|
||||
sf = list(export_formats().Suffix) # export suffixes
|
||||
check_suffix(p, sf) # checks
|
||||
types = [s in Path(p).name for s in sf]
|
||||
return types
|
||||
|
||||
def _preprocess(self, im_batch):
|
||||
|
||||
images = []
|
||||
for element in im_batch:
|
||||
image = self.to_pil(element)
|
||||
image = self.preprocess(image)
|
||||
images.append(image)
|
||||
|
||||
images = torch.stack(images, dim=0)
|
||||
images = images.to(self.device)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def forward(self, im_batch):
|
||||
|
||||
# preprocess batch
|
||||
im_batch = self._preprocess(im_batch)
|
||||
|
||||
# batch to half
|
||||
if self.fp16 and im_batch.dtype != torch.float16:
|
||||
im_batch = im_batch.half()
|
||||
|
||||
# batch processing
|
||||
features = []
|
||||
if self.pt:
|
||||
features = self.model(im_batch)
|
||||
elif self.jit: # TorchScript
|
||||
features = self.model(im_batch)
|
||||
elif self.onnx: # ONNX Runtime
|
||||
im_batch = im_batch.cpu().numpy() # torch to numpy
|
||||
features = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im_batch})[0]
|
||||
elif self.engine: # TensorRT
|
||||
if True and im_batch.shape != self.bindings['images'].shape:
|
||||
i_in, i_out = (self.model_.get_binding_index(x) for x in ('images', 'output'))
|
||||
self.context.set_binding_shape(i_in, im_batch.shape) # reshape if dynamic
|
||||
self.bindings['images'] = self.bindings['images']._replace(shape=im_batch.shape)
|
||||
self.bindings['output'].data.resize_(tuple(self.context.get_binding_shape(i_out)))
|
||||
s = self.bindings['images'].shape
|
||||
assert im_batch.shape == s, f"input size {im_batch.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
|
||||
self.binding_addrs['images'] = int(im_batch.data_ptr())
|
||||
self.context.execute_v2(list(self.binding_addrs.values()))
|
||||
features = self.bindings['output'].data
|
||||
elif self.xml: # OpenVINO
|
||||
im_batch = im_batch.cpu().numpy() # FP32
|
||||
features = self.executable_network([im_batch])[self.output_layer]
|
||||
else:
|
||||
print('Framework not supported at the moment, we are working on it...')
|
||||
exit()
|
||||
|
||||
if isinstance(features, (list, tuple)):
|
||||
return self.from_numpy(features[0]) if len(features) == 1 else [self.from_numpy(x) for x in features]
|
||||
else:
|
||||
return self.from_numpy(features)
|
||||
|
||||
def from_numpy(self, x):
|
||||
return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x
|
||||
|
||||
def warmup(self, imgsz=[(256, 128, 3)]):
|
||||
# Warmup model by running inference once
|
||||
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.tflite
|
||||
if any(warmup_types) and self.device.type != 'cpu':
|
||||
im = [np.empty(*imgsz).astype(np.uint8)] # input
|
||||
for _ in range(2 if self.jit else 1): #
|
||||
self.forward(im) # warmup
|
||||
84
feeder/trackers/multi_tracker_zoo.py
Normal file
84
feeder/trackers/multi_tracker_zoo.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
from trackers.strongsort.utils.parser import get_config
|
||||
|
||||
def create_tracker(tracker_type, tracker_config, reid_weights, device, half):
|
||||
|
||||
cfg = get_config()
|
||||
cfg.merge_from_file(tracker_config)
|
||||
|
||||
if tracker_type == 'strongsort':
|
||||
from trackers.strongsort.strong_sort import StrongSORT
|
||||
strongsort = StrongSORT(
|
||||
reid_weights,
|
||||
device,
|
||||
half,
|
||||
max_dist=cfg.strongsort.max_dist,
|
||||
max_iou_dist=cfg.strongsort.max_iou_dist,
|
||||
max_age=cfg.strongsort.max_age,
|
||||
max_unmatched_preds=cfg.strongsort.max_unmatched_preds,
|
||||
n_init=cfg.strongsort.n_init,
|
||||
nn_budget=cfg.strongsort.nn_budget,
|
||||
mc_lambda=cfg.strongsort.mc_lambda,
|
||||
ema_alpha=cfg.strongsort.ema_alpha,
|
||||
|
||||
)
|
||||
return strongsort
|
||||
|
||||
elif tracker_type == 'ocsort':
|
||||
from trackers.ocsort.ocsort import OCSort
|
||||
ocsort = OCSort(
|
||||
det_thresh=cfg.ocsort.det_thresh,
|
||||
max_age=cfg.ocsort.max_age,
|
||||
min_hits=cfg.ocsort.min_hits,
|
||||
iou_threshold=cfg.ocsort.iou_thresh,
|
||||
delta_t=cfg.ocsort.delta_t,
|
||||
asso_func=cfg.ocsort.asso_func,
|
||||
inertia=cfg.ocsort.inertia,
|
||||
use_byte=cfg.ocsort.use_byte,
|
||||
)
|
||||
return ocsort
|
||||
|
||||
elif tracker_type == 'bytetrack':
|
||||
from trackers.bytetrack.byte_tracker import BYTETracker
|
||||
bytetracker = BYTETracker(
|
||||
track_thresh=cfg.bytetrack.track_thresh,
|
||||
match_thresh=cfg.bytetrack.match_thresh,
|
||||
track_buffer=cfg.bytetrack.track_buffer,
|
||||
frame_rate=cfg.bytetrack.frame_rate
|
||||
)
|
||||
return bytetracker
|
||||
|
||||
elif tracker_type == 'botsort':
|
||||
from trackers.botsort.bot_sort import BoTSORT
|
||||
botsort = BoTSORT(
|
||||
reid_weights,
|
||||
device,
|
||||
half,
|
||||
track_high_thresh=cfg.botsort.track_high_thresh,
|
||||
new_track_thresh=cfg.botsort.new_track_thresh,
|
||||
track_buffer =cfg.botsort.track_buffer,
|
||||
match_thresh=cfg.botsort.match_thresh,
|
||||
proximity_thresh=cfg.botsort.proximity_thresh,
|
||||
appearance_thresh=cfg.botsort.appearance_thresh,
|
||||
cmc_method =cfg.botsort.cmc_method,
|
||||
frame_rate=cfg.botsort.frame_rate,
|
||||
lambda_=cfg.botsort.lambda_
|
||||
)
|
||||
return botsort
|
||||
elif tracker_type == 'deepocsort':
|
||||
from trackers.deepocsort.ocsort import OCSort
|
||||
botsort = OCSort(
|
||||
reid_weights,
|
||||
device,
|
||||
half,
|
||||
det_thresh=cfg.deepocsort.det_thresh,
|
||||
max_age=cfg.deepocsort.max_age,
|
||||
min_hits=cfg.deepocsort.min_hits,
|
||||
iou_threshold=cfg.deepocsort.iou_thresh,
|
||||
delta_t=cfg.deepocsort.delta_t,
|
||||
asso_func=cfg.deepocsort.asso_func,
|
||||
inertia=cfg.deepocsort.inertia,
|
||||
)
|
||||
return botsort
|
||||
else:
|
||||
print('No such tracker')
|
||||
exit()
|
||||
377
feeder/trackers/ocsort/association.py
Normal file
377
feeder/trackers/ocsort/association.py
Normal file
|
|
@ -0,0 +1,377 @@
|
|||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
def iou_batch(bboxes1, bboxes2):
|
||||
"""
|
||||
From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2]
|
||||
"""
|
||||
bboxes2 = np.expand_dims(bboxes2, 0)
|
||||
bboxes1 = np.expand_dims(bboxes1, 1)
|
||||
|
||||
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
|
||||
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
|
||||
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
|
||||
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
|
||||
w = np.maximum(0., xx2 - xx1)
|
||||
h = np.maximum(0., yy2 - yy1)
|
||||
wh = w * h
|
||||
o = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
|
||||
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh)
|
||||
return(o)
|
||||
|
||||
|
||||
def giou_batch(bboxes1, bboxes2):
|
||||
"""
|
||||
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
|
||||
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
|
||||
:return:
|
||||
"""
|
||||
# for details should go to https://arxiv.org/pdf/1902.09630.pdf
|
||||
# ensure predict's bbox form
|
||||
bboxes2 = np.expand_dims(bboxes2, 0)
|
||||
bboxes1 = np.expand_dims(bboxes1, 1)
|
||||
|
||||
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
|
||||
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
|
||||
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
|
||||
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
|
||||
w = np.maximum(0., xx2 - xx1)
|
||||
h = np.maximum(0., yy2 - yy1)
|
||||
wh = w * h
|
||||
iou = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
|
||||
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh)
|
||||
|
||||
xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0])
|
||||
yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1])
|
||||
xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2])
|
||||
yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3])
|
||||
wc = xxc2 - xxc1
|
||||
hc = yyc2 - yyc1
|
||||
assert((wc > 0).all() and (hc > 0).all())
|
||||
area_enclose = wc * hc
|
||||
giou = iou - (area_enclose - wh) / area_enclose
|
||||
giou = (giou + 1.)/2.0 # resize from (-1,1) to (0,1)
|
||||
return giou
|
||||
|
||||
|
||||
def diou_batch(bboxes1, bboxes2):
|
||||
"""
|
||||
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
|
||||
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
|
||||
:return:
|
||||
"""
|
||||
# for details should go to https://arxiv.org/pdf/1902.09630.pdf
|
||||
# ensure predict's bbox form
|
||||
bboxes2 = np.expand_dims(bboxes2, 0)
|
||||
bboxes1 = np.expand_dims(bboxes1, 1)
|
||||
|
||||
# calculate the intersection box
|
||||
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
|
||||
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
|
||||
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
|
||||
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
|
||||
w = np.maximum(0., xx2 - xx1)
|
||||
h = np.maximum(0., yy2 - yy1)
|
||||
wh = w * h
|
||||
iou = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
|
||||
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh)
|
||||
|
||||
centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0
|
||||
centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0
|
||||
centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0
|
||||
centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0
|
||||
|
||||
inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2
|
||||
|
||||
xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0])
|
||||
yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1])
|
||||
xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2])
|
||||
yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3])
|
||||
|
||||
outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2
|
||||
diou = iou - inner_diag / outer_diag
|
||||
|
||||
return (diou + 1) / 2.0 # resize from (-1,1) to (0,1)
|
||||
|
||||
def ciou_batch(bboxes1, bboxes2):
|
||||
"""
|
||||
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
|
||||
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
|
||||
:return:
|
||||
"""
|
||||
# for details should go to https://arxiv.org/pdf/1902.09630.pdf
|
||||
# ensure predict's bbox form
|
||||
bboxes2 = np.expand_dims(bboxes2, 0)
|
||||
bboxes1 = np.expand_dims(bboxes1, 1)
|
||||
|
||||
# calculate the intersection box
|
||||
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
|
||||
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
|
||||
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
|
||||
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
|
||||
w = np.maximum(0., xx2 - xx1)
|
||||
h = np.maximum(0., yy2 - yy1)
|
||||
wh = w * h
|
||||
iou = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
|
||||
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh)
|
||||
|
||||
centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0
|
||||
centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0
|
||||
centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0
|
||||
centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0
|
||||
|
||||
inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2
|
||||
|
||||
xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0])
|
||||
yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1])
|
||||
xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2])
|
||||
yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3])
|
||||
|
||||
outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2
|
||||
|
||||
w1 = bboxes1[..., 2] - bboxes1[..., 0]
|
||||
h1 = bboxes1[..., 3] - bboxes1[..., 1]
|
||||
w2 = bboxes2[..., 2] - bboxes2[..., 0]
|
||||
h2 = bboxes2[..., 3] - bboxes2[..., 1]
|
||||
|
||||
# prevent dividing over zero. add one pixel shift
|
||||
h2 = h2 + 1.
|
||||
h1 = h1 + 1.
|
||||
arctan = np.arctan(w2/h2) - np.arctan(w1/h1)
|
||||
v = (4 / (np.pi ** 2)) * (arctan ** 2)
|
||||
S = 1 - iou
|
||||
alpha = v / (S+v)
|
||||
ciou = iou - inner_diag / outer_diag - alpha * v
|
||||
|
||||
return (ciou + 1) / 2.0 # resize from (-1,1) to (0,1)
|
||||
|
||||
|
||||
def ct_dist(bboxes1, bboxes2):
|
||||
"""
|
||||
Measure the center distance between two sets of bounding boxes,
|
||||
this is a coarse implementation, we don't recommend using it only
|
||||
for association, which can be unstable and sensitive to frame rate
|
||||
and object speed.
|
||||
"""
|
||||
bboxes2 = np.expand_dims(bboxes2, 0)
|
||||
bboxes1 = np.expand_dims(bboxes1, 1)
|
||||
|
||||
centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0
|
||||
centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0
|
||||
centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0
|
||||
centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0
|
||||
|
||||
ct_dist2 = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2
|
||||
|
||||
ct_dist = np.sqrt(ct_dist2)
|
||||
|
||||
# The linear rescaling is a naive version and needs more study
|
||||
ct_dist = ct_dist / ct_dist.max()
|
||||
return ct_dist.max() - ct_dist # resize to (0,1)
|
||||
|
||||
|
||||
|
||||
def speed_direction_batch(dets, tracks):
|
||||
tracks = tracks[..., np.newaxis]
|
||||
CX1, CY1 = (dets[:,0] + dets[:,2])/2.0, (dets[:,1]+dets[:,3])/2.0
|
||||
CX2, CY2 = (tracks[:,0] + tracks[:,2]) /2.0, (tracks[:,1]+tracks[:,3])/2.0
|
||||
dx = CX1 - CX2
|
||||
dy = CY1 - CY2
|
||||
norm = np.sqrt(dx**2 + dy**2) + 1e-6
|
||||
dx = dx / norm
|
||||
dy = dy / norm
|
||||
return dy, dx # size: num_track x num_det
|
||||
|
||||
|
||||
def linear_assignment(cost_matrix):
|
||||
try:
|
||||
import lap
|
||||
_, x, y = lap.lapjv(cost_matrix, extend_cost=True)
|
||||
return np.array([[y[i],i] for i in x if i >= 0]) #
|
||||
except ImportError:
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
x, y = linear_sum_assignment(cost_matrix)
|
||||
return np.array(list(zip(x, y)))
|
||||
|
||||
|
||||
def associate_detections_to_trackers(detections,trackers, iou_threshold = 0.3):
|
||||
"""
|
||||
Assigns detections to tracked object (both represented as bounding boxes)
|
||||
Returns 3 lists of matches, unmatched_detections and unmatched_trackers
|
||||
"""
|
||||
if(len(trackers)==0):
|
||||
return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0,5),dtype=int)
|
||||
|
||||
iou_matrix = iou_batch(detections, trackers)
|
||||
|
||||
if min(iou_matrix.shape) > 0:
|
||||
a = (iou_matrix > iou_threshold).astype(np.int32)
|
||||
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
||||
matched_indices = np.stack(np.where(a), axis=1)
|
||||
else:
|
||||
matched_indices = linear_assignment(-iou_matrix)
|
||||
else:
|
||||
matched_indices = np.empty(shape=(0,2))
|
||||
|
||||
unmatched_detections = []
|
||||
for d, det in enumerate(detections):
|
||||
if(d not in matched_indices[:,0]):
|
||||
unmatched_detections.append(d)
|
||||
unmatched_trackers = []
|
||||
for t, trk in enumerate(trackers):
|
||||
if(t not in matched_indices[:,1]):
|
||||
unmatched_trackers.append(t)
|
||||
|
||||
#filter out matched with low IOU
|
||||
matches = []
|
||||
for m in matched_indices:
|
||||
if(iou_matrix[m[0], m[1]]<iou_threshold):
|
||||
unmatched_detections.append(m[0])
|
||||
unmatched_trackers.append(m[1])
|
||||
else:
|
||||
matches.append(m.reshape(1,2))
|
||||
if(len(matches)==0):
|
||||
matches = np.empty((0,2),dtype=int)
|
||||
else:
|
||||
matches = np.concatenate(matches,axis=0)
|
||||
|
||||
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
|
||||
|
||||
|
||||
def associate(detections, trackers, iou_threshold, velocities, previous_obs, vdc_weight):
|
||||
if(len(trackers)==0):
|
||||
return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0,5),dtype=int)
|
||||
|
||||
Y, X = speed_direction_batch(detections, previous_obs)
|
||||
inertia_Y, inertia_X = velocities[:,0], velocities[:,1]
|
||||
inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1)
|
||||
inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1)
|
||||
diff_angle_cos = inertia_X * X + inertia_Y * Y
|
||||
diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1)
|
||||
diff_angle = np.arccos(diff_angle_cos)
|
||||
diff_angle = (np.pi /2.0 - np.abs(diff_angle)) / np.pi
|
||||
|
||||
valid_mask = np.ones(previous_obs.shape[0])
|
||||
valid_mask[np.where(previous_obs[:,4]<0)] = 0
|
||||
|
||||
iou_matrix = iou_batch(detections, trackers)
|
||||
scores = np.repeat(detections[:,-1][:, np.newaxis], trackers.shape[0], axis=1)
|
||||
# iou_matrix = iou_matrix * scores # a trick sometiems works, we don't encourage this
|
||||
valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1)
|
||||
|
||||
angle_diff_cost = (valid_mask * diff_angle) * vdc_weight
|
||||
angle_diff_cost = angle_diff_cost.T
|
||||
angle_diff_cost = angle_diff_cost * scores
|
||||
|
||||
if min(iou_matrix.shape) > 0:
|
||||
a = (iou_matrix > iou_threshold).astype(np.int32)
|
||||
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
||||
matched_indices = np.stack(np.where(a), axis=1)
|
||||
else:
|
||||
matched_indices = linear_assignment(-(iou_matrix+angle_diff_cost))
|
||||
else:
|
||||
matched_indices = np.empty(shape=(0,2))
|
||||
|
||||
unmatched_detections = []
|
||||
for d, det in enumerate(detections):
|
||||
if(d not in matched_indices[:,0]):
|
||||
unmatched_detections.append(d)
|
||||
unmatched_trackers = []
|
||||
for t, trk in enumerate(trackers):
|
||||
if(t not in matched_indices[:,1]):
|
||||
unmatched_trackers.append(t)
|
||||
|
||||
# filter out matched with low IOU
|
||||
matches = []
|
||||
for m in matched_indices:
|
||||
if(iou_matrix[m[0], m[1]]<iou_threshold):
|
||||
unmatched_detections.append(m[0])
|
||||
unmatched_trackers.append(m[1])
|
||||
else:
|
||||
matches.append(m.reshape(1,2))
|
||||
if(len(matches)==0):
|
||||
matches = np.empty((0,2),dtype=int)
|
||||
else:
|
||||
matches = np.concatenate(matches,axis=0)
|
||||
|
||||
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
|
||||
|
||||
|
||||
def associate_kitti(detections, trackers, det_cates, iou_threshold,
|
||||
velocities, previous_obs, vdc_weight):
|
||||
if(len(trackers)==0):
|
||||
return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0,5),dtype=int)
|
||||
|
||||
"""
|
||||
Cost from the velocity direction consistency
|
||||
"""
|
||||
Y, X = speed_direction_batch(detections, previous_obs)
|
||||
inertia_Y, inertia_X = velocities[:,0], velocities[:,1]
|
||||
inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1)
|
||||
inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1)
|
||||
diff_angle_cos = inertia_X * X + inertia_Y * Y
|
||||
diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1)
|
||||
diff_angle = np.arccos(diff_angle_cos)
|
||||
diff_angle = (np.pi /2.0 - np.abs(diff_angle)) / np.pi
|
||||
|
||||
valid_mask = np.ones(previous_obs.shape[0])
|
||||
valid_mask[np.where(previous_obs[:,4]<0)]=0
|
||||
valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1)
|
||||
|
||||
scores = np.repeat(detections[:,-1][:, np.newaxis], trackers.shape[0], axis=1)
|
||||
angle_diff_cost = (valid_mask * diff_angle) * vdc_weight
|
||||
angle_diff_cost = angle_diff_cost.T
|
||||
angle_diff_cost = angle_diff_cost * scores
|
||||
|
||||
"""
|
||||
Cost from IoU
|
||||
"""
|
||||
iou_matrix = iou_batch(detections, trackers)
|
||||
|
||||
|
||||
"""
|
||||
With multiple categories, generate the cost for catgory mismatch
|
||||
"""
|
||||
num_dets = detections.shape[0]
|
||||
num_trk = trackers.shape[0]
|
||||
cate_matrix = np.zeros((num_dets, num_trk))
|
||||
for i in range(num_dets):
|
||||
for j in range(num_trk):
|
||||
if det_cates[i] != trackers[j, 4]:
|
||||
cate_matrix[i][j] = -1e6
|
||||
|
||||
cost_matrix = - iou_matrix -angle_diff_cost - cate_matrix
|
||||
|
||||
if min(iou_matrix.shape) > 0:
|
||||
a = (iou_matrix > iou_threshold).astype(np.int32)
|
||||
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
||||
matched_indices = np.stack(np.where(a), axis=1)
|
||||
else:
|
||||
matched_indices = linear_assignment(cost_matrix)
|
||||
else:
|
||||
matched_indices = np.empty(shape=(0,2))
|
||||
|
||||
unmatched_detections = []
|
||||
for d, det in enumerate(detections):
|
||||
if(d not in matched_indices[:,0]):
|
||||
unmatched_detections.append(d)
|
||||
unmatched_trackers = []
|
||||
for t, trk in enumerate(trackers):
|
||||
if(t not in matched_indices[:,1]):
|
||||
unmatched_trackers.append(t)
|
||||
|
||||
#filter out matched with low IOU
|
||||
matches = []
|
||||
for m in matched_indices:
|
||||
if(iou_matrix[m[0], m[1]]<iou_threshold):
|
||||
unmatched_detections.append(m[0])
|
||||
unmatched_trackers.append(m[1])
|
||||
else:
|
||||
matches.append(m.reshape(1,2))
|
||||
if(len(matches)==0):
|
||||
matches = np.empty((0,2),dtype=int)
|
||||
else:
|
||||
matches = np.concatenate(matches,axis=0)
|
||||
|
||||
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
|
||||
12
feeder/trackers/ocsort/configs/ocsort.yaml
Normal file
12
feeder/trackers/ocsort/configs/ocsort.yaml
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
# Trial number: 137
|
||||
# HOTA, MOTA, IDF1: [55.567]
|
||||
ocsort:
|
||||
asso_func: giou
|
||||
conf_thres: 0.5122620708221085
|
||||
delta_t: 1
|
||||
det_thresh: 0
|
||||
inertia: 0.3941737016672115
|
||||
iou_thresh: 0.22136877277096445
|
||||
max_age: 50
|
||||
min_hits: 1
|
||||
use_byte: false
|
||||
1581
feeder/trackers/ocsort/kalmanfilter.py
Normal file
1581
feeder/trackers/ocsort/kalmanfilter.py
Normal file
File diff suppressed because it is too large
Load diff
328
feeder/trackers/ocsort/ocsort.py
Normal file
328
feeder/trackers/ocsort/ocsort.py
Normal file
|
|
@ -0,0 +1,328 @@
|
|||
"""
|
||||
This script is adopted from the SORT script by Alex Bewley alex@bewley.ai
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from .association import *
|
||||
from ultralytics.yolo.utils.ops import xywh2xyxy
|
||||
|
||||
|
||||
def k_previous_obs(observations, cur_age, k):
|
||||
if len(observations) == 0:
|
||||
return [-1, -1, -1, -1, -1]
|
||||
for i in range(k):
|
||||
dt = k - i
|
||||
if cur_age - dt in observations:
|
||||
return observations[cur_age-dt]
|
||||
max_age = max(observations.keys())
|
||||
return observations[max_age]
|
||||
|
||||
|
||||
def convert_bbox_to_z(bbox):
|
||||
"""
|
||||
Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
|
||||
[x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
|
||||
the aspect ratio
|
||||
"""
|
||||
w = bbox[2] - bbox[0]
|
||||
h = bbox[3] - bbox[1]
|
||||
x = bbox[0] + w/2.
|
||||
y = bbox[1] + h/2.
|
||||
s = w * h # scale is just area
|
||||
r = w / float(h+1e-6)
|
||||
return np.array([x, y, s, r]).reshape((4, 1))
|
||||
|
||||
|
||||
def convert_x_to_bbox(x, score=None):
|
||||
"""
|
||||
Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
|
||||
[x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
|
||||
"""
|
||||
w = np.sqrt(x[2] * x[3])
|
||||
h = x[2] / w
|
||||
if(score == None):
|
||||
return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2.]).reshape((1, 4))
|
||||
else:
|
||||
return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2., score]).reshape((1, 5))
|
||||
|
||||
|
||||
def speed_direction(bbox1, bbox2):
|
||||
cx1, cy1 = (bbox1[0]+bbox1[2]) / 2.0, (bbox1[1]+bbox1[3])/2.0
|
||||
cx2, cy2 = (bbox2[0]+bbox2[2]) / 2.0, (bbox2[1]+bbox2[3])/2.0
|
||||
speed = np.array([cy2-cy1, cx2-cx1])
|
||||
norm = np.sqrt((cy2-cy1)**2 + (cx2-cx1)**2) + 1e-6
|
||||
return speed / norm
|
||||
|
||||
|
||||
class KalmanBoxTracker(object):
|
||||
"""
|
||||
This class represents the internal state of individual tracked objects observed as bbox.
|
||||
"""
|
||||
count = 0
|
||||
|
||||
def __init__(self, bbox, cls, delta_t=3, orig=False):
|
||||
"""
|
||||
Initialises a tracker using initial bounding box.
|
||||
|
||||
"""
|
||||
# define constant velocity model
|
||||
if not orig:
|
||||
from .kalmanfilter import KalmanFilterNew as KalmanFilter
|
||||
self.kf = KalmanFilter(dim_x=7, dim_z=4)
|
||||
else:
|
||||
from filterpy.kalman import KalmanFilter
|
||||
self.kf = KalmanFilter(dim_x=7, dim_z=4)
|
||||
self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0], [0, 0, 1, 0, 0, 0, 1], [
|
||||
0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 1]])
|
||||
self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0],
|
||||
[0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0]])
|
||||
|
||||
self.kf.R[2:, 2:] *= 10.
|
||||
self.kf.P[4:, 4:] *= 1000. # give high uncertainty to the unobservable initial velocities
|
||||
self.kf.P *= 10.
|
||||
self.kf.Q[-1, -1] *= 0.01
|
||||
self.kf.Q[4:, 4:] *= 0.01
|
||||
|
||||
self.kf.x[:4] = convert_bbox_to_z(bbox)
|
||||
self.time_since_update = 0
|
||||
self.id = KalmanBoxTracker.count
|
||||
KalmanBoxTracker.count += 1
|
||||
self.history = []
|
||||
self.hits = 0
|
||||
self.hit_streak = 0
|
||||
self.age = 0
|
||||
self.conf = bbox[-1]
|
||||
self.cls = cls
|
||||
"""
|
||||
NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of
|
||||
function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a
|
||||
fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), let's bear it for now.
|
||||
"""
|
||||
self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder
|
||||
self.observations = dict()
|
||||
self.history_observations = []
|
||||
self.velocity = None
|
||||
self.delta_t = delta_t
|
||||
|
||||
def update(self, bbox, cls):
|
||||
"""
|
||||
Updates the state vector with observed bbox.
|
||||
"""
|
||||
|
||||
if bbox is not None:
|
||||
self.conf = bbox[-1]
|
||||
self.cls = cls
|
||||
if self.last_observation.sum() >= 0: # no previous observation
|
||||
previous_box = None
|
||||
for i in range(self.delta_t):
|
||||
dt = self.delta_t - i
|
||||
if self.age - dt in self.observations:
|
||||
previous_box = self.observations[self.age-dt]
|
||||
break
|
||||
if previous_box is None:
|
||||
previous_box = self.last_observation
|
||||
"""
|
||||
Estimate the track speed direction with observations \Delta t steps away
|
||||
"""
|
||||
self.velocity = speed_direction(previous_box, bbox)
|
||||
|
||||
"""
|
||||
Insert new observations. This is a ugly way to maintain both self.observations
|
||||
and self.history_observations. Bear it for the moment.
|
||||
"""
|
||||
self.last_observation = bbox
|
||||
self.observations[self.age] = bbox
|
||||
self.history_observations.append(bbox)
|
||||
|
||||
self.time_since_update = 0
|
||||
self.history = []
|
||||
self.hits += 1
|
||||
self.hit_streak += 1
|
||||
self.kf.update(convert_bbox_to_z(bbox))
|
||||
else:
|
||||
self.kf.update(bbox)
|
||||
|
||||
def predict(self):
|
||||
"""
|
||||
Advances the state vector and returns the predicted bounding box estimate.
|
||||
"""
|
||||
if((self.kf.x[6]+self.kf.x[2]) <= 0):
|
||||
self.kf.x[6] *= 0.0
|
||||
|
||||
self.kf.predict()
|
||||
self.age += 1
|
||||
if(self.time_since_update > 0):
|
||||
self.hit_streak = 0
|
||||
self.time_since_update += 1
|
||||
self.history.append(convert_x_to_bbox(self.kf.x))
|
||||
return self.history[-1]
|
||||
|
||||
def get_state(self):
|
||||
"""
|
||||
Returns the current bounding box estimate.
|
||||
"""
|
||||
return convert_x_to_bbox(self.kf.x)
|
||||
|
||||
|
||||
"""
|
||||
We support multiple ways for association cost calculation, by default
|
||||
we use IoU. GIoU may have better performance in some situations. We note
|
||||
that we hardly normalize the cost by all methods to (0,1) which may not be
|
||||
the best practice.
|
||||
"""
|
||||
ASSO_FUNCS = { "iou": iou_batch,
|
||||
"giou": giou_batch,
|
||||
"ciou": ciou_batch,
|
||||
"diou": diou_batch,
|
||||
"ct_dist": ct_dist}
|
||||
|
||||
|
||||
class OCSort(object):
|
||||
def __init__(self, det_thresh, max_age=30, min_hits=3,
|
||||
iou_threshold=0.3, delta_t=3, asso_func="iou", inertia=0.2, use_byte=False):
|
||||
"""
|
||||
Sets key parameters for SORT
|
||||
"""
|
||||
self.max_age = max_age
|
||||
self.min_hits = min_hits
|
||||
self.iou_threshold = iou_threshold
|
||||
self.trackers = []
|
||||
self.frame_count = 0
|
||||
self.det_thresh = det_thresh
|
||||
self.delta_t = delta_t
|
||||
self.asso_func = ASSO_FUNCS[asso_func]
|
||||
self.inertia = inertia
|
||||
self.use_byte = use_byte
|
||||
KalmanBoxTracker.count = 0
|
||||
|
||||
def update(self, dets, _):
|
||||
"""
|
||||
Params:
|
||||
dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...]
|
||||
Requires: this method must be called once for each frame even with empty detections (use np.empty((0, 5)) for frames without detections).
|
||||
Returns the a similar array, where the last column is the object ID.
|
||||
NOTE: The number of objects returned may differ from the number of detections provided.
|
||||
"""
|
||||
|
||||
self.frame_count += 1
|
||||
|
||||
xyxys = dets[:, 0:4]
|
||||
confs = dets[:, 4]
|
||||
clss = dets[:, 5]
|
||||
|
||||
classes = clss.numpy()
|
||||
xyxys = xyxys.numpy()
|
||||
confs = confs.numpy()
|
||||
|
||||
output_results = np.column_stack((xyxys, confs, classes))
|
||||
|
||||
inds_low = confs > 0.1
|
||||
inds_high = confs < self.det_thresh
|
||||
inds_second = np.logical_and(inds_low, inds_high) # self.det_thresh > score > 0.1, for second matching
|
||||
dets_second = output_results[inds_second] # detections for second matching
|
||||
remain_inds = confs > self.det_thresh
|
||||
dets = output_results[remain_inds]
|
||||
|
||||
# get predicted locations from existing trackers.
|
||||
trks = np.zeros((len(self.trackers), 5))
|
||||
to_del = []
|
||||
ret = []
|
||||
for t, trk in enumerate(trks):
|
||||
pos = self.trackers[t].predict()[0]
|
||||
trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
|
||||
if np.any(np.isnan(pos)):
|
||||
to_del.append(t)
|
||||
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
|
||||
for t in reversed(to_del):
|
||||
self.trackers.pop(t)
|
||||
|
||||
velocities = np.array(
|
||||
[trk.velocity if trk.velocity is not None else np.array((0, 0)) for trk in self.trackers])
|
||||
last_boxes = np.array([trk.last_observation for trk in self.trackers])
|
||||
k_observations = np.array(
|
||||
[k_previous_obs(trk.observations, trk.age, self.delta_t) for trk in self.trackers])
|
||||
|
||||
"""
|
||||
First round of association
|
||||
"""
|
||||
matched, unmatched_dets, unmatched_trks = associate(
|
||||
dets, trks, self.iou_threshold, velocities, k_observations, self.inertia)
|
||||
for m in matched:
|
||||
self.trackers[m[1]].update(dets[m[0], :5], dets[m[0], 5])
|
||||
|
||||
"""
|
||||
Second round of associaton by OCR
|
||||
"""
|
||||
# BYTE association
|
||||
if self.use_byte and len(dets_second) > 0 and unmatched_trks.shape[0] > 0:
|
||||
u_trks = trks[unmatched_trks]
|
||||
iou_left = self.asso_func(dets_second, u_trks) # iou between low score detections and unmatched tracks
|
||||
iou_left = np.array(iou_left)
|
||||
if iou_left.max() > self.iou_threshold:
|
||||
"""
|
||||
NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
|
||||
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
|
||||
uniform here for simplicity
|
||||
"""
|
||||
matched_indices = linear_assignment(-iou_left)
|
||||
to_remove_trk_indices = []
|
||||
for m in matched_indices:
|
||||
det_ind, trk_ind = m[0], unmatched_trks[m[1]]
|
||||
if iou_left[m[0], m[1]] < self.iou_threshold:
|
||||
continue
|
||||
self.trackers[trk_ind].update(dets_second[det_ind, :5], dets_second[det_ind, 5])
|
||||
to_remove_trk_indices.append(trk_ind)
|
||||
unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices))
|
||||
|
||||
if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
|
||||
left_dets = dets[unmatched_dets]
|
||||
left_trks = last_boxes[unmatched_trks]
|
||||
iou_left = self.asso_func(left_dets, left_trks)
|
||||
iou_left = np.array(iou_left)
|
||||
if iou_left.max() > self.iou_threshold:
|
||||
"""
|
||||
NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
|
||||
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
|
||||
uniform here for simplicity
|
||||
"""
|
||||
rematched_indices = linear_assignment(-iou_left)
|
||||
to_remove_det_indices = []
|
||||
to_remove_trk_indices = []
|
||||
for m in rematched_indices:
|
||||
det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]]
|
||||
if iou_left[m[0], m[1]] < self.iou_threshold:
|
||||
continue
|
||||
self.trackers[trk_ind].update(dets[det_ind, :5], dets[det_ind, 5])
|
||||
to_remove_det_indices.append(det_ind)
|
||||
to_remove_trk_indices.append(trk_ind)
|
||||
unmatched_dets = np.setdiff1d(unmatched_dets, np.array(to_remove_det_indices))
|
||||
unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices))
|
||||
|
||||
for m in unmatched_trks:
|
||||
self.trackers[m].update(None, None)
|
||||
|
||||
# create and initialise new trackers for unmatched detections
|
||||
for i in unmatched_dets:
|
||||
trk = KalmanBoxTracker(dets[i, :5], dets[i, 5], delta_t=self.delta_t)
|
||||
self.trackers.append(trk)
|
||||
i = len(self.trackers)
|
||||
for trk in reversed(self.trackers):
|
||||
if trk.last_observation.sum() < 0:
|
||||
d = trk.get_state()[0]
|
||||
else:
|
||||
"""
|
||||
this is optional to use the recent observation or the kalman filter prediction,
|
||||
we didn't notice significant difference here
|
||||
"""
|
||||
d = trk.last_observation[:4]
|
||||
if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
|
||||
# +1 as MOT benchmark requires positive
|
||||
ret.append(np.concatenate((d, [trk.id+1], [trk.cls], [trk.conf])).reshape(1, -1))
|
||||
i -= 1
|
||||
# remove dead tracklet
|
||||
if(trk.time_since_update > self.max_age):
|
||||
self.trackers.pop(i)
|
||||
if(len(ret) > 0):
|
||||
return np.concatenate(ret)
|
||||
return np.empty((0, 5))
|
||||
313
feeder/trackers/reid_export.py
Normal file
313
feeder/trackers/reid_export.py
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
import argparse
|
||||
|
||||
import os
|
||||
# limit the number of cpus used by high performance libraries
|
||||
os.environ["OMP_NUM_THREADS"] = "1"
|
||||
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
||||
os.environ["MKL_NUM_THREADS"] = "1"
|
||||
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
|
||||
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import time
|
||||
import platform
|
||||
import pandas as pd
|
||||
import subprocess
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torch.utils.mobile_optimizer import optimize_for_mobile
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[0].parents[0] # yolov5 strongsort root directory
|
||||
WEIGHTS = ROOT / 'weights'
|
||||
|
||||
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.append(str(ROOT)) # add ROOT to PATH
|
||||
if str(ROOT / 'yolov5') not in sys.path:
|
||||
sys.path.append(str(ROOT / 'yolov5')) # add yolov5 ROOT to PATH
|
||||
|
||||
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
||||
|
||||
import logging
|
||||
from ultralytics.yolo.utils.torch_utils import select_device
|
||||
from ultralytics.yolo.utils import LOGGER, colorstr, ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_version
|
||||
from trackers.strongsort.deep.models import build_model
|
||||
from trackers.strongsort.deep.reid_model_factory import get_model_name, load_pretrained_weights
|
||||
|
||||
|
||||
def file_size(path):
|
||||
# Return file/dir size (MB)
|
||||
path = Path(path)
|
||||
if path.is_file():
|
||||
return path.stat().st_size / 1E6
|
||||
elif path.is_dir():
|
||||
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
|
||||
def export_formats():
|
||||
# YOLOv5 export formats
|
||||
x = [
|
||||
['PyTorch', '-', '.pt', True, True],
|
||||
['TorchScript', 'torchscript', '.torchscript', True, True],
|
||||
['ONNX', 'onnx', '.onnx', True, True],
|
||||
['OpenVINO', 'openvino', '_openvino_model', True, False],
|
||||
['TensorRT', 'engine', '.engine', False, True],
|
||||
['TensorFlow Lite', 'tflite', '.tflite', True, False],
|
||||
]
|
||||
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
||||
|
||||
|
||||
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
|
||||
# YOLOv5 TorchScript model export
|
||||
try:
|
||||
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
|
||||
f = file.with_suffix('.torchscript')
|
||||
|
||||
ts = torch.jit.trace(model, im, strict=False)
|
||||
if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
||||
optimize_for_mobile(ts)._save_for_lite_interpreter(str(f))
|
||||
else:
|
||||
ts.save(str(f))
|
||||
|
||||
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
||||
return f
|
||||
except Exception as e:
|
||||
LOGGER.info(f'{prefix} export failure: {e}')
|
||||
|
||||
|
||||
def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')):
|
||||
# ONNX export
|
||||
try:
|
||||
check_requirements(('onnx',))
|
||||
import onnx
|
||||
|
||||
f = file.with_suffix('.onnx')
|
||||
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
|
||||
|
||||
if dynamic:
|
||||
dynamic = {'images': {0: 'batch'}} # shape(1,3,640,640)
|
||||
dynamic['output'] = {0: 'batch'} # shape(1,25200,85)
|
||||
|
||||
torch.onnx.export(
|
||||
model.cpu() if dynamic else model, # --dynamic only compatible with cpu
|
||||
im.cpu() if dynamic else im,
|
||||
f,
|
||||
verbose=False,
|
||||
opset_version=opset,
|
||||
do_constant_folding=True,
|
||||
input_names=['images'],
|
||||
output_names=['output'],
|
||||
dynamic_axes=dynamic or None
|
||||
)
|
||||
# Checks
|
||||
model_onnx = onnx.load(f) # load onnx model
|
||||
onnx.checker.check_model(model_onnx) # check onnx model
|
||||
onnx.save(model_onnx, f)
|
||||
|
||||
# Simplify
|
||||
if simplify:
|
||||
try:
|
||||
cuda = torch.cuda.is_available()
|
||||
check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
|
||||
import onnxsim
|
||||
|
||||
LOGGER.info(f'simplifying with onnx-simplifier {onnxsim.__version__}...')
|
||||
model_onnx, check = onnxsim.simplify(model_onnx)
|
||||
assert check, 'assert check failed'
|
||||
onnx.save(model_onnx, f)
|
||||
except Exception as e:
|
||||
LOGGER.info(f'simplifier failure: {e}')
|
||||
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
||||
return f
|
||||
except Exception as e:
|
||||
LOGGER.info(f'export failure: {e}')
|
||||
|
||||
|
||||
|
||||
def export_openvino(file, half, prefix=colorstr('OpenVINO:')):
|
||||
# YOLOv5 OpenVINO export
|
||||
check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
||||
import openvino.inference_engine as ie
|
||||
try:
|
||||
LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
|
||||
f = str(file).replace('.pt', f'_openvino_model{os.sep}')
|
||||
|
||||
cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
|
||||
subprocess.check_output(cmd.split()) # export
|
||||
except Exception as e:
|
||||
LOGGER.info(f'export failure: {e}')
|
||||
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
||||
return f
|
||||
|
||||
|
||||
def export_tflite(file, half, prefix=colorstr('TFLite:')):
|
||||
# YOLOv5 OpenVINO export
|
||||
try:
|
||||
check_requirements(('openvino2tensorflow', 'tensorflow', 'tensorflow_datasets')) # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
||||
import openvino.inference_engine as ie
|
||||
LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
|
||||
output = Path(str(file).replace(f'_openvino_model{os.sep}', f'_tflite_model{os.sep}'))
|
||||
modelxml = list(Path(file).glob('*.xml'))[0]
|
||||
cmd = f"openvino2tensorflow \
|
||||
--model_path {modelxml} \
|
||||
--model_output_path {output} \
|
||||
--output_pb \
|
||||
--output_saved_model \
|
||||
--output_no_quant_float32_tflite \
|
||||
--output_dynamic_range_quant_tflite"
|
||||
subprocess.check_output(cmd.split()) # export
|
||||
|
||||
LOGGER.info(f'{prefix} export success, results saved in {output} ({file_size(f):.1f} MB)')
|
||||
return f
|
||||
except Exception as e:
|
||||
LOGGER.info(f'\n{prefix} export failure: {e}')
|
||||
|
||||
|
||||
def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
|
||||
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
|
||||
try:
|
||||
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
|
||||
try:
|
||||
import tensorrt as trt
|
||||
except Exception:
|
||||
if platform.system() == 'Linux':
|
||||
check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',))
|
||||
import tensorrt as trt
|
||||
|
||||
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
|
||||
grid = model.model[-1].anchor_grid
|
||||
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
|
||||
export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
|
||||
model.model[-1].anchor_grid = grid
|
||||
else: # TensorRT >= 8
|
||||
check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
|
||||
export_onnx(model, im, file, 12, dynamic, simplify) # opset 13
|
||||
onnx = file.with_suffix('.onnx')
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
|
||||
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
|
||||
f = file.with_suffix('.engine') # TensorRT engine file
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
if verbose:
|
||||
logger.min_severity = trt.Logger.Severity.VERBOSE
|
||||
|
||||
builder = trt.Builder(logger)
|
||||
config = builder.create_builder_config()
|
||||
config.max_workspace_size = workspace * 1 << 30
|
||||
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
|
||||
|
||||
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||||
network = builder.create_network(flag)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
if not parser.parse_from_file(str(onnx)):
|
||||
raise RuntimeError(f'failed to load ONNX file: {onnx}')
|
||||
|
||||
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
||||
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
||||
LOGGER.info(f'{prefix} Network Description:')
|
||||
for inp in inputs:
|
||||
LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
|
||||
for out in outputs:
|
||||
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
|
||||
|
||||
if dynamic:
|
||||
if im.shape[0] <= 1:
|
||||
LOGGER.warning(f"{prefix}WARNING: --dynamic model requires maximum --batch-size argument")
|
||||
profile = builder.create_optimization_profile()
|
||||
for inp in inputs:
|
||||
profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
|
||||
config.add_optimization_profile(profile)
|
||||
|
||||
LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}')
|
||||
if builder.platform_has_fast_fp16 and half:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
|
||||
t.write(engine.serialize())
|
||||
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
||||
return f
|
||||
except Exception as e:
|
||||
LOGGER.info(f'\n{prefix} export failure: {e}')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="ReID export")
|
||||
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
||||
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[256, 128], help='image (h, w)')
|
||||
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
|
||||
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')
|
||||
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
|
||||
parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
|
||||
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
|
||||
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
|
||||
parser.add_argument('--weights', nargs='+', type=str, default=WEIGHTS / 'osnet_x0_25_msmt17.pt', help='model.pt path(s)')
|
||||
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
|
||||
parser.add_argument('--include',
|
||||
nargs='+',
|
||||
default=['torchscript'],
|
||||
help='torchscript, onnx, openvino, engine')
|
||||
args = parser.parse_args()
|
||||
|
||||
t = time.time()
|
||||
|
||||
include = [x.lower() for x in args.include] # to lowercase
|
||||
fmts = tuple(export_formats()['Argument'][1:]) # --include arguments
|
||||
flags = [x in include for x in fmts]
|
||||
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'
|
||||
jit, onnx, openvino, engine, tflite = flags # export booleans
|
||||
|
||||
args.device = select_device(args.device)
|
||||
if args.half:
|
||||
assert args.device.type != 'cpu', '--half only compatible with GPU export, i.e. use --device 0'
|
||||
assert not args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
|
||||
|
||||
if type(args.weights) is list:
|
||||
args.weights = Path(args.weights[0])
|
||||
|
||||
model = build_model(
|
||||
get_model_name(args.weights),
|
||||
num_classes=1,
|
||||
pretrained=not (args.weights and args.weights.is_file() and args.weights.suffix == '.pt'),
|
||||
use_gpu=args.device
|
||||
).to(args.device)
|
||||
load_pretrained_weights(model, args.weights)
|
||||
model.eval()
|
||||
|
||||
if args.optimize:
|
||||
assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
|
||||
|
||||
im = torch.zeros(args.batch_size, 3, args.imgsz[0], args.imgsz[1]).to(args.device) # image size(1,3,640,480) BCHW iDetection
|
||||
for _ in range(2):
|
||||
y = model(im) # dry runs
|
||||
if args.half:
|
||||
im, model = im.half(), model.half() # to FP16
|
||||
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
|
||||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {args.weights} with output shape {shape} ({file_size(args.weights):.1f} MB)")
|
||||
|
||||
# Exports
|
||||
f = [''] * len(fmts) # exported filenames
|
||||
if jit:
|
||||
f[0] = export_torchscript(model, im, args.weights, args.optimize) # opset 12
|
||||
if engine: # TensorRT required before ONNX
|
||||
f[1] = export_engine(model, im, args.weights, args.half, args.dynamic, args.simplify, args.workspace, args.verbose)
|
||||
if onnx: # OpenVINO requires ONNX
|
||||
f[2] = export_onnx(model, im, args.weights, args.opset, args.dynamic, args.simplify) # opset 12
|
||||
if openvino:
|
||||
f[3] = export_openvino(args.weights, args.half)
|
||||
if tflite:
|
||||
export_tflite(f, False)
|
||||
|
||||
# Finish
|
||||
f = [str(x) for x in f if x] # filter out '' and None
|
||||
if any(f):
|
||||
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
|
||||
f"\nResults saved to {colorstr('bold', args.weights.parent.resolve())}"
|
||||
f"\nVisualize: https://netron.app")
|
||||
|
||||
0
feeder/trackers/strongsort/__init__.py
Normal file
0
feeder/trackers/strongsort/__init__.py
Normal file
11
feeder/trackers/strongsort/configs/strongsort.yaml
Normal file
11
feeder/trackers/strongsort/configs/strongsort.yaml
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
strongsort:
|
||||
ecc: true
|
||||
ema_alpha: 0.8962157769329083
|
||||
max_age: 40
|
||||
max_dist: 0.1594374041012136
|
||||
max_iou_dist: 0.5431835667667874
|
||||
max_unmatched_preds: 0
|
||||
mc_lambda: 0.995
|
||||
n_init: 3
|
||||
nn_budget: 100
|
||||
conf_thres: 0.5122620708221085
|
||||
0
feeder/trackers/strongsort/deep/checkpoint/.gitkeep
Normal file
0
feeder/trackers/strongsort/deep/checkpoint/.gitkeep
Normal file
Binary file not shown.
Binary file not shown.
BIN
feeder/trackers/strongsort/deep/checkpoint/osnet_x1_0_msmt17.pth
Normal file
BIN
feeder/trackers/strongsort/deep/checkpoint/osnet_x1_0_msmt17.pth
Normal file
Binary file not shown.
122
feeder/trackers/strongsort/deep/models/__init__.py
Normal file
122
feeder/trackers/strongsort/deep/models/__init__.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
from __future__ import absolute_import
|
||||
import torch
|
||||
|
||||
from .pcb import *
|
||||
from .mlfn import *
|
||||
from .hacnn import *
|
||||
from .osnet import *
|
||||
from .senet import *
|
||||
from .mudeep import *
|
||||
from .nasnet import *
|
||||
from .resnet import *
|
||||
from .densenet import *
|
||||
from .xception import *
|
||||
from .osnet_ain import *
|
||||
from .resnetmid import *
|
||||
from .shufflenet import *
|
||||
from .squeezenet import *
|
||||
from .inceptionv4 import *
|
||||
from .mobilenetv2 import *
|
||||
from .resnet_ibn_a import *
|
||||
from .resnet_ibn_b import *
|
||||
from .shufflenetv2 import *
|
||||
from .inceptionresnetv2 import *
|
||||
|
||||
__model_factory = {
|
||||
# image classification models
|
||||
'resnet18': resnet18,
|
||||
'resnet34': resnet34,
|
||||
'resnet50': resnet50,
|
||||
'resnet101': resnet101,
|
||||
'resnet152': resnet152,
|
||||
'resnext50_32x4d': resnext50_32x4d,
|
||||
'resnext101_32x8d': resnext101_32x8d,
|
||||
'resnet50_fc512': resnet50_fc512,
|
||||
'se_resnet50': se_resnet50,
|
||||
'se_resnet50_fc512': se_resnet50_fc512,
|
||||
'se_resnet101': se_resnet101,
|
||||
'se_resnext50_32x4d': se_resnext50_32x4d,
|
||||
'se_resnext101_32x4d': se_resnext101_32x4d,
|
||||
'densenet121': densenet121,
|
||||
'densenet169': densenet169,
|
||||
'densenet201': densenet201,
|
||||
'densenet161': densenet161,
|
||||
'densenet121_fc512': densenet121_fc512,
|
||||
'inceptionresnetv2': inceptionresnetv2,
|
||||
'inceptionv4': inceptionv4,
|
||||
'xception': xception,
|
||||
'resnet50_ibn_a': resnet50_ibn_a,
|
||||
'resnet50_ibn_b': resnet50_ibn_b,
|
||||
# lightweight models
|
||||
'nasnsetmobile': nasnetamobile,
|
||||
'mobilenetv2_x1_0': mobilenetv2_x1_0,
|
||||
'mobilenetv2_x1_4': mobilenetv2_x1_4,
|
||||
'shufflenet': shufflenet,
|
||||
'squeezenet1_0': squeezenet1_0,
|
||||
'squeezenet1_0_fc512': squeezenet1_0_fc512,
|
||||
'squeezenet1_1': squeezenet1_1,
|
||||
'shufflenet_v2_x0_5': shufflenet_v2_x0_5,
|
||||
'shufflenet_v2_x1_0': shufflenet_v2_x1_0,
|
||||
'shufflenet_v2_x1_5': shufflenet_v2_x1_5,
|
||||
'shufflenet_v2_x2_0': shufflenet_v2_x2_0,
|
||||
# reid-specific models
|
||||
'mudeep': MuDeep,
|
||||
'resnet50mid': resnet50mid,
|
||||
'hacnn': HACNN,
|
||||
'pcb_p6': pcb_p6,
|
||||
'pcb_p4': pcb_p4,
|
||||
'mlfn': mlfn,
|
||||
'osnet_x1_0': osnet_x1_0,
|
||||
'osnet_x0_75': osnet_x0_75,
|
||||
'osnet_x0_5': osnet_x0_5,
|
||||
'osnet_x0_25': osnet_x0_25,
|
||||
'osnet_ibn_x1_0': osnet_ibn_x1_0,
|
||||
'osnet_ain_x1_0': osnet_ain_x1_0,
|
||||
'osnet_ain_x0_75': osnet_ain_x0_75,
|
||||
'osnet_ain_x0_5': osnet_ain_x0_5,
|
||||
'osnet_ain_x0_25': osnet_ain_x0_25
|
||||
}
|
||||
|
||||
|
||||
def show_avai_models():
|
||||
"""Displays available models.
|
||||
|
||||
Examples::
|
||||
>>> from torchreid import models
|
||||
>>> models.show_avai_models()
|
||||
"""
|
||||
print(list(__model_factory.keys()))
|
||||
|
||||
|
||||
def build_model(
|
||||
name, num_classes, loss='softmax', pretrained=True, use_gpu=True
|
||||
):
|
||||
"""A function wrapper for building a model.
|
||||
|
||||
Args:
|
||||
name (str): model name.
|
||||
num_classes (int): number of training identities.
|
||||
loss (str, optional): loss function to optimize the model. Currently
|
||||
supports "softmax" and "triplet". Default is "softmax".
|
||||
pretrained (bool, optional): whether to load ImageNet-pretrained weights.
|
||||
Default is True.
|
||||
use_gpu (bool, optional): whether to use gpu. Default is True.
|
||||
|
||||
Returns:
|
||||
nn.Module
|
||||
|
||||
Examples::
|
||||
>>> from torchreid import models
|
||||
>>> model = models.build_model('resnet50', 751, loss='softmax')
|
||||
"""
|
||||
avai_models = list(__model_factory.keys())
|
||||
if name not in avai_models:
|
||||
raise KeyError(
|
||||
'Unknown model: {}. Must be one of {}'.format(name, avai_models)
|
||||
)
|
||||
return __model_factory[name](
|
||||
num_classes=num_classes,
|
||||
loss=loss,
|
||||
pretrained=pretrained,
|
||||
use_gpu=use_gpu
|
||||
)
|
||||
380
feeder/trackers/strongsort/deep/models/densenet.py
Normal file
380
feeder/trackers/strongsort/deep/models/densenet.py
Normal file
|
|
@ -0,0 +1,380 @@
|
|||
"""
|
||||
Code source: https://github.com/pytorch/vision
|
||||
"""
|
||||
from __future__ import division, absolute_import
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.utils import model_zoo
|
||||
|
||||
__all__ = [
|
||||
'densenet121', 'densenet169', 'densenet201', 'densenet161',
|
||||
'densenet121_fc512'
|
||||
]
|
||||
|
||||
model_urls = {
|
||||
'densenet121':
|
||||
'https://download.pytorch.org/models/densenet121-a639ec97.pth',
|
||||
'densenet169':
|
||||
'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
|
||||
'densenet201':
|
||||
'https://download.pytorch.org/models/densenet201-c1103571.pth',
|
||||
'densenet161':
|
||||
'https://download.pytorch.org/models/densenet161-8d451a50.pth',
|
||||
}
|
||||
|
||||
|
||||
class _DenseLayer(nn.Sequential):
|
||||
|
||||
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
|
||||
super(_DenseLayer, self).__init__()
|
||||
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
|
||||
self.add_module('relu1', nn.ReLU(inplace=True)),
|
||||
self.add_module(
|
||||
'conv1',
|
||||
nn.Conv2d(
|
||||
num_input_features,
|
||||
bn_size * growth_rate,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=False
|
||||
)
|
||||
),
|
||||
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
|
||||
self.add_module('relu2', nn.ReLU(inplace=True)),
|
||||
self.add_module(
|
||||
'conv2',
|
||||
nn.Conv2d(
|
||||
bn_size * growth_rate,
|
||||
growth_rate,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
),
|
||||
self.drop_rate = drop_rate
|
||||
|
||||
def forward(self, x):
|
||||
new_features = super(_DenseLayer, self).forward(x)
|
||||
if self.drop_rate > 0:
|
||||
new_features = F.dropout(
|
||||
new_features, p=self.drop_rate, training=self.training
|
||||
)
|
||||
return torch.cat([x, new_features], 1)
|
||||
|
||||
|
||||
class _DenseBlock(nn.Sequential):
|
||||
|
||||
def __init__(
|
||||
self, num_layers, num_input_features, bn_size, growth_rate, drop_rate
|
||||
):
|
||||
super(_DenseBlock, self).__init__()
|
||||
for i in range(num_layers):
|
||||
layer = _DenseLayer(
|
||||
num_input_features + i*growth_rate, growth_rate, bn_size,
|
||||
drop_rate
|
||||
)
|
||||
self.add_module('denselayer%d' % (i+1), layer)
|
||||
|
||||
|
||||
class _Transition(nn.Sequential):
|
||||
|
||||
def __init__(self, num_input_features, num_output_features):
|
||||
super(_Transition, self).__init__()
|
||||
self.add_module('norm', nn.BatchNorm2d(num_input_features))
|
||||
self.add_module('relu', nn.ReLU(inplace=True))
|
||||
self.add_module(
|
||||
'conv',
|
||||
nn.Conv2d(
|
||||
num_input_features,
|
||||
num_output_features,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=False
|
||||
)
|
||||
)
|
||||
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
|
||||
|
||||
|
||||
class DenseNet(nn.Module):
|
||||
"""Densely connected network.
|
||||
|
||||
Reference:
|
||||
Huang et al. Densely Connected Convolutional Networks. CVPR 2017.
|
||||
|
||||
Public keys:
|
||||
- ``densenet121``: DenseNet121.
|
||||
- ``densenet169``: DenseNet169.
|
||||
- ``densenet201``: DenseNet201.
|
||||
- ``densenet161``: DenseNet161.
|
||||
- ``densenet121_fc512``: DenseNet121 + FC.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
loss,
|
||||
growth_rate=32,
|
||||
block_config=(6, 12, 24, 16),
|
||||
num_init_features=64,
|
||||
bn_size=4,
|
||||
drop_rate=0,
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
super(DenseNet, self).__init__()
|
||||
self.loss = loss
|
||||
|
||||
# First convolution
|
||||
self.features = nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
'conv0',
|
||||
nn.Conv2d(
|
||||
3,
|
||||
num_init_features,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
bias=False
|
||||
)
|
||||
),
|
||||
('norm0', nn.BatchNorm2d(num_init_features)),
|
||||
('relu0', nn.ReLU(inplace=True)),
|
||||
(
|
||||
'pool0',
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Each denseblock
|
||||
num_features = num_init_features
|
||||
for i, num_layers in enumerate(block_config):
|
||||
block = _DenseBlock(
|
||||
num_layers=num_layers,
|
||||
num_input_features=num_features,
|
||||
bn_size=bn_size,
|
||||
growth_rate=growth_rate,
|
||||
drop_rate=drop_rate
|
||||
)
|
||||
self.features.add_module('denseblock%d' % (i+1), block)
|
||||
num_features = num_features + num_layers*growth_rate
|
||||
if i != len(block_config) - 1:
|
||||
trans = _Transition(
|
||||
num_input_features=num_features,
|
||||
num_output_features=num_features // 2
|
||||
)
|
||||
self.features.add_module('transition%d' % (i+1), trans)
|
||||
num_features = num_features // 2
|
||||
|
||||
# Final batch norm
|
||||
self.features.add_module('norm5', nn.BatchNorm2d(num_features))
|
||||
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.feature_dim = num_features
|
||||
self.fc = self._construct_fc_layer(fc_dims, num_features, dropout_p)
|
||||
|
||||
# Linear layer
|
||||
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
||||
|
||||
self._init_params()
|
||||
|
||||
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
||||
"""Constructs fully connected layer.
|
||||
|
||||
Args:
|
||||
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
|
||||
input_dim (int): input dimension
|
||||
dropout_p (float): dropout probability, if None, dropout is unused
|
||||
"""
|
||||
if fc_dims is None:
|
||||
self.feature_dim = input_dim
|
||||
return None
|
||||
|
||||
assert isinstance(
|
||||
fc_dims, (list, tuple)
|
||||
), 'fc_dims must be either list or tuple, but got {}'.format(
|
||||
type(fc_dims)
|
||||
)
|
||||
|
||||
layers = []
|
||||
for dim in fc_dims:
|
||||
layers.append(nn.Linear(input_dim, dim))
|
||||
layers.append(nn.BatchNorm1d(dim))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
if dropout_p is not None:
|
||||
layers.append(nn.Dropout(p=dropout_p))
|
||||
input_dim = dim
|
||||
|
||||
self.feature_dim = fc_dims[-1]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu'
|
||||
)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm1d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
f = self.features(x)
|
||||
f = F.relu(f, inplace=True)
|
||||
v = self.global_avgpool(f)
|
||||
v = v.view(v.size(0), -1)
|
||||
|
||||
if self.fc is not None:
|
||||
v = self.fc(v)
|
||||
|
||||
if not self.training:
|
||||
return v
|
||||
|
||||
y = self.classifier(v)
|
||||
|
||||
if self.loss == 'softmax':
|
||||
return y
|
||||
elif self.loss == 'triplet':
|
||||
return y, v
|
||||
else:
|
||||
raise KeyError('Unsupported loss: {}'.format(self.loss))
|
||||
|
||||
|
||||
def init_pretrained_weights(model, model_url):
|
||||
"""Initializes model with pretrained weights.
|
||||
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
pretrain_dict = model_zoo.load_url(model_url)
|
||||
|
||||
# '.'s are no longer allowed in module names, but pervious _DenseLayer
|
||||
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
|
||||
# They are also in the checkpoints in model_urls. This pattern is used
|
||||
# to find such keys.
|
||||
pattern = re.compile(
|
||||
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$'
|
||||
)
|
||||
for key in list(pretrain_dict.keys()):
|
||||
res = pattern.match(key)
|
||||
if res:
|
||||
new_key = res.group(1) + res.group(2)
|
||||
pretrain_dict[new_key] = pretrain_dict[key]
|
||||
del pretrain_dict[key]
|
||||
|
||||
model_dict = model.state_dict()
|
||||
pretrain_dict = {
|
||||
k: v
|
||||
for k, v in pretrain_dict.items()
|
||||
if k in model_dict and model_dict[k].size() == v.size()
|
||||
}
|
||||
model_dict.update(pretrain_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
|
||||
"""
|
||||
Dense network configurations:
|
||||
--
|
||||
densenet121: num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16)
|
||||
densenet169: num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32)
|
||||
densenet201: num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32)
|
||||
densenet161: num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24)
|
||||
"""
|
||||
|
||||
|
||||
def densenet121(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||
model = DenseNet(
|
||||
num_classes=num_classes,
|
||||
loss=loss,
|
||||
num_init_features=64,
|
||||
growth_rate=32,
|
||||
block_config=(6, 12, 24, 16),
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['densenet121'])
|
||||
return model
|
||||
|
||||
|
||||
def densenet169(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||
model = DenseNet(
|
||||
num_classes=num_classes,
|
||||
loss=loss,
|
||||
num_init_features=64,
|
||||
growth_rate=32,
|
||||
block_config=(6, 12, 32, 32),
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['densenet169'])
|
||||
return model
|
||||
|
||||
|
||||
def densenet201(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||
model = DenseNet(
|
||||
num_classes=num_classes,
|
||||
loss=loss,
|
||||
num_init_features=64,
|
||||
growth_rate=32,
|
||||
block_config=(6, 12, 48, 32),
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['densenet201'])
|
||||
return model
|
||||
|
||||
|
||||
def densenet161(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||
model = DenseNet(
|
||||
num_classes=num_classes,
|
||||
loss=loss,
|
||||
num_init_features=96,
|
||||
growth_rate=48,
|
||||
block_config=(6, 12, 36, 24),
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['densenet161'])
|
||||
return model
|
||||
|
||||
|
||||
def densenet121_fc512(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||
model = DenseNet(
|
||||
num_classes=num_classes,
|
||||
loss=loss,
|
||||
num_init_features=64,
|
||||
growth_rate=32,
|
||||
block_config=(6, 12, 24, 16),
|
||||
fc_dims=[512],
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['densenet121'])
|
||||
return model
|
||||
414
feeder/trackers/strongsort/deep/models/hacnn.py
Normal file
414
feeder/trackers/strongsort/deep/models/hacnn.py
Normal file
|
|
@ -0,0 +1,414 @@
|
|||
from __future__ import division, absolute_import
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
__all__ = ['HACNN']
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""Basic convolutional block.
|
||||
|
||||
convolution + batch normalization + relu.
|
||||
|
||||
Args:
|
||||
in_c (int): number of input channels.
|
||||
out_c (int): number of output channels.
|
||||
k (int or tuple): kernel size.
|
||||
s (int or tuple): stride.
|
||||
p (int or tuple): padding.
|
||||
"""
|
||||
|
||||
def __init__(self, in_c, out_c, k, s=1, p=0):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
|
||||
self.bn = nn.BatchNorm2d(out_c)
|
||||
|
||||
def forward(self, x):
|
||||
return F.relu(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class InceptionA(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(InceptionA, self).__init__()
|
||||
mid_channels = out_channels // 4
|
||||
|
||||
self.stream1 = nn.Sequential(
|
||||
ConvBlock(in_channels, mid_channels, 1),
|
||||
ConvBlock(mid_channels, mid_channels, 3, p=1),
|
||||
)
|
||||
self.stream2 = nn.Sequential(
|
||||
ConvBlock(in_channels, mid_channels, 1),
|
||||
ConvBlock(mid_channels, mid_channels, 3, p=1),
|
||||
)
|
||||
self.stream3 = nn.Sequential(
|
||||
ConvBlock(in_channels, mid_channels, 1),
|
||||
ConvBlock(mid_channels, mid_channels, 3, p=1),
|
||||
)
|
||||
self.stream4 = nn.Sequential(
|
||||
nn.AvgPool2d(3, stride=1, padding=1),
|
||||
ConvBlock(in_channels, mid_channels, 1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
s1 = self.stream1(x)
|
||||
s2 = self.stream2(x)
|
||||
s3 = self.stream3(x)
|
||||
s4 = self.stream4(x)
|
||||
y = torch.cat([s1, s2, s3, s4], dim=1)
|
||||
return y
|
||||
|
||||
|
||||
class InceptionB(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(InceptionB, self).__init__()
|
||||
mid_channels = out_channels // 4
|
||||
|
||||
self.stream1 = nn.Sequential(
|
||||
ConvBlock(in_channels, mid_channels, 1),
|
||||
ConvBlock(mid_channels, mid_channels, 3, s=2, p=1),
|
||||
)
|
||||
self.stream2 = nn.Sequential(
|
||||
ConvBlock(in_channels, mid_channels, 1),
|
||||
ConvBlock(mid_channels, mid_channels, 3, p=1),
|
||||
ConvBlock(mid_channels, mid_channels, 3, s=2, p=1),
|
||||
)
|
||||
self.stream3 = nn.Sequential(
|
||||
nn.MaxPool2d(3, stride=2, padding=1),
|
||||
ConvBlock(in_channels, mid_channels * 2, 1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
s1 = self.stream1(x)
|
||||
s2 = self.stream2(x)
|
||||
s3 = self.stream3(x)
|
||||
y = torch.cat([s1, s2, s3], dim=1)
|
||||
return y
|
||||
|
||||
|
||||
class SpatialAttn(nn.Module):
|
||||
"""Spatial Attention (Sec. 3.1.I.1)"""
|
||||
|
||||
def __init__(self):
|
||||
super(SpatialAttn, self).__init__()
|
||||
self.conv1 = ConvBlock(1, 1, 3, s=2, p=1)
|
||||
self.conv2 = ConvBlock(1, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
# global cross-channel averaging
|
||||
x = x.mean(1, keepdim=True)
|
||||
# 3-by-3 conv
|
||||
x = self.conv1(x)
|
||||
# bilinear resizing
|
||||
x = F.upsample(
|
||||
x, (x.size(2) * 2, x.size(3) * 2),
|
||||
mode='bilinear',
|
||||
align_corners=True
|
||||
)
|
||||
# scaling conv
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
class ChannelAttn(nn.Module):
|
||||
"""Channel Attention (Sec. 3.1.I.2)"""
|
||||
|
||||
def __init__(self, in_channels, reduction_rate=16):
|
||||
super(ChannelAttn, self).__init__()
|
||||
assert in_channels % reduction_rate == 0
|
||||
self.conv1 = ConvBlock(in_channels, in_channels // reduction_rate, 1)
|
||||
self.conv2 = ConvBlock(in_channels // reduction_rate, in_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
# squeeze operation (global average pooling)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
# excitation operation (2 conv layers)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
class SoftAttn(nn.Module):
|
||||
"""Soft Attention (Sec. 3.1.I)
|
||||
|
||||
Aim: Spatial Attention + Channel Attention
|
||||
|
||||
Output: attention maps with shape identical to input.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super(SoftAttn, self).__init__()
|
||||
self.spatial_attn = SpatialAttn()
|
||||
self.channel_attn = ChannelAttn(in_channels)
|
||||
self.conv = ConvBlock(in_channels, in_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
y_spatial = self.spatial_attn(x)
|
||||
y_channel = self.channel_attn(x)
|
||||
y = y_spatial * y_channel
|
||||
y = torch.sigmoid(self.conv(y))
|
||||
return y
|
||||
|
||||
|
||||
class HardAttn(nn.Module):
|
||||
"""Hard Attention (Sec. 3.1.II)"""
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super(HardAttn, self).__init__()
|
||||
self.fc = nn.Linear(in_channels, 4 * 2)
|
||||
self.init_params()
|
||||
|
||||
def init_params(self):
|
||||
self.fc.weight.data.zero_()
|
||||
self.fc.bias.data.copy_(
|
||||
torch.tensor(
|
||||
[0, -0.75, 0, -0.25, 0, 0.25, 0, 0.75], dtype=torch.float
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# squeeze operation (global average pooling)
|
||||
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1))
|
||||
# predict transformation parameters
|
||||
theta = torch.tanh(self.fc(x))
|
||||
theta = theta.view(-1, 4, 2)
|
||||
return theta
|
||||
|
||||
|
||||
class HarmAttn(nn.Module):
|
||||
"""Harmonious Attention (Sec. 3.1)"""
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super(HarmAttn, self).__init__()
|
||||
self.soft_attn = SoftAttn(in_channels)
|
||||
self.hard_attn = HardAttn(in_channels)
|
||||
|
||||
def forward(self, x):
|
||||
y_soft_attn = self.soft_attn(x)
|
||||
theta = self.hard_attn(x)
|
||||
return y_soft_attn, theta
|
||||
|
||||
|
||||
class HACNN(nn.Module):
|
||||
"""Harmonious Attention Convolutional Neural Network.
|
||||
|
||||
Reference:
|
||||
Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018.
|
||||
|
||||
Public keys:
|
||||
- ``hacnn``: HACNN.
|
||||
"""
|
||||
|
||||
# Args:
|
||||
# num_classes (int): number of classes to predict
|
||||
# nchannels (list): number of channels AFTER concatenation
|
||||
# feat_dim (int): feature dimension for a single stream
|
||||
# learn_region (bool): whether to learn region features (i.e. local branch)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
loss='softmax',
|
||||
nchannels=[128, 256, 384],
|
||||
feat_dim=512,
|
||||
learn_region=True,
|
||||
use_gpu=True,
|
||||
**kwargs
|
||||
):
|
||||
super(HACNN, self).__init__()
|
||||
self.loss = loss
|
||||
self.learn_region = learn_region
|
||||
self.use_gpu = use_gpu
|
||||
|
||||
self.conv = ConvBlock(3, 32, 3, s=2, p=1)
|
||||
|
||||
# Construct Inception + HarmAttn blocks
|
||||
# ============== Block 1 ==============
|
||||
self.inception1 = nn.Sequential(
|
||||
InceptionA(32, nchannels[0]),
|
||||
InceptionB(nchannels[0], nchannels[0]),
|
||||
)
|
||||
self.ha1 = HarmAttn(nchannels[0])
|
||||
|
||||
# ============== Block 2 ==============
|
||||
self.inception2 = nn.Sequential(
|
||||
InceptionA(nchannels[0], nchannels[1]),
|
||||
InceptionB(nchannels[1], nchannels[1]),
|
||||
)
|
||||
self.ha2 = HarmAttn(nchannels[1])
|
||||
|
||||
# ============== Block 3 ==============
|
||||
self.inception3 = nn.Sequential(
|
||||
InceptionA(nchannels[1], nchannels[2]),
|
||||
InceptionB(nchannels[2], nchannels[2]),
|
||||
)
|
||||
self.ha3 = HarmAttn(nchannels[2])
|
||||
|
||||
self.fc_global = nn.Sequential(
|
||||
nn.Linear(nchannels[2], feat_dim),
|
||||
nn.BatchNorm1d(feat_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.classifier_global = nn.Linear(feat_dim, num_classes)
|
||||
|
||||
if self.learn_region:
|
||||
self.init_scale_factors()
|
||||
self.local_conv1 = InceptionB(32, nchannels[0])
|
||||
self.local_conv2 = InceptionB(nchannels[0], nchannels[1])
|
||||
self.local_conv3 = InceptionB(nchannels[1], nchannels[2])
|
||||
self.fc_local = nn.Sequential(
|
||||
nn.Linear(nchannels[2] * 4, feat_dim),
|
||||
nn.BatchNorm1d(feat_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.classifier_local = nn.Linear(feat_dim, num_classes)
|
||||
self.feat_dim = feat_dim * 2
|
||||
else:
|
||||
self.feat_dim = feat_dim
|
||||
|
||||
def init_scale_factors(self):
|
||||
# initialize scale factors (s_w, s_h) for four regions
|
||||
self.scale_factors = []
|
||||
self.scale_factors.append(
|
||||
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
|
||||
)
|
||||
self.scale_factors.append(
|
||||
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
|
||||
)
|
||||
self.scale_factors.append(
|
||||
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
|
||||
)
|
||||
self.scale_factors.append(
|
||||
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
|
||||
)
|
||||
|
||||
def stn(self, x, theta):
|
||||
"""Performs spatial transform
|
||||
|
||||
x: (batch, channel, height, width)
|
||||
theta: (batch, 2, 3)
|
||||
"""
|
||||
grid = F.affine_grid(theta, x.size())
|
||||
x = F.grid_sample(x, grid)
|
||||
return x
|
||||
|
||||
def transform_theta(self, theta_i, region_idx):
|
||||
"""Transforms theta to include (s_w, s_h), resulting in (batch, 2, 3)"""
|
||||
scale_factors = self.scale_factors[region_idx]
|
||||
theta = torch.zeros(theta_i.size(0), 2, 3)
|
||||
theta[:, :, :2] = scale_factors
|
||||
theta[:, :, -1] = theta_i
|
||||
if self.use_gpu:
|
||||
theta = theta.cuda()
|
||||
return theta
|
||||
|
||||
def forward(self, x):
|
||||
assert x.size(2) == 160 and x.size(3) == 64, \
|
||||
'Input size does not match, expected (160, 64) but got ({}, {})'.format(x.size(2), x.size(3))
|
||||
x = self.conv(x)
|
||||
|
||||
# ============== Block 1 ==============
|
||||
# global branch
|
||||
x1 = self.inception1(x)
|
||||
x1_attn, x1_theta = self.ha1(x1)
|
||||
x1_out = x1 * x1_attn
|
||||
# local branch
|
||||
if self.learn_region:
|
||||
x1_local_list = []
|
||||
for region_idx in range(4):
|
||||
x1_theta_i = x1_theta[:, region_idx, :]
|
||||
x1_theta_i = self.transform_theta(x1_theta_i, region_idx)
|
||||
x1_trans_i = self.stn(x, x1_theta_i)
|
||||
x1_trans_i = F.upsample(
|
||||
x1_trans_i, (24, 28), mode='bilinear', align_corners=True
|
||||
)
|
||||
x1_local_i = self.local_conv1(x1_trans_i)
|
||||
x1_local_list.append(x1_local_i)
|
||||
|
||||
# ============== Block 2 ==============
|
||||
# Block 2
|
||||
# global branch
|
||||
x2 = self.inception2(x1_out)
|
||||
x2_attn, x2_theta = self.ha2(x2)
|
||||
x2_out = x2 * x2_attn
|
||||
# local branch
|
||||
if self.learn_region:
|
||||
x2_local_list = []
|
||||
for region_idx in range(4):
|
||||
x2_theta_i = x2_theta[:, region_idx, :]
|
||||
x2_theta_i = self.transform_theta(x2_theta_i, region_idx)
|
||||
x2_trans_i = self.stn(x1_out, x2_theta_i)
|
||||
x2_trans_i = F.upsample(
|
||||
x2_trans_i, (12, 14), mode='bilinear', align_corners=True
|
||||
)
|
||||
x2_local_i = x2_trans_i + x1_local_list[region_idx]
|
||||
x2_local_i = self.local_conv2(x2_local_i)
|
||||
x2_local_list.append(x2_local_i)
|
||||
|
||||
# ============== Block 3 ==============
|
||||
# Block 3
|
||||
# global branch
|
||||
x3 = self.inception3(x2_out)
|
||||
x3_attn, x3_theta = self.ha3(x3)
|
||||
x3_out = x3 * x3_attn
|
||||
# local branch
|
||||
if self.learn_region:
|
||||
x3_local_list = []
|
||||
for region_idx in range(4):
|
||||
x3_theta_i = x3_theta[:, region_idx, :]
|
||||
x3_theta_i = self.transform_theta(x3_theta_i, region_idx)
|
||||
x3_trans_i = self.stn(x2_out, x3_theta_i)
|
||||
x3_trans_i = F.upsample(
|
||||
x3_trans_i, (6, 7), mode='bilinear', align_corners=True
|
||||
)
|
||||
x3_local_i = x3_trans_i + x2_local_list[region_idx]
|
||||
x3_local_i = self.local_conv3(x3_local_i)
|
||||
x3_local_list.append(x3_local_i)
|
||||
|
||||
# ============== Feature generation ==============
|
||||
# global branch
|
||||
x_global = F.avg_pool2d(x3_out,
|
||||
x3_out.size()[2:]
|
||||
).view(x3_out.size(0), x3_out.size(1))
|
||||
x_global = self.fc_global(x_global)
|
||||
# local branch
|
||||
if self.learn_region:
|
||||
x_local_list = []
|
||||
for region_idx in range(4):
|
||||
x_local_i = x3_local_list[region_idx]
|
||||
x_local_i = F.avg_pool2d(x_local_i,
|
||||
x_local_i.size()[2:]
|
||||
).view(x_local_i.size(0), -1)
|
||||
x_local_list.append(x_local_i)
|
||||
x_local = torch.cat(x_local_list, 1)
|
||||
x_local = self.fc_local(x_local)
|
||||
|
||||
if not self.training:
|
||||
# l2 normalization before concatenation
|
||||
if self.learn_region:
|
||||
x_global = x_global / x_global.norm(p=2, dim=1, keepdim=True)
|
||||
x_local = x_local / x_local.norm(p=2, dim=1, keepdim=True)
|
||||
return torch.cat([x_global, x_local], 1)
|
||||
else:
|
||||
return x_global
|
||||
|
||||
prelogits_global = self.classifier_global(x_global)
|
||||
if self.learn_region:
|
||||
prelogits_local = self.classifier_local(x_local)
|
||||
|
||||
if self.loss == 'softmax':
|
||||
if self.learn_region:
|
||||
return (prelogits_global, prelogits_local)
|
||||
else:
|
||||
return prelogits_global
|
||||
|
||||
elif self.loss == 'triplet':
|
||||
if self.learn_region:
|
||||
return (prelogits_global, prelogits_local), (x_global, x_local)
|
||||
else:
|
||||
return prelogits_global, x_global
|
||||
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
361
feeder/trackers/strongsort/deep/models/inceptionresnetv2.py
Normal file
361
feeder/trackers/strongsort/deep/models/inceptionresnetv2.py
Normal file
|
|
@ -0,0 +1,361 @@
|
|||
"""
|
||||
Code imported from https://github.com/Cadene/pretrained-models.pytorch
|
||||
"""
|
||||
from __future__ import division, absolute_import
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
__all__ = ['inceptionresnetv2']
|
||||
|
||||
pretrained_settings = {
|
||||
'inceptionresnetv2': {
|
||||
'imagenet': {
|
||||
'url':
|
||||
'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 299, 299],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'num_classes': 1000
|
||||
},
|
||||
'imagenet+background': {
|
||||
'url':
|
||||
'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 299, 299],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'num_classes': 1001
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BasicConv2d(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
|
||||
super(BasicConv2d, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=False
|
||||
) # verify bias false
|
||||
self.bn = nn.BatchNorm2d(
|
||||
out_planes,
|
||||
eps=0.001, # value found in tensorflow
|
||||
momentum=0.1, # default pytorch value
|
||||
affine=True
|
||||
)
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class Mixed_5b(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Mixed_5b, self).__init__()
|
||||
|
||||
self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(192, 48, kernel_size=1, stride=1),
|
||||
BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
|
||||
)
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
BasicConv2d(192, 64, kernel_size=1, stride=1),
|
||||
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
|
||||
self.branch3 = nn.Sequential(
|
||||
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
||||
BasicConv2d(192, 64, kernel_size=1, stride=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
x3 = self.branch3(x)
|
||||
out = torch.cat((x0, x1, x2, x3), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Block35(nn.Module):
|
||||
|
||||
def __init__(self, scale=1.0):
|
||||
super(Block35, self).__init__()
|
||||
|
||||
self.scale = scale
|
||||
|
||||
self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(320, 32, kernel_size=1, stride=1),
|
||||
BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
BasicConv2d(320, 32, kernel_size=1, stride=1),
|
||||
BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
|
||||
self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
out = torch.cat((x0, x1, x2), 1)
|
||||
out = self.conv2d(out)
|
||||
out = out * self.scale + x
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Mixed_6a(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Mixed_6a, self).__init__()
|
||||
|
||||
self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(320, 256, kernel_size=1, stride=1),
|
||||
BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(256, 384, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch2 = nn.MaxPool2d(3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
out = torch.cat((x0, x1, x2), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Block17(nn.Module):
|
||||
|
||||
def __init__(self, scale=1.0):
|
||||
super(Block17, self).__init__()
|
||||
|
||||
self.scale = scale
|
||||
|
||||
self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(1088, 128, kernel_size=1, stride=1),
|
||||
BasicConv2d(
|
||||
128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)
|
||||
),
|
||||
BasicConv2d(
|
||||
160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)
|
||||
)
|
||||
)
|
||||
|
||||
self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
out = torch.cat((x0, x1), 1)
|
||||
out = self.conv2d(out)
|
||||
out = out * self.scale + x
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Mixed_7a(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Mixed_7a, self).__init__()
|
||||
|
||||
self.branch0 = nn.Sequential(
|
||||
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
||||
BasicConv2d(256, 384, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
||||
BasicConv2d(256, 288, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
||||
BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(288, 320, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch3 = nn.MaxPool2d(3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
x3 = self.branch3(x)
|
||||
out = torch.cat((x0, x1, x2, x3), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Block8(nn.Module):
|
||||
|
||||
def __init__(self, scale=1.0, noReLU=False):
|
||||
super(Block8, self).__init__()
|
||||
|
||||
self.scale = scale
|
||||
self.noReLU = noReLU
|
||||
|
||||
self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(2080, 192, kernel_size=1, stride=1),
|
||||
BasicConv2d(
|
||||
192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)
|
||||
),
|
||||
BasicConv2d(
|
||||
224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)
|
||||
)
|
||||
)
|
||||
|
||||
self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
|
||||
if not self.noReLU:
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
out = torch.cat((x0, x1), 1)
|
||||
out = self.conv2d(out)
|
||||
out = out * self.scale + x
|
||||
if not self.noReLU:
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
# ----------------
|
||||
# Model Definition
|
||||
# ----------------
|
||||
class InceptionResNetV2(nn.Module):
|
||||
"""Inception-ResNet-V2.
|
||||
|
||||
Reference:
|
||||
Szegedy et al. Inception-v4, Inception-ResNet and the Impact of Residual
|
||||
Connections on Learning. AAAI 2017.
|
||||
|
||||
Public keys:
|
||||
- ``inceptionresnetv2``: Inception-ResNet-V2.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes, loss='softmax', **kwargs):
|
||||
super(InceptionResNetV2, self).__init__()
|
||||
self.loss = loss
|
||||
|
||||
# Modules
|
||||
self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
|
||||
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
|
||||
self.conv2d_2b = BasicConv2d(
|
||||
32, 64, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
self.maxpool_3a = nn.MaxPool2d(3, stride=2)
|
||||
self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
|
||||
self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
|
||||
self.maxpool_5a = nn.MaxPool2d(3, stride=2)
|
||||
self.mixed_5b = Mixed_5b()
|
||||
self.repeat = nn.Sequential(
|
||||
Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17),
|
||||
Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17),
|
||||
Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17),
|
||||
Block35(scale=0.17)
|
||||
)
|
||||
self.mixed_6a = Mixed_6a()
|
||||
self.repeat_1 = nn.Sequential(
|
||||
Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10),
|
||||
Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10),
|
||||
Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10),
|
||||
Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10),
|
||||
Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10),
|
||||
Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10),
|
||||
Block17(scale=0.10), Block17(scale=0.10)
|
||||
)
|
||||
self.mixed_7a = Mixed_7a()
|
||||
self.repeat_2 = nn.Sequential(
|
||||
Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20),
|
||||
Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20),
|
||||
Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20)
|
||||
)
|
||||
|
||||
self.block8 = Block8(noReLU=True)
|
||||
self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(1536, num_classes)
|
||||
|
||||
def load_imagenet_weights(self):
|
||||
settings = pretrained_settings['inceptionresnetv2']['imagenet']
|
||||
pretrain_dict = model_zoo.load_url(settings['url'])
|
||||
model_dict = self.state_dict()
|
||||
pretrain_dict = {
|
||||
k: v
|
||||
for k, v in pretrain_dict.items()
|
||||
if k in model_dict and model_dict[k].size() == v.size()
|
||||
}
|
||||
model_dict.update(pretrain_dict)
|
||||
self.load_state_dict(model_dict)
|
||||
|
||||
def featuremaps(self, x):
|
||||
x = self.conv2d_1a(x)
|
||||
x = self.conv2d_2a(x)
|
||||
x = self.conv2d_2b(x)
|
||||
x = self.maxpool_3a(x)
|
||||
x = self.conv2d_3b(x)
|
||||
x = self.conv2d_4a(x)
|
||||
x = self.maxpool_5a(x)
|
||||
x = self.mixed_5b(x)
|
||||
x = self.repeat(x)
|
||||
x = self.mixed_6a(x)
|
||||
x = self.repeat_1(x)
|
||||
x = self.mixed_7a(x)
|
||||
x = self.repeat_2(x)
|
||||
x = self.block8(x)
|
||||
x = self.conv2d_7b(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
f = self.featuremaps(x)
|
||||
v = self.global_avgpool(f)
|
||||
v = v.view(v.size(0), -1)
|
||||
|
||||
if not self.training:
|
||||
return v
|
||||
|
||||
y = self.classifier(v)
|
||||
|
||||
if self.loss == 'softmax':
|
||||
return y
|
||||
elif self.loss == 'triplet':
|
||||
return y, v
|
||||
else:
|
||||
raise KeyError('Unsupported loss: {}'.format(self.loss))
|
||||
|
||||
|
||||
def inceptionresnetv2(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||
model = InceptionResNetV2(num_classes=num_classes, loss=loss, **kwargs)
|
||||
if pretrained:
|
||||
model.load_imagenet_weights()
|
||||
return model
|
||||
381
feeder/trackers/strongsort/deep/models/inceptionv4.py
Normal file
381
feeder/trackers/strongsort/deep/models/inceptionv4.py
Normal file
|
|
@ -0,0 +1,381 @@
|
|||
from __future__ import division, absolute_import
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
__all__ = ['inceptionv4']
|
||||
"""
|
||||
Code imported from https://github.com/Cadene/pretrained-models.pytorch
|
||||
"""
|
||||
|
||||
pretrained_settings = {
|
||||
'inceptionv4': {
|
||||
'imagenet': {
|
||||
'url':
|
||||
'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 299, 299],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'num_classes': 1000
|
||||
},
|
||||
'imagenet+background': {
|
||||
'url':
|
||||
'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 299, 299],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'num_classes': 1001
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BasicConv2d(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
|
||||
super(BasicConv2d, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=False
|
||||
) # verify bias false
|
||||
self.bn = nn.BatchNorm2d(
|
||||
out_planes,
|
||||
eps=0.001, # value found in tensorflow
|
||||
momentum=0.1, # default pytorch value
|
||||
affine=True
|
||||
)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class Mixed_3a(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Mixed_3a, self).__init__()
|
||||
self.maxpool = nn.MaxPool2d(3, stride=2)
|
||||
self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.maxpool(x)
|
||||
x1 = self.conv(x)
|
||||
out = torch.cat((x0, x1), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Mixed_4a(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Mixed_4a, self).__init__()
|
||||
|
||||
self.branch0 = nn.Sequential(
|
||||
BasicConv2d(160, 64, kernel_size=1, stride=1),
|
||||
BasicConv2d(64, 96, kernel_size=3, stride=1)
|
||||
)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(160, 64, kernel_size=1, stride=1),
|
||||
BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
||||
BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)),
|
||||
BasicConv2d(64, 96, kernel_size=(3, 3), stride=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
out = torch.cat((x0, x1), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Mixed_5a(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Mixed_5a, self).__init__()
|
||||
self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2)
|
||||
self.maxpool = nn.MaxPool2d(3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.conv(x)
|
||||
x1 = self.maxpool(x)
|
||||
out = torch.cat((x0, x1), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Inception_A(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Inception_A, self).__init__()
|
||||
self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(384, 64, kernel_size=1, stride=1),
|
||||
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
BasicConv2d(384, 64, kernel_size=1, stride=1),
|
||||
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
|
||||
self.branch3 = nn.Sequential(
|
||||
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
||||
BasicConv2d(384, 96, kernel_size=1, stride=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
x3 = self.branch3(x)
|
||||
out = torch.cat((x0, x1, x2, x3), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Reduction_A(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Reduction_A, self).__init__()
|
||||
self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(384, 192, kernel_size=1, stride=1),
|
||||
BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(224, 256, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch2 = nn.MaxPool2d(3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
out = torch.cat((x0, x1, x2), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Inception_B(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Inception_B, self).__init__()
|
||||
self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(1024, 192, kernel_size=1, stride=1),
|
||||
BasicConv2d(
|
||||
192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)
|
||||
),
|
||||
BasicConv2d(
|
||||
224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0)
|
||||
)
|
||||
)
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
BasicConv2d(1024, 192, kernel_size=1, stride=1),
|
||||
BasicConv2d(
|
||||
192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)
|
||||
),
|
||||
BasicConv2d(
|
||||
192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)
|
||||
),
|
||||
BasicConv2d(
|
||||
224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)
|
||||
),
|
||||
BasicConv2d(
|
||||
224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)
|
||||
)
|
||||
)
|
||||
|
||||
self.branch3 = nn.Sequential(
|
||||
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
||||
BasicConv2d(1024, 128, kernel_size=1, stride=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
x3 = self.branch3(x)
|
||||
out = torch.cat((x0, x1, x2, x3), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Reduction_B(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Reduction_B, self).__init__()
|
||||
|
||||
self.branch0 = nn.Sequential(
|
||||
BasicConv2d(1024, 192, kernel_size=1, stride=1),
|
||||
BasicConv2d(192, 192, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(1024, 256, kernel_size=1, stride=1),
|
||||
BasicConv2d(
|
||||
256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)
|
||||
),
|
||||
BasicConv2d(
|
||||
256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)
|
||||
), BasicConv2d(320, 320, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch2 = nn.MaxPool2d(3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
out = torch.cat((x0, x1, x2), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Inception_C(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Inception_C, self).__init__()
|
||||
|
||||
self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1)
|
||||
self.branch1_1a = BasicConv2d(
|
||||
384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)
|
||||
)
|
||||
self.branch1_1b = BasicConv2d(
|
||||
384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)
|
||||
)
|
||||
|
||||
self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1)
|
||||
self.branch2_1 = BasicConv2d(
|
||||
384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0)
|
||||
)
|
||||
self.branch2_2 = BasicConv2d(
|
||||
448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1)
|
||||
)
|
||||
self.branch2_3a = BasicConv2d(
|
||||
512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)
|
||||
)
|
||||
self.branch2_3b = BasicConv2d(
|
||||
512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)
|
||||
)
|
||||
|
||||
self.branch3 = nn.Sequential(
|
||||
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
||||
BasicConv2d(1536, 256, kernel_size=1, stride=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
|
||||
x1_0 = self.branch1_0(x)
|
||||
x1_1a = self.branch1_1a(x1_0)
|
||||
x1_1b = self.branch1_1b(x1_0)
|
||||
x1 = torch.cat((x1_1a, x1_1b), 1)
|
||||
|
||||
x2_0 = self.branch2_0(x)
|
||||
x2_1 = self.branch2_1(x2_0)
|
||||
x2_2 = self.branch2_2(x2_1)
|
||||
x2_3a = self.branch2_3a(x2_2)
|
||||
x2_3b = self.branch2_3b(x2_2)
|
||||
x2 = torch.cat((x2_3a, x2_3b), 1)
|
||||
|
||||
x3 = self.branch3(x)
|
||||
|
||||
out = torch.cat((x0, x1, x2, x3), 1)
|
||||
return out
|
||||
|
||||
|
||||
class InceptionV4(nn.Module):
|
||||
"""Inception-v4.
|
||||
|
||||
Reference:
|
||||
Szegedy et al. Inception-v4, Inception-ResNet and the Impact of Residual
|
||||
Connections on Learning. AAAI 2017.
|
||||
|
||||
Public keys:
|
||||
- ``inceptionv4``: InceptionV4.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes, loss, **kwargs):
|
||||
super(InceptionV4, self).__init__()
|
||||
self.loss = loss
|
||||
|
||||
self.features = nn.Sequential(
|
||||
BasicConv2d(3, 32, kernel_size=3, stride=2),
|
||||
BasicConv2d(32, 32, kernel_size=3, stride=1),
|
||||
BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
||||
Mixed_3a(),
|
||||
Mixed_4a(),
|
||||
Mixed_5a(),
|
||||
Inception_A(),
|
||||
Inception_A(),
|
||||
Inception_A(),
|
||||
Inception_A(),
|
||||
Reduction_A(), # Mixed_6a
|
||||
Inception_B(),
|
||||
Inception_B(),
|
||||
Inception_B(),
|
||||
Inception_B(),
|
||||
Inception_B(),
|
||||
Inception_B(),
|
||||
Inception_B(),
|
||||
Reduction_B(), # Mixed_7a
|
||||
Inception_C(),
|
||||
Inception_C(),
|
||||
Inception_C()
|
||||
)
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(1536, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
f = self.features(x)
|
||||
v = self.global_avgpool(f)
|
||||
v = v.view(v.size(0), -1)
|
||||
|
||||
if not self.training:
|
||||
return v
|
||||
|
||||
y = self.classifier(v)
|
||||
|
||||
if self.loss == 'softmax':
|
||||
return y
|
||||
elif self.loss == 'triplet':
|
||||
return y, v
|
||||
else:
|
||||
raise KeyError('Unsupported loss: {}'.format(self.loss))
|
||||
|
||||
|
||||
def init_pretrained_weights(model, model_url):
|
||||
"""Initializes model with pretrained weights.
|
||||
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
pretrain_dict = model_zoo.load_url(model_url)
|
||||
model_dict = model.state_dict()
|
||||
pretrain_dict = {
|
||||
k: v
|
||||
for k, v in pretrain_dict.items()
|
||||
if k in model_dict and model_dict[k].size() == v.size()
|
||||
}
|
||||
model_dict.update(pretrain_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
|
||||
def inceptionv4(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||
model = InceptionV4(num_classes, loss, **kwargs)
|
||||
if pretrained:
|
||||
model_url = pretrained_settings['inceptionv4']['imagenet']['url']
|
||||
init_pretrained_weights(model, model_url)
|
||||
return model
|
||||
269
feeder/trackers/strongsort/deep/models/mlfn.py
Normal file
269
feeder/trackers/strongsort/deep/models/mlfn.py
Normal file
|
|
@ -0,0 +1,269 @@
|
|||
from __future__ import division, absolute_import
|
||||
import torch
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
__all__ = ['mlfn']
|
||||
|
||||
model_urls = {
|
||||
# training epoch = 5, top1 = 51.6
|
||||
'imagenet':
|
||||
'https://mega.nz/#!YHxAhaxC!yu9E6zWl0x5zscSouTdbZu8gdFFytDdl-RAdD2DEfpk',
|
||||
}
|
||||
|
||||
|
||||
class MLFNBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, in_channels, out_channels, stride, fsm_channels, groups=32
|
||||
):
|
||||
super(MLFNBlock, self).__init__()
|
||||
self.groups = groups
|
||||
mid_channels = out_channels // 2
|
||||
|
||||
# Factor Modules
|
||||
self.fm_conv1 = nn.Conv2d(in_channels, mid_channels, 1, bias=False)
|
||||
self.fm_bn1 = nn.BatchNorm2d(mid_channels)
|
||||
self.fm_conv2 = nn.Conv2d(
|
||||
mid_channels,
|
||||
mid_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False,
|
||||
groups=self.groups
|
||||
)
|
||||
self.fm_bn2 = nn.BatchNorm2d(mid_channels)
|
||||
self.fm_conv3 = nn.Conv2d(mid_channels, out_channels, 1, bias=False)
|
||||
self.fm_bn3 = nn.BatchNorm2d(out_channels)
|
||||
|
||||
# Factor Selection Module
|
||||
self.fsm = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(in_channels, fsm_channels[0], 1),
|
||||
nn.BatchNorm2d(fsm_channels[0]),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(fsm_channels[0], fsm_channels[1], 1),
|
||||
nn.BatchNorm2d(fsm_channels[1]),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(fsm_channels[1], self.groups, 1),
|
||||
nn.BatchNorm2d(self.groups),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
self.downsample = None
|
||||
if in_channels != out_channels or stride > 1:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels, out_channels, 1, stride=stride, bias=False
|
||||
),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
s = self.fsm(x)
|
||||
|
||||
# reduce dimension
|
||||
x = self.fm_conv1(x)
|
||||
x = self.fm_bn1(x)
|
||||
x = F.relu(x, inplace=True)
|
||||
|
||||
# group convolution
|
||||
x = self.fm_conv2(x)
|
||||
x = self.fm_bn2(x)
|
||||
x = F.relu(x, inplace=True)
|
||||
|
||||
# factor selection
|
||||
b, c = x.size(0), x.size(1)
|
||||
n = c // self.groups
|
||||
ss = s.repeat(1, n, 1, 1) # from (b, g, 1, 1) to (b, g*n=c, 1, 1)
|
||||
ss = ss.view(b, n, self.groups, 1, 1)
|
||||
ss = ss.permute(0, 2, 1, 3, 4).contiguous()
|
||||
ss = ss.view(b, c, 1, 1)
|
||||
x = ss * x
|
||||
|
||||
# recover dimension
|
||||
x = self.fm_conv3(x)
|
||||
x = self.fm_bn3(x)
|
||||
x = F.relu(x, inplace=True)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(residual)
|
||||
|
||||
return F.relu(residual + x, inplace=True), s
|
||||
|
||||
|
||||
class MLFN(nn.Module):
|
||||
"""Multi-Level Factorisation Net.
|
||||
|
||||
Reference:
|
||||
Chang et al. Multi-Level Factorisation Net for
|
||||
Person Re-Identification. CVPR 2018.
|
||||
|
||||
Public keys:
|
||||
- ``mlfn``: MLFN (Multi-Level Factorisation Net).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
loss='softmax',
|
||||
groups=32,
|
||||
channels=[64, 256, 512, 1024, 2048],
|
||||
embed_dim=1024,
|
||||
**kwargs
|
||||
):
|
||||
super(MLFN, self).__init__()
|
||||
self.loss = loss
|
||||
self.groups = groups
|
||||
|
||||
# first convolutional layer
|
||||
self.conv1 = nn.Conv2d(3, channels[0], 7, stride=2, padding=3)
|
||||
self.bn1 = nn.BatchNorm2d(channels[0])
|
||||
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
|
||||
# main body
|
||||
self.feature = nn.ModuleList(
|
||||
[
|
||||
# layer 1-3
|
||||
MLFNBlock(channels[0], channels[1], 1, [128, 64], self.groups),
|
||||
MLFNBlock(channels[1], channels[1], 1, [128, 64], self.groups),
|
||||
MLFNBlock(channels[1], channels[1], 1, [128, 64], self.groups),
|
||||
# layer 4-7
|
||||
MLFNBlock(
|
||||
channels[1], channels[2], 2, [256, 128], self.groups
|
||||
),
|
||||
MLFNBlock(
|
||||
channels[2], channels[2], 1, [256, 128], self.groups
|
||||
),
|
||||
MLFNBlock(
|
||||
channels[2], channels[2], 1, [256, 128], self.groups
|
||||
),
|
||||
MLFNBlock(
|
||||
channels[2], channels[2], 1, [256, 128], self.groups
|
||||
),
|
||||
# layer 8-13
|
||||
MLFNBlock(
|
||||
channels[2], channels[3], 2, [512, 128], self.groups
|
||||
),
|
||||
MLFNBlock(
|
||||
channels[3], channels[3], 1, [512, 128], self.groups
|
||||
),
|
||||
MLFNBlock(
|
||||
channels[3], channels[3], 1, [512, 128], self.groups
|
||||
),
|
||||
MLFNBlock(
|
||||
channels[3], channels[3], 1, [512, 128], self.groups
|
||||
),
|
||||
MLFNBlock(
|
||||
channels[3], channels[3], 1, [512, 128], self.groups
|
||||
),
|
||||
MLFNBlock(
|
||||
channels[3], channels[3], 1, [512, 128], self.groups
|
||||
),
|
||||
# layer 14-16
|
||||
MLFNBlock(
|
||||
channels[3], channels[4], 2, [512, 128], self.groups
|
||||
),
|
||||
MLFNBlock(
|
||||
channels[4], channels[4], 1, [512, 128], self.groups
|
||||
),
|
||||
MLFNBlock(
|
||||
channels[4], channels[4], 1, [512, 128], self.groups
|
||||
),
|
||||
]
|
||||
)
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
|
||||
# projection functions
|
||||
self.fc_x = nn.Sequential(
|
||||
nn.Conv2d(channels[4], embed_dim, 1, bias=False),
|
||||
nn.BatchNorm2d(embed_dim),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.fc_s = nn.Sequential(
|
||||
nn.Conv2d(self.groups * 16, embed_dim, 1, bias=False),
|
||||
nn.BatchNorm2d(embed_dim),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self.classifier = nn.Linear(embed_dim, num_classes)
|
||||
|
||||
self.init_params()
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu'
|
||||
)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = F.relu(x, inplace=True)
|
||||
x = self.maxpool(x)
|
||||
|
||||
s_hat = []
|
||||
for block in self.feature:
|
||||
x, s = block(x)
|
||||
s_hat.append(s)
|
||||
s_hat = torch.cat(s_hat, 1)
|
||||
|
||||
x = self.global_avgpool(x)
|
||||
x = self.fc_x(x)
|
||||
s_hat = self.fc_s(s_hat)
|
||||
|
||||
v = (x+s_hat) * 0.5
|
||||
v = v.view(v.size(0), -1)
|
||||
|
||||
if not self.training:
|
||||
return v
|
||||
|
||||
y = self.classifier(v)
|
||||
|
||||
if self.loss == 'softmax':
|
||||
return y
|
||||
elif self.loss == 'triplet':
|
||||
return y, v
|
||||
else:
|
||||
raise KeyError('Unsupported loss: {}'.format(self.loss))
|
||||
|
||||
|
||||
def init_pretrained_weights(model, model_url):
|
||||
"""Initializes model with pretrained weights.
|
||||
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
pretrain_dict = model_zoo.load_url(model_url)
|
||||
model_dict = model.state_dict()
|
||||
pretrain_dict = {
|
||||
k: v
|
||||
for k, v in pretrain_dict.items()
|
||||
if k in model_dict and model_dict[k].size() == v.size()
|
||||
}
|
||||
model_dict.update(pretrain_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
|
||||
def mlfn(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||
model = MLFN(num_classes, loss, **kwargs)
|
||||
if pretrained:
|
||||
# init_pretrained_weights(model, model_urls['imagenet'])
|
||||
import warnings
|
||||
warnings.warn(
|
||||
'The imagenet pretrained weights need to be manually downloaded from {}'
|
||||
.format(model_urls['imagenet'])
|
||||
)
|
||||
return model
|
||||
274
feeder/trackers/strongsort/deep/models/mobilenetv2.py
Normal file
274
feeder/trackers/strongsort/deep/models/mobilenetv2.py
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
from __future__ import division, absolute_import
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
__all__ = ['mobilenetv2_x1_0', 'mobilenetv2_x1_4']
|
||||
|
||||
model_urls = {
|
||||
# 1.0: top-1 71.3
|
||||
'mobilenetv2_x1_0':
|
||||
'https://mega.nz/#!NKp2wAIA!1NH1pbNzY_M2hVk_hdsxNM1NUOWvvGPHhaNr-fASF6c',
|
||||
# 1.4: top-1 73.9
|
||||
'mobilenetv2_x1_4':
|
||||
'https://mega.nz/#!RGhgEIwS!xN2s2ZdyqI6vQ3EwgmRXLEW3khr9tpXg96G9SUJugGk',
|
||||
}
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""Basic convolutional block.
|
||||
|
||||
convolution (bias discarded) + batch normalization + relu6.
|
||||
|
||||
Args:
|
||||
in_c (int): number of input channels.
|
||||
out_c (int): number of output channels.
|
||||
k (int or tuple): kernel size.
|
||||
s (int or tuple): stride.
|
||||
p (int or tuple): padding.
|
||||
g (int): number of blocked connections from input channels
|
||||
to output channels (default: 1).
|
||||
"""
|
||||
|
||||
def __init__(self, in_c, out_c, k, s=1, p=0, g=1):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_c, out_c, k, stride=s, padding=p, bias=False, groups=g
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_c)
|
||||
|
||||
def forward(self, x):
|
||||
return F.relu6(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, expansion_factor, stride=1):
|
||||
super(Bottleneck, self).__init__()
|
||||
mid_channels = in_channels * expansion_factor
|
||||
self.use_residual = stride == 1 and in_channels == out_channels
|
||||
self.conv1 = ConvBlock(in_channels, mid_channels, 1)
|
||||
self.dwconv2 = ConvBlock(
|
||||
mid_channels, mid_channels, 3, stride, 1, g=mid_channels
|
||||
)
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.Conv2d(mid_channels, out_channels, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
m = self.conv1(x)
|
||||
m = self.dwconv2(m)
|
||||
m = self.conv3(m)
|
||||
if self.use_residual:
|
||||
return x + m
|
||||
else:
|
||||
return m
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
"""MobileNetV2.
|
||||
|
||||
Reference:
|
||||
Sandler et al. MobileNetV2: Inverted Residuals and
|
||||
Linear Bottlenecks. CVPR 2018.
|
||||
|
||||
Public keys:
|
||||
- ``mobilenetv2_x1_0``: MobileNetV2 x1.0.
|
||||
- ``mobilenetv2_x1_4``: MobileNetV2 x1.4.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
width_mult=1,
|
||||
loss='softmax',
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
):
|
||||
super(MobileNetV2, self).__init__()
|
||||
self.loss = loss
|
||||
self.in_channels = int(32 * width_mult)
|
||||
self.feature_dim = int(1280 * width_mult) if width_mult > 1 else 1280
|
||||
|
||||
# construct layers
|
||||
self.conv1 = ConvBlock(3, self.in_channels, 3, s=2, p=1)
|
||||
self.conv2 = self._make_layer(
|
||||
Bottleneck, 1, int(16 * width_mult), 1, 1
|
||||
)
|
||||
self.conv3 = self._make_layer(
|
||||
Bottleneck, 6, int(24 * width_mult), 2, 2
|
||||
)
|
||||
self.conv4 = self._make_layer(
|
||||
Bottleneck, 6, int(32 * width_mult), 3, 2
|
||||
)
|
||||
self.conv5 = self._make_layer(
|
||||
Bottleneck, 6, int(64 * width_mult), 4, 2
|
||||
)
|
||||
self.conv6 = self._make_layer(
|
||||
Bottleneck, 6, int(96 * width_mult), 3, 1
|
||||
)
|
||||
self.conv7 = self._make_layer(
|
||||
Bottleneck, 6, int(160 * width_mult), 3, 2
|
||||
)
|
||||
self.conv8 = self._make_layer(
|
||||
Bottleneck, 6, int(320 * width_mult), 1, 1
|
||||
)
|
||||
self.conv9 = ConvBlock(self.in_channels, self.feature_dim, 1)
|
||||
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = self._construct_fc_layer(
|
||||
fc_dims, self.feature_dim, dropout_p
|
||||
)
|
||||
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
||||
|
||||
self._init_params()
|
||||
|
||||
def _make_layer(self, block, t, c, n, s):
|
||||
# t: expansion factor
|
||||
# c: output channels
|
||||
# n: number of blocks
|
||||
# s: stride for first layer
|
||||
layers = []
|
||||
layers.append(block(self.in_channels, c, t, s))
|
||||
self.in_channels = c
|
||||
for i in range(1, n):
|
||||
layers.append(block(self.in_channels, c, t))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
||||
"""Constructs fully connected layer.
|
||||
|
||||
Args:
|
||||
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
|
||||
input_dim (int): input dimension
|
||||
dropout_p (float): dropout probability, if None, dropout is unused
|
||||
"""
|
||||
if fc_dims is None:
|
||||
self.feature_dim = input_dim
|
||||
return None
|
||||
|
||||
assert isinstance(
|
||||
fc_dims, (list, tuple)
|
||||
), 'fc_dims must be either list or tuple, but got {}'.format(
|
||||
type(fc_dims)
|
||||
)
|
||||
|
||||
layers = []
|
||||
for dim in fc_dims:
|
||||
layers.append(nn.Linear(input_dim, dim))
|
||||
layers.append(nn.BatchNorm1d(dim))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
if dropout_p is not None:
|
||||
layers.append(nn.Dropout(p=dropout_p))
|
||||
input_dim = dim
|
||||
|
||||
self.feature_dim = fc_dims[-1]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu'
|
||||
)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm1d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def featuremaps(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
x = self.conv5(x)
|
||||
x = self.conv6(x)
|
||||
x = self.conv7(x)
|
||||
x = self.conv8(x)
|
||||
x = self.conv9(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
f = self.featuremaps(x)
|
||||
v = self.global_avgpool(f)
|
||||
v = v.view(v.size(0), -1)
|
||||
|
||||
if self.fc is not None:
|
||||
v = self.fc(v)
|
||||
|
||||
if not self.training:
|
||||
return v
|
||||
|
||||
y = self.classifier(v)
|
||||
|
||||
if self.loss == 'softmax':
|
||||
return y
|
||||
elif self.loss == 'triplet':
|
||||
return y, v
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
|
||||
def init_pretrained_weights(model, model_url):
|
||||
"""Initializes model with pretrained weights.
|
||||
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
pretrain_dict = model_zoo.load_url(model_url)
|
||||
model_dict = model.state_dict()
|
||||
pretrain_dict = {
|
||||
k: v
|
||||
for k, v in pretrain_dict.items()
|
||||
if k in model_dict and model_dict[k].size() == v.size()
|
||||
}
|
||||
model_dict.update(pretrain_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
|
||||
def mobilenetv2_x1_0(num_classes, loss, pretrained=True, **kwargs):
|
||||
model = MobileNetV2(
|
||||
num_classes,
|
||||
loss=loss,
|
||||
width_mult=1,
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
# init_pretrained_weights(model, model_urls['mobilenetv2_x1_0'])
|
||||
import warnings
|
||||
warnings.warn(
|
||||
'The imagenet pretrained weights need to be manually downloaded from {}'
|
||||
.format(model_urls['mobilenetv2_x1_0'])
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def mobilenetv2_x1_4(num_classes, loss, pretrained=True, **kwargs):
|
||||
model = MobileNetV2(
|
||||
num_classes,
|
||||
loss=loss,
|
||||
width_mult=1.4,
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
# init_pretrained_weights(model, model_urls['mobilenetv2_x1_4'])
|
||||
import warnings
|
||||
warnings.warn(
|
||||
'The imagenet pretrained weights need to be manually downloaded from {}'
|
||||
.format(model_urls['mobilenetv2_x1_4'])
|
||||
)
|
||||
return model
|
||||
206
feeder/trackers/strongsort/deep/models/mudeep.py
Normal file
206
feeder/trackers/strongsort/deep/models/mudeep.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
from __future__ import division, absolute_import
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
__all__ = ['MuDeep']
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""Basic convolutional block.
|
||||
|
||||
convolution + batch normalization + relu.
|
||||
|
||||
Args:
|
||||
in_c (int): number of input channels.
|
||||
out_c (int): number of output channels.
|
||||
k (int or tuple): kernel size.
|
||||
s (int or tuple): stride.
|
||||
p (int or tuple): padding.
|
||||
"""
|
||||
|
||||
def __init__(self, in_c, out_c, k, s, p):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
|
||||
self.bn = nn.BatchNorm2d(out_c)
|
||||
|
||||
def forward(self, x):
|
||||
return F.relu(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class ConvLayers(nn.Module):
|
||||
"""Preprocessing layers."""
|
||||
|
||||
def __init__(self):
|
||||
super(ConvLayers, self).__init__()
|
||||
self.conv1 = ConvBlock(3, 48, k=3, s=1, p=1)
|
||||
self.conv2 = ConvBlock(48, 96, k=3, s=1, p=1)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.maxpool(x)
|
||||
return x
|
||||
|
||||
|
||||
class MultiScaleA(nn.Module):
|
||||
"""Multi-scale stream layer A (Sec.3.1)"""
|
||||
|
||||
def __init__(self):
|
||||
super(MultiScaleA, self).__init__()
|
||||
self.stream1 = nn.Sequential(
|
||||
ConvBlock(96, 96, k=1, s=1, p=0),
|
||||
ConvBlock(96, 24, k=3, s=1, p=1),
|
||||
)
|
||||
self.stream2 = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
|
||||
ConvBlock(96, 24, k=1, s=1, p=0),
|
||||
)
|
||||
self.stream3 = ConvBlock(96, 24, k=1, s=1, p=0)
|
||||
self.stream4 = nn.Sequential(
|
||||
ConvBlock(96, 16, k=1, s=1, p=0),
|
||||
ConvBlock(16, 24, k=3, s=1, p=1),
|
||||
ConvBlock(24, 24, k=3, s=1, p=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
s1 = self.stream1(x)
|
||||
s2 = self.stream2(x)
|
||||
s3 = self.stream3(x)
|
||||
s4 = self.stream4(x)
|
||||
y = torch.cat([s1, s2, s3, s4], dim=1)
|
||||
return y
|
||||
|
||||
|
||||
class Reduction(nn.Module):
|
||||
"""Reduction layer (Sec.3.1)"""
|
||||
|
||||
def __init__(self):
|
||||
super(Reduction, self).__init__()
|
||||
self.stream1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.stream2 = ConvBlock(96, 96, k=3, s=2, p=1)
|
||||
self.stream3 = nn.Sequential(
|
||||
ConvBlock(96, 48, k=1, s=1, p=0),
|
||||
ConvBlock(48, 56, k=3, s=1, p=1),
|
||||
ConvBlock(56, 64, k=3, s=2, p=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
s1 = self.stream1(x)
|
||||
s2 = self.stream2(x)
|
||||
s3 = self.stream3(x)
|
||||
y = torch.cat([s1, s2, s3], dim=1)
|
||||
return y
|
||||
|
||||
|
||||
class MultiScaleB(nn.Module):
|
||||
"""Multi-scale stream layer B (Sec.3.1)"""
|
||||
|
||||
def __init__(self):
|
||||
super(MultiScaleB, self).__init__()
|
||||
self.stream1 = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
|
||||
ConvBlock(256, 256, k=1, s=1, p=0),
|
||||
)
|
||||
self.stream2 = nn.Sequential(
|
||||
ConvBlock(256, 64, k=1, s=1, p=0),
|
||||
ConvBlock(64, 128, k=(1, 3), s=1, p=(0, 1)),
|
||||
ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)),
|
||||
)
|
||||
self.stream3 = ConvBlock(256, 256, k=1, s=1, p=0)
|
||||
self.stream4 = nn.Sequential(
|
||||
ConvBlock(256, 64, k=1, s=1, p=0),
|
||||
ConvBlock(64, 64, k=(1, 3), s=1, p=(0, 1)),
|
||||
ConvBlock(64, 128, k=(3, 1), s=1, p=(1, 0)),
|
||||
ConvBlock(128, 128, k=(1, 3), s=1, p=(0, 1)),
|
||||
ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
s1 = self.stream1(x)
|
||||
s2 = self.stream2(x)
|
||||
s3 = self.stream3(x)
|
||||
s4 = self.stream4(x)
|
||||
return s1, s2, s3, s4
|
||||
|
||||
|
||||
class Fusion(nn.Module):
|
||||
"""Saliency-based learning fusion layer (Sec.3.2)"""
|
||||
|
||||
def __init__(self):
|
||||
super(Fusion, self).__init__()
|
||||
self.a1 = nn.Parameter(torch.rand(1, 256, 1, 1))
|
||||
self.a2 = nn.Parameter(torch.rand(1, 256, 1, 1))
|
||||
self.a3 = nn.Parameter(torch.rand(1, 256, 1, 1))
|
||||
self.a4 = nn.Parameter(torch.rand(1, 256, 1, 1))
|
||||
|
||||
# We add an average pooling layer to reduce the spatial dimension
|
||||
# of feature maps, which differs from the original paper.
|
||||
self.avgpool = nn.AvgPool2d(kernel_size=4, stride=4, padding=0)
|
||||
|
||||
def forward(self, x1, x2, x3, x4):
|
||||
s1 = self.a1.expand_as(x1) * x1
|
||||
s2 = self.a2.expand_as(x2) * x2
|
||||
s3 = self.a3.expand_as(x3) * x3
|
||||
s4 = self.a4.expand_as(x4) * x4
|
||||
y = self.avgpool(s1 + s2 + s3 + s4)
|
||||
return y
|
||||
|
||||
|
||||
class MuDeep(nn.Module):
|
||||
"""Multiscale deep neural network.
|
||||
|
||||
Reference:
|
||||
Qian et al. Multi-scale Deep Learning Architectures
|
||||
for Person Re-identification. ICCV 2017.
|
||||
|
||||
Public keys:
|
||||
- ``mudeep``: Multiscale deep neural network.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes, loss='softmax', **kwargs):
|
||||
super(MuDeep, self).__init__()
|
||||
self.loss = loss
|
||||
|
||||
self.block1 = ConvLayers()
|
||||
self.block2 = MultiScaleA()
|
||||
self.block3 = Reduction()
|
||||
self.block4 = MultiScaleB()
|
||||
self.block5 = Fusion()
|
||||
|
||||
# Due to this fully connected layer, input image has to be fixed
|
||||
# in shape, i.e. (3, 256, 128), such that the last convolutional feature
|
||||
# maps are of shape (256, 16, 8). If input shape is changed,
|
||||
# the input dimension of this layer has to be changed accordingly.
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(256 * 16 * 8, 4096),
|
||||
nn.BatchNorm1d(4096),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.classifier = nn.Linear(4096, num_classes)
|
||||
self.feat_dim = 4096
|
||||
|
||||
def featuremaps(self, x):
|
||||
x = self.block1(x)
|
||||
x = self.block2(x)
|
||||
x = self.block3(x)
|
||||
x = self.block4(x)
|
||||
x = self.block5(*x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.featuremaps(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
y = self.classifier(x)
|
||||
|
||||
if not self.training:
|
||||
return x
|
||||
|
||||
if self.loss == 'softmax':
|
||||
return y
|
||||
elif self.loss == 'triplet':
|
||||
return y, x
|
||||
else:
|
||||
raise KeyError('Unsupported loss: {}'.format(self.loss))
|
||||
1131
feeder/trackers/strongsort/deep/models/nasnet.py
Normal file
1131
feeder/trackers/strongsort/deep/models/nasnet.py
Normal file
File diff suppressed because it is too large
Load diff
598
feeder/trackers/strongsort/deep/models/osnet.py
Normal file
598
feeder/trackers/strongsort/deep/models/osnet.py
Normal file
|
|
@ -0,0 +1,598 @@
|
|||
from __future__ import division, absolute_import
|
||||
import warnings
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
__all__ = [
|
||||
'osnet_x1_0', 'osnet_x0_75', 'osnet_x0_5', 'osnet_x0_25', 'osnet_ibn_x1_0'
|
||||
]
|
||||
|
||||
pretrained_urls = {
|
||||
'osnet_x1_0':
|
||||
'https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY',
|
||||
'osnet_x0_75':
|
||||
'https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq',
|
||||
'osnet_x0_5':
|
||||
'https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i',
|
||||
'osnet_x0_25':
|
||||
'https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs',
|
||||
'osnet_ibn_x1_0':
|
||||
'https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l'
|
||||
}
|
||||
|
||||
|
||||
##########
|
||||
# Basic layers
|
||||
##########
|
||||
class ConvLayer(nn.Module):
|
||||
"""Convolution layer (conv + bn + relu)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=1,
|
||||
IN=False
|
||||
):
|
||||
super(ConvLayer, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=False,
|
||||
groups=groups
|
||||
)
|
||||
if IN:
|
||||
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
|
||||
else:
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class Conv1x1(nn.Module):
|
||||
"""1x1 convolution + bn + relu."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, groups=1):
|
||||
super(Conv1x1, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=groups
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class Conv1x1Linear(nn.Module):
|
||||
"""1x1 convolution + bn (w/o non-linearity)."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super(Conv1x1Linear, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, out_channels, 1, stride=stride, padding=0, bias=False
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class Conv3x3(nn.Module):
|
||||
"""3x3 convolution + bn + relu."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, groups=1):
|
||||
super(Conv3x3, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False,
|
||||
groups=groups
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class LightConv3x3(nn.Module):
|
||||
"""Lightweight 3x3 convolution.
|
||||
|
||||
1x1 (linear) + dw 3x3 (nonlinear).
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(LightConv3x3, self).__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, out_channels, 1, stride=1, padding=0, bias=False
|
||||
)
|
||||
self.conv2 = nn.Conv2d(
|
||||
out_channels,
|
||||
out_channels,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
groups=out_channels
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
##########
|
||||
# Building blocks for omni-scale feature learning
|
||||
##########
|
||||
class ChannelGate(nn.Module):
|
||||
"""A mini-network that generates channel-wise gates conditioned on input tensor."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
num_gates=None,
|
||||
return_gates=False,
|
||||
gate_activation='sigmoid',
|
||||
reduction=16,
|
||||
layer_norm=False
|
||||
):
|
||||
super(ChannelGate, self).__init__()
|
||||
if num_gates is None:
|
||||
num_gates = in_channels
|
||||
self.return_gates = return_gates
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc1 = nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels // reduction,
|
||||
kernel_size=1,
|
||||
bias=True,
|
||||
padding=0
|
||||
)
|
||||
self.norm1 = None
|
||||
if layer_norm:
|
||||
self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc2 = nn.Conv2d(
|
||||
in_channels // reduction,
|
||||
num_gates,
|
||||
kernel_size=1,
|
||||
bias=True,
|
||||
padding=0
|
||||
)
|
||||
if gate_activation == 'sigmoid':
|
||||
self.gate_activation = nn.Sigmoid()
|
||||
elif gate_activation == 'relu':
|
||||
self.gate_activation = nn.ReLU(inplace=True)
|
||||
elif gate_activation == 'linear':
|
||||
self.gate_activation = None
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unknown gate activation: {}".format(gate_activation)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
input = x
|
||||
x = self.global_avgpool(x)
|
||||
x = self.fc1(x)
|
||||
if self.norm1 is not None:
|
||||
x = self.norm1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
if self.gate_activation is not None:
|
||||
x = self.gate_activation(x)
|
||||
if self.return_gates:
|
||||
return x
|
||||
return input * x
|
||||
|
||||
|
||||
class OSBlock(nn.Module):
|
||||
"""Omni-scale feature learning block."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
IN=False,
|
||||
bottleneck_reduction=4,
|
||||
**kwargs
|
||||
):
|
||||
super(OSBlock, self).__init__()
|
||||
mid_channels = out_channels // bottleneck_reduction
|
||||
self.conv1 = Conv1x1(in_channels, mid_channels)
|
||||
self.conv2a = LightConv3x3(mid_channels, mid_channels)
|
||||
self.conv2b = nn.Sequential(
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
)
|
||||
self.conv2c = nn.Sequential(
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
)
|
||||
self.conv2d = nn.Sequential(
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
)
|
||||
self.gate = ChannelGate(mid_channels)
|
||||
self.conv3 = Conv1x1Linear(mid_channels, out_channels)
|
||||
self.downsample = None
|
||||
if in_channels != out_channels:
|
||||
self.downsample = Conv1x1Linear(in_channels, out_channels)
|
||||
self.IN = None
|
||||
if IN:
|
||||
self.IN = nn.InstanceNorm2d(out_channels, affine=True)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x1 = self.conv1(x)
|
||||
x2a = self.conv2a(x1)
|
||||
x2b = self.conv2b(x1)
|
||||
x2c = self.conv2c(x1)
|
||||
x2d = self.conv2d(x1)
|
||||
x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d)
|
||||
x3 = self.conv3(x2)
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(identity)
|
||||
out = x3 + identity
|
||||
if self.IN is not None:
|
||||
out = self.IN(out)
|
||||
return F.relu(out)
|
||||
|
||||
|
||||
##########
|
||||
# Network architecture
|
||||
##########
|
||||
class OSNet(nn.Module):
|
||||
"""Omni-Scale Network.
|
||||
|
||||
Reference:
|
||||
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
|
||||
- Zhou et al. Learning Generalisable Omni-Scale Representations
|
||||
for Person Re-Identification. TPAMI, 2021.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
blocks,
|
||||
layers,
|
||||
channels,
|
||||
feature_dim=512,
|
||||
loss='softmax',
|
||||
IN=False,
|
||||
**kwargs
|
||||
):
|
||||
super(OSNet, self).__init__()
|
||||
num_blocks = len(blocks)
|
||||
assert num_blocks == len(layers)
|
||||
assert num_blocks == len(channels) - 1
|
||||
self.loss = loss
|
||||
self.feature_dim = feature_dim
|
||||
|
||||
# convolutional backbone
|
||||
self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN)
|
||||
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
self.conv2 = self._make_layer(
|
||||
blocks[0],
|
||||
layers[0],
|
||||
channels[0],
|
||||
channels[1],
|
||||
reduce_spatial_size=True,
|
||||
IN=IN
|
||||
)
|
||||
self.conv3 = self._make_layer(
|
||||
blocks[1],
|
||||
layers[1],
|
||||
channels[1],
|
||||
channels[2],
|
||||
reduce_spatial_size=True
|
||||
)
|
||||
self.conv4 = self._make_layer(
|
||||
blocks[2],
|
||||
layers[2],
|
||||
channels[2],
|
||||
channels[3],
|
||||
reduce_spatial_size=False
|
||||
)
|
||||
self.conv5 = Conv1x1(channels[3], channels[3])
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
# fully connected layer
|
||||
self.fc = self._construct_fc_layer(
|
||||
self.feature_dim, channels[3], dropout_p=None
|
||||
)
|
||||
# identity classification layer
|
||||
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
||||
|
||||
self._init_params()
|
||||
|
||||
def _make_layer(
|
||||
self,
|
||||
block,
|
||||
layer,
|
||||
in_channels,
|
||||
out_channels,
|
||||
reduce_spatial_size,
|
||||
IN=False
|
||||
):
|
||||
layers = []
|
||||
|
||||
layers.append(block(in_channels, out_channels, IN=IN))
|
||||
for i in range(1, layer):
|
||||
layers.append(block(out_channels, out_channels, IN=IN))
|
||||
|
||||
if reduce_spatial_size:
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
Conv1x1(out_channels, out_channels),
|
||||
nn.AvgPool2d(2, stride=2)
|
||||
)
|
||||
)
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
||||
if fc_dims is None or fc_dims < 0:
|
||||
self.feature_dim = input_dim
|
||||
return None
|
||||
|
||||
if isinstance(fc_dims, int):
|
||||
fc_dims = [fc_dims]
|
||||
|
||||
layers = []
|
||||
for dim in fc_dims:
|
||||
layers.append(nn.Linear(input_dim, dim))
|
||||
layers.append(nn.BatchNorm1d(dim))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
if dropout_p is not None:
|
||||
layers.append(nn.Dropout(p=dropout_p))
|
||||
input_dim = dim
|
||||
|
||||
self.feature_dim = fc_dims[-1]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu'
|
||||
)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.BatchNorm1d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def featuremaps(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
x = self.conv5(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, return_featuremaps=False):
|
||||
x = self.featuremaps(x)
|
||||
if return_featuremaps:
|
||||
return x
|
||||
v = self.global_avgpool(x)
|
||||
v = v.view(v.size(0), -1)
|
||||
if self.fc is not None:
|
||||
v = self.fc(v)
|
||||
if not self.training:
|
||||
return v
|
||||
y = self.classifier(v)
|
||||
if self.loss == 'softmax':
|
||||
return y
|
||||
elif self.loss == 'triplet':
|
||||
return y, v
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
|
||||
def init_pretrained_weights(model, key=''):
|
||||
"""Initializes model with pretrained weights.
|
||||
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
import os
|
||||
import errno
|
||||
import gdown
|
||||
from collections import OrderedDict
|
||||
|
||||
def _get_torch_home():
|
||||
ENV_TORCH_HOME = 'TORCH_HOME'
|
||||
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
||||
DEFAULT_CACHE_DIR = '~/.cache'
|
||||
torch_home = os.path.expanduser(
|
||||
os.getenv(
|
||||
ENV_TORCH_HOME,
|
||||
os.path.join(
|
||||
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
|
||||
)
|
||||
)
|
||||
)
|
||||
return torch_home
|
||||
|
||||
torch_home = _get_torch_home()
|
||||
model_dir = os.path.join(torch_home, 'checkpoints')
|
||||
try:
|
||||
os.makedirs(model_dir)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EEXIST:
|
||||
# Directory already exists, ignore.
|
||||
pass
|
||||
else:
|
||||
# Unexpected OSError, re-raise.
|
||||
raise
|
||||
filename = key + '_imagenet.pth'
|
||||
cached_file = os.path.join(model_dir, filename)
|
||||
|
||||
if not os.path.exists(cached_file):
|
||||
gdown.download(pretrained_urls[key], cached_file, quiet=False)
|
||||
|
||||
state_dict = torch.load(cached_file)
|
||||
model_dict = model.state_dict()
|
||||
new_state_dict = OrderedDict()
|
||||
matched_layers, discarded_layers = [], []
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith('module.'):
|
||||
k = k[7:] # discard module.
|
||||
|
||||
if k in model_dict and model_dict[k].size() == v.size():
|
||||
new_state_dict[k] = v
|
||||
matched_layers.append(k)
|
||||
else:
|
||||
discarded_layers.append(k)
|
||||
|
||||
model_dict.update(new_state_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
if len(matched_layers) == 0:
|
||||
warnings.warn(
|
||||
'The pretrained weights from "{}" cannot be loaded, '
|
||||
'please check the key names manually '
|
||||
'(** ignored and continue **)'.format(cached_file)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
'Successfully loaded imagenet pretrained weights from "{}"'.
|
||||
format(cached_file)
|
||||
)
|
||||
if len(discarded_layers) > 0:
|
||||
print(
|
||||
'** The following layers are discarded '
|
||||
'due to unmatched keys or layer size: {}'.
|
||||
format(discarded_layers)
|
||||
)
|
||||
|
||||
|
||||
##########
|
||||
# Instantiation
|
||||
##########
|
||||
def osnet_x1_0(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
|
||||
# standard size (width x1.0)
|
||||
model = OSNet(
|
||||
num_classes,
|
||||
blocks=[OSBlock, OSBlock, OSBlock],
|
||||
layers=[2, 2, 2],
|
||||
channels=[64, 256, 384, 512],
|
||||
loss=loss,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_x1_0')
|
||||
return model
|
||||
|
||||
|
||||
def osnet_x0_75(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
|
||||
# medium size (width x0.75)
|
||||
model = OSNet(
|
||||
num_classes,
|
||||
blocks=[OSBlock, OSBlock, OSBlock],
|
||||
layers=[2, 2, 2],
|
||||
channels=[48, 192, 288, 384],
|
||||
loss=loss,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_x0_75')
|
||||
return model
|
||||
|
||||
|
||||
def osnet_x0_5(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
|
||||
# tiny size (width x0.5)
|
||||
model = OSNet(
|
||||
num_classes,
|
||||
blocks=[OSBlock, OSBlock, OSBlock],
|
||||
layers=[2, 2, 2],
|
||||
channels=[32, 128, 192, 256],
|
||||
loss=loss,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_x0_5')
|
||||
return model
|
||||
|
||||
|
||||
def osnet_x0_25(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
|
||||
# very tiny size (width x0.25)
|
||||
model = OSNet(
|
||||
num_classes,
|
||||
blocks=[OSBlock, OSBlock, OSBlock],
|
||||
layers=[2, 2, 2],
|
||||
channels=[16, 64, 96, 128],
|
||||
loss=loss,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_x0_25')
|
||||
return model
|
||||
|
||||
|
||||
def osnet_ibn_x1_0(
|
||||
num_classes=1000, pretrained=True, loss='softmax', **kwargs
|
||||
):
|
||||
# standard size (width x1.0) + IBN layer
|
||||
# Ref: Pan et al. Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net. ECCV, 2018.
|
||||
model = OSNet(
|
||||
num_classes,
|
||||
blocks=[OSBlock, OSBlock, OSBlock],
|
||||
layers=[2, 2, 2],
|
||||
channels=[64, 256, 384, 512],
|
||||
loss=loss,
|
||||
IN=True,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_ibn_x1_0')
|
||||
return model
|
||||
609
feeder/trackers/strongsort/deep/models/osnet_ain.py
Normal file
609
feeder/trackers/strongsort/deep/models/osnet_ain.py
Normal file
|
|
@ -0,0 +1,609 @@
|
|||
from __future__ import division, absolute_import
|
||||
import warnings
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
__all__ = [
|
||||
'osnet_ain_x1_0', 'osnet_ain_x0_75', 'osnet_ain_x0_5', 'osnet_ain_x0_25'
|
||||
]
|
||||
|
||||
pretrained_urls = {
|
||||
'osnet_ain_x1_0':
|
||||
'https://drive.google.com/uc?id=1-CaioD9NaqbHK_kzSMW8VE4_3KcsRjEo',
|
||||
'osnet_ain_x0_75':
|
||||
'https://drive.google.com/uc?id=1apy0hpsMypqstfencdH-jKIUEFOW4xoM',
|
||||
'osnet_ain_x0_5':
|
||||
'https://drive.google.com/uc?id=1KusKvEYyKGDTUBVRxRiz55G31wkihB6l',
|
||||
'osnet_ain_x0_25':
|
||||
'https://drive.google.com/uc?id=1SxQt2AvmEcgWNhaRb2xC4rP6ZwVDP0Wt'
|
||||
}
|
||||
|
||||
|
||||
##########
|
||||
# Basic layers
|
||||
##########
|
||||
class ConvLayer(nn.Module):
|
||||
"""Convolution layer (conv + bn + relu)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=1,
|
||||
IN=False
|
||||
):
|
||||
super(ConvLayer, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=False,
|
||||
groups=groups
|
||||
)
|
||||
if IN:
|
||||
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
|
||||
else:
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return self.relu(x)
|
||||
|
||||
|
||||
class Conv1x1(nn.Module):
|
||||
"""1x1 convolution + bn + relu."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, groups=1):
|
||||
super(Conv1x1, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=groups
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return self.relu(x)
|
||||
|
||||
|
||||
class Conv1x1Linear(nn.Module):
|
||||
"""1x1 convolution + bn (w/o non-linearity)."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, bn=True):
|
||||
super(Conv1x1Linear, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, out_channels, 1, stride=stride, padding=0, bias=False
|
||||
)
|
||||
self.bn = None
|
||||
if bn:
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.bn is not None:
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class Conv3x3(nn.Module):
|
||||
"""3x3 convolution + bn + relu."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, groups=1):
|
||||
super(Conv3x3, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False,
|
||||
groups=groups
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return self.relu(x)
|
||||
|
||||
|
||||
class LightConv3x3(nn.Module):
|
||||
"""Lightweight 3x3 convolution.
|
||||
|
||||
1x1 (linear) + dw 3x3 (nonlinear).
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(LightConv3x3, self).__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, out_channels, 1, stride=1, padding=0, bias=False
|
||||
)
|
||||
self.conv2 = nn.Conv2d(
|
||||
out_channels,
|
||||
out_channels,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
groups=out_channels
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.bn(x)
|
||||
return self.relu(x)
|
||||
|
||||
|
||||
class LightConvStream(nn.Module):
|
||||
"""Lightweight convolution stream."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, depth):
|
||||
super(LightConvStream, self).__init__()
|
||||
assert depth >= 1, 'depth must be equal to or larger than 1, but got {}'.format(
|
||||
depth
|
||||
)
|
||||
layers = []
|
||||
layers += [LightConv3x3(in_channels, out_channels)]
|
||||
for i in range(depth - 1):
|
||||
layers += [LightConv3x3(out_channels, out_channels)]
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
##########
|
||||
# Building blocks for omni-scale feature learning
|
||||
##########
|
||||
class ChannelGate(nn.Module):
|
||||
"""A mini-network that generates channel-wise gates conditioned on input tensor."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
num_gates=None,
|
||||
return_gates=False,
|
||||
gate_activation='sigmoid',
|
||||
reduction=16,
|
||||
layer_norm=False
|
||||
):
|
||||
super(ChannelGate, self).__init__()
|
||||
if num_gates is None:
|
||||
num_gates = in_channels
|
||||
self.return_gates = return_gates
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc1 = nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels // reduction,
|
||||
kernel_size=1,
|
||||
bias=True,
|
||||
padding=0
|
||||
)
|
||||
self.norm1 = None
|
||||
if layer_norm:
|
||||
self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Conv2d(
|
||||
in_channels // reduction,
|
||||
num_gates,
|
||||
kernel_size=1,
|
||||
bias=True,
|
||||
padding=0
|
||||
)
|
||||
if gate_activation == 'sigmoid':
|
||||
self.gate_activation = nn.Sigmoid()
|
||||
elif gate_activation == 'relu':
|
||||
self.gate_activation = nn.ReLU()
|
||||
elif gate_activation == 'linear':
|
||||
self.gate_activation = None
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unknown gate activation: {}".format(gate_activation)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
input = x
|
||||
x = self.global_avgpool(x)
|
||||
x = self.fc1(x)
|
||||
if self.norm1 is not None:
|
||||
x = self.norm1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
if self.gate_activation is not None:
|
||||
x = self.gate_activation(x)
|
||||
if self.return_gates:
|
||||
return x
|
||||
return input * x
|
||||
|
||||
|
||||
class OSBlock(nn.Module):
|
||||
"""Omni-scale feature learning block."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
|
||||
super(OSBlock, self).__init__()
|
||||
assert T >= 1
|
||||
assert out_channels >= reduction and out_channels % reduction == 0
|
||||
mid_channels = out_channels // reduction
|
||||
|
||||
self.conv1 = Conv1x1(in_channels, mid_channels)
|
||||
self.conv2 = nn.ModuleList()
|
||||
for t in range(1, T + 1):
|
||||
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
|
||||
self.gate = ChannelGate(mid_channels)
|
||||
self.conv3 = Conv1x1Linear(mid_channels, out_channels)
|
||||
self.downsample = None
|
||||
if in_channels != out_channels:
|
||||
self.downsample = Conv1x1Linear(in_channels, out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x1 = self.conv1(x)
|
||||
x2 = 0
|
||||
for conv2_t in self.conv2:
|
||||
x2_t = conv2_t(x1)
|
||||
x2 = x2 + self.gate(x2_t)
|
||||
x3 = self.conv3(x2)
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(identity)
|
||||
out = x3 + identity
|
||||
return F.relu(out)
|
||||
|
||||
|
||||
class OSBlockINin(nn.Module):
|
||||
"""Omni-scale feature learning block with instance normalization."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
|
||||
super(OSBlockINin, self).__init__()
|
||||
assert T >= 1
|
||||
assert out_channels >= reduction and out_channels % reduction == 0
|
||||
mid_channels = out_channels // reduction
|
||||
|
||||
self.conv1 = Conv1x1(in_channels, mid_channels)
|
||||
self.conv2 = nn.ModuleList()
|
||||
for t in range(1, T + 1):
|
||||
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
|
||||
self.gate = ChannelGate(mid_channels)
|
||||
self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn=False)
|
||||
self.downsample = None
|
||||
if in_channels != out_channels:
|
||||
self.downsample = Conv1x1Linear(in_channels, out_channels)
|
||||
self.IN = nn.InstanceNorm2d(out_channels, affine=True)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x1 = self.conv1(x)
|
||||
x2 = 0
|
||||
for conv2_t in self.conv2:
|
||||
x2_t = conv2_t(x1)
|
||||
x2 = x2 + self.gate(x2_t)
|
||||
x3 = self.conv3(x2)
|
||||
x3 = self.IN(x3) # IN inside residual
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(identity)
|
||||
out = x3 + identity
|
||||
return F.relu(out)
|
||||
|
||||
|
||||
##########
|
||||
# Network architecture
|
||||
##########
|
||||
class OSNet(nn.Module):
|
||||
"""Omni-Scale Network.
|
||||
|
||||
Reference:
|
||||
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
|
||||
- Zhou et al. Learning Generalisable Omni-Scale Representations
|
||||
for Person Re-Identification. TPAMI, 2021.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
blocks,
|
||||
layers,
|
||||
channels,
|
||||
feature_dim=512,
|
||||
loss='softmax',
|
||||
conv1_IN=False,
|
||||
**kwargs
|
||||
):
|
||||
super(OSNet, self).__init__()
|
||||
num_blocks = len(blocks)
|
||||
assert num_blocks == len(layers)
|
||||
assert num_blocks == len(channels) - 1
|
||||
self.loss = loss
|
||||
self.feature_dim = feature_dim
|
||||
|
||||
# convolutional backbone
|
||||
self.conv1 = ConvLayer(
|
||||
3, channels[0], 7, stride=2, padding=3, IN=conv1_IN
|
||||
)
|
||||
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
self.conv2 = self._make_layer(
|
||||
blocks[0], layers[0], channels[0], channels[1]
|
||||
)
|
||||
self.pool2 = nn.Sequential(
|
||||
Conv1x1(channels[1], channels[1]), nn.AvgPool2d(2, stride=2)
|
||||
)
|
||||
self.conv3 = self._make_layer(
|
||||
blocks[1], layers[1], channels[1], channels[2]
|
||||
)
|
||||
self.pool3 = nn.Sequential(
|
||||
Conv1x1(channels[2], channels[2]), nn.AvgPool2d(2, stride=2)
|
||||
)
|
||||
self.conv4 = self._make_layer(
|
||||
blocks[2], layers[2], channels[2], channels[3]
|
||||
)
|
||||
self.conv5 = Conv1x1(channels[3], channels[3])
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
# fully connected layer
|
||||
self.fc = self._construct_fc_layer(
|
||||
self.feature_dim, channels[3], dropout_p=None
|
||||
)
|
||||
# identity classification layer
|
||||
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
||||
|
||||
self._init_params()
|
||||
|
||||
def _make_layer(self, blocks, layer, in_channels, out_channels):
|
||||
layers = []
|
||||
layers += [blocks[0](in_channels, out_channels)]
|
||||
for i in range(1, len(blocks)):
|
||||
layers += [blocks[i](out_channels, out_channels)]
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
||||
if fc_dims is None or fc_dims < 0:
|
||||
self.feature_dim = input_dim
|
||||
return None
|
||||
|
||||
if isinstance(fc_dims, int):
|
||||
fc_dims = [fc_dims]
|
||||
|
||||
layers = []
|
||||
for dim in fc_dims:
|
||||
layers.append(nn.Linear(input_dim, dim))
|
||||
layers.append(nn.BatchNorm1d(dim))
|
||||
layers.append(nn.ReLU())
|
||||
if dropout_p is not None:
|
||||
layers.append(nn.Dropout(p=dropout_p))
|
||||
input_dim = dim
|
||||
|
||||
self.feature_dim = fc_dims[-1]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu'
|
||||
)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.BatchNorm1d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.InstanceNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def featuremaps(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.conv2(x)
|
||||
x = self.pool2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.pool3(x)
|
||||
x = self.conv4(x)
|
||||
x = self.conv5(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, return_featuremaps=False):
|
||||
x = self.featuremaps(x)
|
||||
if return_featuremaps:
|
||||
return x
|
||||
v = self.global_avgpool(x)
|
||||
v = v.view(v.size(0), -1)
|
||||
if self.fc is not None:
|
||||
v = self.fc(v)
|
||||
if not self.training:
|
||||
return v
|
||||
y = self.classifier(v)
|
||||
if self.loss == 'softmax':
|
||||
return y
|
||||
elif self.loss == 'triplet':
|
||||
return y, v
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
|
||||
def init_pretrained_weights(model, key=''):
|
||||
"""Initializes model with pretrained weights.
|
||||
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
import os
|
||||
import errno
|
||||
import gdown
|
||||
from collections import OrderedDict
|
||||
|
||||
def _get_torch_home():
|
||||
ENV_TORCH_HOME = 'TORCH_HOME'
|
||||
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
||||
DEFAULT_CACHE_DIR = '~/.cache'
|
||||
torch_home = os.path.expanduser(
|
||||
os.getenv(
|
||||
ENV_TORCH_HOME,
|
||||
os.path.join(
|
||||
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
|
||||
)
|
||||
)
|
||||
)
|
||||
return torch_home
|
||||
|
||||
torch_home = _get_torch_home()
|
||||
model_dir = os.path.join(torch_home, 'checkpoints')
|
||||
try:
|
||||
os.makedirs(model_dir)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EEXIST:
|
||||
# Directory already exists, ignore.
|
||||
pass
|
||||
else:
|
||||
# Unexpected OSError, re-raise.
|
||||
raise
|
||||
filename = key + '_imagenet.pth'
|
||||
cached_file = os.path.join(model_dir, filename)
|
||||
|
||||
if not os.path.exists(cached_file):
|
||||
gdown.download(pretrained_urls[key], cached_file, quiet=False)
|
||||
|
||||
state_dict = torch.load(cached_file)
|
||||
model_dict = model.state_dict()
|
||||
new_state_dict = OrderedDict()
|
||||
matched_layers, discarded_layers = [], []
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith('module.'):
|
||||
k = k[7:] # discard module.
|
||||
|
||||
if k in model_dict and model_dict[k].size() == v.size():
|
||||
new_state_dict[k] = v
|
||||
matched_layers.append(k)
|
||||
else:
|
||||
discarded_layers.append(k)
|
||||
|
||||
model_dict.update(new_state_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
if len(matched_layers) == 0:
|
||||
warnings.warn(
|
||||
'The pretrained weights from "{}" cannot be loaded, '
|
||||
'please check the key names manually '
|
||||
'(** ignored and continue **)'.format(cached_file)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
'Successfully loaded imagenet pretrained weights from "{}"'.
|
||||
format(cached_file)
|
||||
)
|
||||
if len(discarded_layers) > 0:
|
||||
print(
|
||||
'** The following layers are discarded '
|
||||
'due to unmatched keys or layer size: {}'.
|
||||
format(discarded_layers)
|
||||
)
|
||||
|
||||
|
||||
##########
|
||||
# Instantiation
|
||||
##########
|
||||
def osnet_ain_x1_0(
|
||||
num_classes=1000, pretrained=True, loss='softmax', **kwargs
|
||||
):
|
||||
model = OSNet(
|
||||
num_classes,
|
||||
blocks=[
|
||||
[OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin],
|
||||
[OSBlockINin, OSBlock]
|
||||
],
|
||||
layers=[2, 2, 2],
|
||||
channels=[64, 256, 384, 512],
|
||||
loss=loss,
|
||||
conv1_IN=True,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_ain_x1_0')
|
||||
return model
|
||||
|
||||
|
||||
def osnet_ain_x0_75(
|
||||
num_classes=1000, pretrained=True, loss='softmax', **kwargs
|
||||
):
|
||||
model = OSNet(
|
||||
num_classes,
|
||||
blocks=[
|
||||
[OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin],
|
||||
[OSBlockINin, OSBlock]
|
||||
],
|
||||
layers=[2, 2, 2],
|
||||
channels=[48, 192, 288, 384],
|
||||
loss=loss,
|
||||
conv1_IN=True,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_ain_x0_75')
|
||||
return model
|
||||
|
||||
|
||||
def osnet_ain_x0_5(
|
||||
num_classes=1000, pretrained=True, loss='softmax', **kwargs
|
||||
):
|
||||
model = OSNet(
|
||||
num_classes,
|
||||
blocks=[
|
||||
[OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin],
|
||||
[OSBlockINin, OSBlock]
|
||||
],
|
||||
layers=[2, 2, 2],
|
||||
channels=[32, 128, 192, 256],
|
||||
loss=loss,
|
||||
conv1_IN=True,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_ain_x0_5')
|
||||
return model
|
||||
|
||||
|
||||
def osnet_ain_x0_25(
|
||||
num_classes=1000, pretrained=True, loss='softmax', **kwargs
|
||||
):
|
||||
model = OSNet(
|
||||
num_classes,
|
||||
blocks=[
|
||||
[OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin],
|
||||
[OSBlockINin, OSBlock]
|
||||
],
|
||||
layers=[2, 2, 2],
|
||||
channels=[16, 64, 96, 128],
|
||||
loss=loss,
|
||||
conv1_IN=True,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_ain_x0_25')
|
||||
return model
|
||||
314
feeder/trackers/strongsort/deep/models/pcb.py
Normal file
314
feeder/trackers/strongsort/deep/models/pcb.py
Normal file
|
|
@ -0,0 +1,314 @@
|
|||
from __future__ import division, absolute_import
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
__all__ = ['pcb_p6', 'pcb_p4']
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(
|
||||
planes, planes * self.expansion, kernel_size=1, bias=False
|
||||
)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DimReduceLayer(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, nonlinear):
|
||||
super(DimReduceLayer, self).__init__()
|
||||
layers = []
|
||||
layers.append(
|
||||
nn.Conv2d(
|
||||
in_channels, out_channels, 1, stride=1, padding=0, bias=False
|
||||
)
|
||||
)
|
||||
layers.append(nn.BatchNorm2d(out_channels))
|
||||
|
||||
if nonlinear == 'relu':
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
elif nonlinear == 'leakyrelu':
|
||||
layers.append(nn.LeakyReLU(0.1))
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class PCB(nn.Module):
|
||||
"""Part-based Convolutional Baseline.
|
||||
|
||||
Reference:
|
||||
Sun et al. Beyond Part Models: Person Retrieval with Refined
|
||||
Part Pooling (and A Strong Convolutional Baseline). ECCV 2018.
|
||||
|
||||
Public keys:
|
||||
- ``pcb_p4``: PCB with 4-part strips.
|
||||
- ``pcb_p6``: PCB with 6-part strips.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
loss,
|
||||
block,
|
||||
layers,
|
||||
parts=6,
|
||||
reduced_dim=256,
|
||||
nonlinear='relu',
|
||||
**kwargs
|
||||
):
|
||||
self.inplanes = 64
|
||||
super(PCB, self).__init__()
|
||||
self.loss = loss
|
||||
self.parts = parts
|
||||
self.feature_dim = 512 * block.expansion
|
||||
|
||||
# backbone network
|
||||
self.conv1 = nn.Conv2d(
|
||||
3, 64, kernel_size=7, stride=2, padding=3, bias=False
|
||||
)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=1)
|
||||
|
||||
# pcb layers
|
||||
self.parts_avgpool = nn.AdaptiveAvgPool2d((self.parts, 1))
|
||||
self.dropout = nn.Dropout(p=0.5)
|
||||
self.conv5 = DimReduceLayer(
|
||||
512 * block.expansion, reduced_dim, nonlinear=nonlinear
|
||||
)
|
||||
self.feature_dim = reduced_dim
|
||||
self.classifier = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(self.feature_dim, num_classes)
|
||||
for _ in range(self.parts)
|
||||
]
|
||||
)
|
||||
|
||||
self._init_params()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
self.inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False
|
||||
),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu'
|
||||
)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm1d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def featuremaps(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
f = self.featuremaps(x)
|
||||
v_g = self.parts_avgpool(f)
|
||||
|
||||
if not self.training:
|
||||
v_g = F.normalize(v_g, p=2, dim=1)
|
||||
return v_g.view(v_g.size(0), -1)
|
||||
|
||||
v_g = self.dropout(v_g)
|
||||
v_h = self.conv5(v_g)
|
||||
|
||||
y = []
|
||||
for i in range(self.parts):
|
||||
v_h_i = v_h[:, :, i, :]
|
||||
v_h_i = v_h_i.view(v_h_i.size(0), -1)
|
||||
y_i = self.classifier[i](v_h_i)
|
||||
y.append(y_i)
|
||||
|
||||
if self.loss == 'softmax':
|
||||
return y
|
||||
elif self.loss == 'triplet':
|
||||
v_g = F.normalize(v_g, p=2, dim=1)
|
||||
return y, v_g.view(v_g.size(0), -1)
|
||||
else:
|
||||
raise KeyError('Unsupported loss: {}'.format(self.loss))
|
||||
|
||||
|
||||
def init_pretrained_weights(model, model_url):
|
||||
"""Initializes model with pretrained weights.
|
||||
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
pretrain_dict = model_zoo.load_url(model_url)
|
||||
model_dict = model.state_dict()
|
||||
pretrain_dict = {
|
||||
k: v
|
||||
for k, v in pretrain_dict.items()
|
||||
if k in model_dict and model_dict[k].size() == v.size()
|
||||
}
|
||||
model_dict.update(pretrain_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
|
||||
def pcb_p6(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||
model = PCB(
|
||||
num_classes=num_classes,
|
||||
loss=loss,
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
last_stride=1,
|
||||
parts=6,
|
||||
reduced_dim=256,
|
||||
nonlinear='relu',
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['resnet50'])
|
||||
return model
|
||||
|
||||
|
||||
def pcb_p4(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||
model = PCB(
|
||||
num_classes=num_classes,
|
||||
loss=loss,
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
last_stride=1,
|
||||
parts=4,
|
||||
reduced_dim=256,
|
||||
nonlinear='relu',
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['resnet50'])
|
||||
return model
|
||||
530
feeder/trackers/strongsort/deep/models/resnet.py
Normal file
530
feeder/trackers/strongsort/deep/models/resnet.py
Normal file
|
|
@ -0,0 +1,530 @@
|
|||
"""
|
||||
Code source: https://github.com/pytorch/vision
|
||||
"""
|
||||
from __future__ import division, absolute_import
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from torch import nn
|
||||
|
||||
__all__ = [
|
||||
'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
|
||||
'resnext50_32x4d', 'resnext101_32x8d', 'resnet50_fc512'
|
||||
]
|
||||
|
||||
model_urls = {
|
||||
'resnet18':
|
||||
'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34':
|
||||
'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
'resnet50':
|
||||
'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101':
|
||||
'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152':
|
||||
'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
'resnext50_32x4d':
|
||||
'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
'resnext101_32x8d':
|
||||
'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
dilation=dilation
|
||||
)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(
|
||||
in_planes, out_planes, kernel_size=1, stride=stride, bias=False
|
||||
)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
downsample=None,
|
||||
groups=1,
|
||||
base_width=64,
|
||||
dilation=1,
|
||||
norm_layer=None
|
||||
):
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError(
|
||||
'BasicBlock only supports groups=1 and base_width=64'
|
||||
)
|
||||
if dilation > 1:
|
||||
raise NotImplementedError(
|
||||
"Dilation > 1 not supported in BasicBlock"
|
||||
)
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
downsample=None,
|
||||
groups=1,
|
||||
base_width=64,
|
||||
dilation=1,
|
||||
norm_layer=None
|
||||
):
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width/64.)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
"""Residual network.
|
||||
|
||||
Reference:
|
||||
- He et al. Deep Residual Learning for Image Recognition. CVPR 2016.
|
||||
- Xie et al. Aggregated Residual Transformations for Deep Neural Networks. CVPR 2017.
|
||||
|
||||
Public keys:
|
||||
- ``resnet18``: ResNet18.
|
||||
- ``resnet34``: ResNet34.
|
||||
- ``resnet50``: ResNet50.
|
||||
- ``resnet101``: ResNet101.
|
||||
- ``resnet152``: ResNet152.
|
||||
- ``resnext50_32x4d``: ResNeXt50.
|
||||
- ``resnext101_32x8d``: ResNeXt101.
|
||||
- ``resnet50_fc512``: ResNet50 + FC.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
loss,
|
||||
block,
|
||||
layers,
|
||||
zero_init_residual=False,
|
||||
groups=1,
|
||||
width_per_group=64,
|
||||
replace_stride_with_dilation=None,
|
||||
norm_layer=None,
|
||||
last_stride=2,
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
):
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
self.loss = loss
|
||||
self.feature_dim = 512 * block.expansion
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError(
|
||||
"replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".
|
||||
format(replace_stride_with_dilation)
|
||||
)
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(
|
||||
3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
|
||||
)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(
|
||||
block,
|
||||
128,
|
||||
layers[1],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[0]
|
||||
)
|
||||
self.layer3 = self._make_layer(
|
||||
block,
|
||||
256,
|
||||
layers[2],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[1]
|
||||
)
|
||||
self.layer4 = self._make_layer(
|
||||
block,
|
||||
512,
|
||||
layers[3],
|
||||
stride=last_stride,
|
||||
dilate=replace_stride_with_dilation[2]
|
||||
)
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = self._construct_fc_layer(
|
||||
fc_dims, 512 * block.expansion, dropout_p
|
||||
)
|
||||
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
||||
|
||||
self._init_params()
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
block(
|
||||
self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer
|
||||
)
|
||||
)
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(
|
||||
block(
|
||||
self.inplanes,
|
||||
planes,
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
dilation=self.dilation,
|
||||
norm_layer=norm_layer
|
||||
)
|
||||
)
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
||||
"""Constructs fully connected layer
|
||||
|
||||
Args:
|
||||
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
|
||||
input_dim (int): input dimension
|
||||
dropout_p (float): dropout probability, if None, dropout is unused
|
||||
"""
|
||||
if fc_dims is None:
|
||||
self.feature_dim = input_dim
|
||||
return None
|
||||
|
||||
assert isinstance(
|
||||
fc_dims, (list, tuple)
|
||||
), 'fc_dims must be either list or tuple, but got {}'.format(
|
||||
type(fc_dims)
|
||||
)
|
||||
|
||||
layers = []
|
||||
for dim in fc_dims:
|
||||
layers.append(nn.Linear(input_dim, dim))
|
||||
layers.append(nn.BatchNorm1d(dim))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
if dropout_p is not None:
|
||||
layers.append(nn.Dropout(p=dropout_p))
|
||||
input_dim = dim
|
||||
|
||||
self.feature_dim = fc_dims[-1]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu'
|
||||
)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm1d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def featuremaps(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
f = self.featuremaps(x)
|
||||
v = self.global_avgpool(f)
|
||||
v = v.view(v.size(0), -1)
|
||||
|
||||
if self.fc is not None:
|
||||
v = self.fc(v)
|
||||
|
||||
if not self.training:
|
||||
return v
|
||||
|
||||
y = self.classifier(v)
|
||||
|
||||
if self.loss == 'softmax':
|
||||
return y
|
||||
elif self.loss == 'triplet':
|
||||
return y, v
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
|
||||
def init_pretrained_weights(model, model_url):
|
||||
"""Initializes model with pretrained weights.
|
||||
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
pretrain_dict = model_zoo.load_url(model_url)
|
||||
model_dict = model.state_dict()
|
||||
pretrain_dict = {
|
||||
k: v
|
||||
for k, v in pretrain_dict.items()
|
||||
if k in model_dict and model_dict[k].size() == v.size()
|
||||
}
|
||||
model_dict.update(pretrain_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
|
||||
"""ResNet"""
|
||||
|
||||
|
||||
def resnet18(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||
model = ResNet(
|
||||
num_classes=num_classes,
|
||||
loss=loss,
|
||||
block=BasicBlock,
|
||||
layers=[2, 2, 2, 2],
|
||||
last_stride=2,
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['resnet18'])
|
||||
return model
|
||||
|
||||
|
||||
def resnet34(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||
model = ResNet(
|
||||
num_classes=num_classes,
|
||||
loss=loss,
|
||||
block=BasicBlock,
|
||||
layers=[3, 4, 6, 3],
|
||||
last_stride=2,
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['resnet34'])
|
||||
return model
|
||||
|
||||
|
||||
def resnet50(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||
model = ResNet(
|
||||
num_classes=num_classes,
|
||||
loss=loss,
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
last_stride=2,
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['resnet50'])
|
||||
return model
|
||||
|
||||
|
||||
def resnet101(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||
model = ResNet(
|
||||
num_classes=num_classes,
|
||||
loss=loss,
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 23, 3],
|
||||
last_stride=2,
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['resnet101'])
|
||||
return model
|
||||
|
||||
|
||||
def resnet152(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||
model = ResNet(
|
||||
num_classes=num_classes,
|
||||
loss=loss,
|
||||
block=Bottleneck,
|
||||
layers=[3, 8, 36, 3],
|
||||
last_stride=2,
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['resnet152'])
|
||||
return model
|
||||
|
||||
|
||||
"""ResNeXt"""
|
||||
|
||||
|
||||
def resnext50_32x4d(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||
model = ResNet(
|
||||
num_classes=num_classes,
|
||||
loss=loss,
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
last_stride=2,
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
groups=32,
|
||||
width_per_group=4,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['resnext50_32x4d'])
|
||||
return model
|
||||
|
||||
|
||||
def resnext101_32x8d(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||
model = ResNet(
|
||||
num_classes=num_classes,
|
||||
loss=loss,
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 23, 3],
|
||||
last_stride=2,
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
groups=32,
|
||||
width_per_group=8,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['resnext101_32x8d'])
|
||||
return model
|
||||
|
||||
|
||||
"""
|
||||
ResNet + FC
|
||||
"""
|
||||
|
||||
|
||||
def resnet50_fc512(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||
model = ResNet(
|
||||
num_classes=num_classes,
|
||||
loss=loss,
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
last_stride=1,
|
||||
fc_dims=[512],
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['resnet50'])
|
||||
return model
|
||||
289
feeder/trackers/strongsort/deep/models/resnet_ibn_a.py
Normal file
289
feeder/trackers/strongsort/deep/models/resnet_ibn_a.py
Normal file
|
|
@ -0,0 +1,289 @@
|
|||
"""
|
||||
Credit to https://github.com/XingangPan/IBN-Net.
|
||||
"""
|
||||
from __future__ import division, absolute_import
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
__all__ = ['resnet50_ibn_a']
|
||||
|
||||
model_urls = {
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"3x3 convolution with padding"
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class IBN(nn.Module):
|
||||
|
||||
def __init__(self, planes):
|
||||
super(IBN, self).__init__()
|
||||
half1 = int(planes / 2)
|
||||
self.half = half1
|
||||
half2 = planes - half1
|
||||
self.IN = nn.InstanceNorm2d(half1, affine=True)
|
||||
self.BN = nn.BatchNorm2d(half2)
|
||||
|
||||
def forward(self, x):
|
||||
split = torch.split(x, self.half, 1)
|
||||
out1 = self.IN(split[0].contiguous())
|
||||
out2 = self.BN(split[1].contiguous())
|
||||
out = torch.cat((out1, out2), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
if ibn:
|
||||
self.bn1 = IBN(planes)
|
||||
else:
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(
|
||||
planes, planes * self.expansion, kernel_size=1, bias=False
|
||||
)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
"""Residual network + IBN layer.
|
||||
|
||||
Reference:
|
||||
- He et al. Deep Residual Learning for Image Recognition. CVPR 2016.
|
||||
- Pan et al. Two at Once: Enhancing Learning and Generalization
|
||||
Capacities via IBN-Net. ECCV 2018.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block,
|
||||
layers,
|
||||
num_classes=1000,
|
||||
loss='softmax',
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
):
|
||||
scale = 64
|
||||
self.inplanes = scale
|
||||
super(ResNet, self).__init__()
|
||||
self.loss = loss
|
||||
self.feature_dim = scale * 8 * block.expansion
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
3, scale, kernel_size=7, stride=2, padding=3, bias=False
|
||||
)
|
||||
self.bn1 = nn.BatchNorm2d(scale)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, scale, layers[0])
|
||||
self.layer2 = self._make_layer(block, scale * 2, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=2)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = self._construct_fc_layer(
|
||||
fc_dims, scale * 8 * block.expansion, dropout_p
|
||||
)
|
||||
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.InstanceNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
self.inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False
|
||||
),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
ibn = True
|
||||
if planes == 512:
|
||||
ibn = False
|
||||
layers.append(block(self.inplanes, planes, ibn, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, ibn))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
||||
"""Constructs fully connected layer
|
||||
|
||||
Args:
|
||||
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
|
||||
input_dim (int): input dimension
|
||||
dropout_p (float): dropout probability, if None, dropout is unused
|
||||
"""
|
||||
if fc_dims is None:
|
||||
self.feature_dim = input_dim
|
||||
return None
|
||||
|
||||
assert isinstance(
|
||||
fc_dims, (list, tuple)
|
||||
), 'fc_dims must be either list or tuple, but got {}'.format(
|
||||
type(fc_dims)
|
||||
)
|
||||
|
||||
layers = []
|
||||
for dim in fc_dims:
|
||||
layers.append(nn.Linear(input_dim, dim))
|
||||
layers.append(nn.BatchNorm1d(dim))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
if dropout_p is not None:
|
||||
layers.append(nn.Dropout(p=dropout_p))
|
||||
input_dim = dim
|
||||
|
||||
self.feature_dim = fc_dims[-1]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def featuremaps(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
f = self.featuremaps(x)
|
||||
v = self.avgpool(f)
|
||||
v = v.view(v.size(0), -1)
|
||||
if self.fc is not None:
|
||||
v = self.fc(v)
|
||||
if not self.training:
|
||||
return v
|
||||
y = self.classifier(v)
|
||||
if self.loss == 'softmax':
|
||||
return y
|
||||
elif self.loss == 'triplet':
|
||||
return y, v
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
|
||||
def init_pretrained_weights(model, model_url):
|
||||
"""Initializes model with pretrained weights.
|
||||
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
pretrain_dict = model_zoo.load_url(model_url)
|
||||
model_dict = model.state_dict()
|
||||
pretrain_dict = {
|
||||
k: v
|
||||
for k, v in pretrain_dict.items()
|
||||
if k in model_dict and model_dict[k].size() == v.size()
|
||||
}
|
||||
model_dict.update(pretrain_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
|
||||
def resnet50_ibn_a(num_classes, loss='softmax', pretrained=False, **kwargs):
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, loss=loss, **kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['resnet50'])
|
||||
return model
|
||||
274
feeder/trackers/strongsort/deep/models/resnet_ibn_b.py
Normal file
274
feeder/trackers/strongsort/deep/models/resnet_ibn_b.py
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
"""
|
||||
Credit to https://github.com/XingangPan/IBN-Net.
|
||||
"""
|
||||
from __future__ import division, absolute_import
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
__all__ = ['resnet50_ibn_b']
|
||||
|
||||
model_urls = {
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"3x3 convolution with padding"
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, IN=False):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(
|
||||
planes, planes * self.expansion, kernel_size=1, bias=False
|
||||
)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.IN = None
|
||||
if IN:
|
||||
self.IN = nn.InstanceNorm2d(planes * 4, affine=True)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
if self.IN is not None:
|
||||
out = self.IN(out)
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
"""Residual network + IBN layer.
|
||||
|
||||
Reference:
|
||||
- He et al. Deep Residual Learning for Image Recognition. CVPR 2016.
|
||||
- Pan et al. Two at Once: Enhancing Learning and Generalization
|
||||
Capacities via IBN-Net. ECCV 2018.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block,
|
||||
layers,
|
||||
num_classes=1000,
|
||||
loss='softmax',
|
||||
fc_dims=None,
|
||||
dropout_p=None,
|
||||
**kwargs
|
||||
):
|
||||
scale = 64
|
||||
self.inplanes = scale
|
||||
super(ResNet, self).__init__()
|
||||
self.loss = loss
|
||||
self.feature_dim = scale * 8 * block.expansion
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
3, scale, kernel_size=7, stride=2, padding=3, bias=False
|
||||
)
|
||||
self.bn1 = nn.InstanceNorm2d(scale, affine=True)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(
|
||||
block, scale, layers[0], stride=1, IN=True
|
||||
)
|
||||
self.layer2 = self._make_layer(
|
||||
block, scale * 2, layers[1], stride=2, IN=True
|
||||
)
|
||||
self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=2)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = self._construct_fc_layer(
|
||||
fc_dims, scale * 8 * block.expansion, dropout_p
|
||||
)
|
||||
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.InstanceNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, IN=False):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
self.inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False
|
||||
),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks - 1):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
layers.append(block(self.inplanes, planes, IN=IN))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
||||
"""Constructs fully connected layer
|
||||
|
||||
Args:
|
||||
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
|
||||
input_dim (int): input dimension
|
||||
dropout_p (float): dropout probability, if None, dropout is unused
|
||||
"""
|
||||
if fc_dims is None:
|
||||
self.feature_dim = input_dim
|
||||
return None
|
||||
|
||||
assert isinstance(
|
||||
fc_dims, (list, tuple)
|
||||
), 'fc_dims must be either list or tuple, but got {}'.format(
|
||||
type(fc_dims)
|
||||
)
|
||||
|
||||
layers = []
|
||||
for dim in fc_dims:
|
||||
layers.append(nn.Linear(input_dim, dim))
|
||||
layers.append(nn.BatchNorm1d(dim))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
if dropout_p is not None:
|
||||
layers.append(nn.Dropout(p=dropout_p))
|
||||
input_dim = dim
|
||||
|
||||
self.feature_dim = fc_dims[-1]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def featuremaps(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
f = self.featuremaps(x)
|
||||
v = self.avgpool(f)
|
||||
v = v.view(v.size(0), -1)
|
||||
if self.fc is not None:
|
||||
v = self.fc(v)
|
||||
if not self.training:
|
||||
return v
|
||||
y = self.classifier(v)
|
||||
if self.loss == 'softmax':
|
||||
return y
|
||||
elif self.loss == 'triplet':
|
||||
return y, v
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
|
||||
def init_pretrained_weights(model, model_url):
|
||||
"""Initializes model with pretrained weights.
|
||||
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
pretrain_dict = model_zoo.load_url(model_url)
|
||||
model_dict = model.state_dict()
|
||||
pretrain_dict = {
|
||||
k: v
|
||||
for k, v in pretrain_dict.items()
|
||||
if k in model_dict and model_dict[k].size() == v.size()
|
||||
}
|
||||
model_dict.update(pretrain_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
|
||||
def resnet50_ibn_b(num_classes, loss='softmax', pretrained=False, **kwargs):
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, loss=loss, **kwargs
|
||||
)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls['resnet50'])
|
||||
return model
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue