Compare commits

..

7 commits

Author SHA1 Message Date
Pongsatorn Kanjanasantisak
48db3234ed resolve merge conflicts by accepting main branch versions 2025-08-11 00:57:21 +07:00
Pongsatorn Kanjanasantisak
aa883e4bfa test 2025-08-11 00:24:08 +07:00
Pongsatorn Kanjanasantisak
b7d8b3266f add StrongSORT Tacker 2025-08-10 01:23:09 +07:00
Pongsatorn Kanjanasantisak
ffc2e99678 Merge remote-tracking branch 'origin/main' into taiworker 2025-08-09 12:42:53 +07:00
Pongsatorn
e471ab03e9 update pympta.py to only receive from models folder 2025-08-08 23:22:04 +07:00
Pongsatorn
3d0aaab8b3 update Docker File to low vulnerabilities 2025-07-13 15:06:03 +07:00
Pongsatorn
7085a6e00f update gitignore .venv 2025-07-13 13:25:36 +07:00
152 changed files with 21461 additions and 14561 deletions

View file

@ -1,11 +0,0 @@
{
"permissions": {
"allow": [
"Bash(dir:*)",
"WebSearch",
"Bash(mkdir:*)"
],
"deny": [],
"ask": []
}
}

View file

@ -103,4 +103,10 @@ jobs:
- name: Deploy stack
run: |
echo "Pulling and starting containers on server..."
ssh -i ~/.ssh/id_rsa ${{ vars.DEPLOY_USER_CMS }}@${{ vars.DEPLOY_HOST_CMS }} "cd ~/cms-system-k8s && docker compose -f docker-compose.staging.yml -f docker-compose.production.yml pull && docker compose -f docker-compose.staging.yml -f docker-compose.production.yml up -d"
if [ "${{ github.ref_name }}" = "main" ]; then
echo "Deploying production stack..."
ssh -i ~/.ssh/id_rsa ${{ vars.DEPLOY_USER_CMS }}@${{ vars.DEPLOY_HOST_CMS }} "cd ~/cms-system-k8s && docker compose -f docker-compose.production.yml pull && docker compose -f docker-compose.production.yml up -d"
else
echo "Deploying staging stack..."
ssh -i ~/.ssh/id_rsa ${{ vars.DEPLOY_USER_CMS }}@${{ vars.DEPLOY_HOST_CMS }} "cd ~/cms-system-k8s && docker compose -f docker-compose.staging.yml pull && docker compose -f docker-compose.staging.yml up -d"
fi

13
.gitignore vendored
View file

@ -1,8 +1,11 @@
/models
# Do not know how to use
archive/
Dockerfile
# /models
app.log
*.pt
images
.venv/
# All pycache directories
__pycache__/
@ -12,3 +15,7 @@ mptas
detector_worker.log
.gitignore
no_frame_debug.log
# Result from tracker
feeder/runs/

View file

@ -1,130 +1,15 @@
# Base image with complete ML and hardware acceleration stack
FROM pytorch/pytorch:2.8.0-cuda12.6-cudnn9-runtime
# Base image with all ML dependencies
FROM python:3.13-bookworm
# Install build dependencies and system libraries
RUN apt-get update && apt-get install -y \
# Build tools
build-essential \
cmake \
git \
pkg-config \
wget \
unzip \
yasm \
nasm \
# Additional dependencies for FFmpeg/NVIDIA build
libtool \
libc6 \
libc6-dev \
libnuma1 \
libnuma-dev \
# Essential compilation libraries
gcc \
g++ \
libc6-dev \
linux-libc-dev \
# System libraries
libgl1-mesa-glx \
libglib2.0-0 \
libgomp1 \
# Core media libraries (essential ones only)
libjpeg-dev \
libpng-dev \
libx264-dev \
libx265-dev \
libvpx-dev \
libmp3lame-dev \
libv4l-dev \
# TurboJPEG for fast JPEG encoding
libturbojpeg0-dev \
# Python development
python3-dev \
python3-numpy \
&& rm -rf /var/lib/apt/lists/*
# Install system dependencies
RUN apt update && apt install -y libgl1 && rm -rf /var/lib/apt/lists/*
# Add NVIDIA CUDA repository and install minimal development tools
RUN apt-get update && apt-get install -y wget gnupg && \
wget -O - https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub | apt-key add - && \
echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
apt-get update && \
apt-get install -y \
cuda-nvcc-12-6 \
cuda-cudart-dev-12-6 \
libnpp-dev-12-6 \
&& apt-get remove -y wget gnupg && \
apt-get autoremove -y && \
rm -rf /var/lib/apt/lists/*
# Ensure CUDA paths are available
ENV PATH="/usr/local/cuda/bin:${PATH}"
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
# Install NVIDIA Video Codec SDK headers (official method)
RUN cd /tmp && \
git clone https://git.videolan.org/git/ffmpeg/nv-codec-headers.git && \
cd nv-codec-headers && \
make install && \
cd / && rm -rf /tmp/*
# Build FFmpeg from source with NVIDIA CUDA support
RUN cd /tmp && \
echo "Building FFmpeg with NVIDIA CUDA support..." && \
# Download FFmpeg source (official method)
git clone https://git.ffmpeg.org/ffmpeg.git ffmpeg/ && \
cd ffmpeg && \
# Configure with NVIDIA support (simplified to avoid configure issues)
./configure \
--prefix=/usr/local \
--enable-shared \
--disable-static \
--enable-nonfree \
--enable-gpl \
--enable-cuda-nvcc \
--enable-cuvid \
--enable-nvdec \
--enable-nvenc \
--enable-libnpp \
--extra-cflags=-I/usr/local/cuda/include \
--extra-ldflags=-L/usr/local/cuda/lib64 \
--enable-libx264 \
--enable-libx265 \
--enable-libvpx \
--enable-libmp3lame && \
# Build and install
make -j$(nproc) && \
make install && \
ldconfig && \
# Verify CUVID decoders are available
echo "=== Verifying FFmpeg CUVID Support ===" && \
(ffmpeg -hide_banner -decoders 2>/dev/null | grep cuvid || echo "No CUVID decoders found") && \
echo "=== Verifying FFmpeg NVENC Support ===" && \
(ffmpeg -hide_banner -encoders 2>/dev/null | grep nvenc || echo "No NVENC encoders found") && \
cd / && rm -rf /tmp/*
# Set environment variables for maximum hardware acceleration
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/lib:${LD_LIBRARY_PATH}"
ENV PKG_CONFIG_PATH="/usr/local/lib/pkgconfig:${PKG_CONFIG_PATH}"
ENV PYTHONPATH="/usr/local/lib/python3.10/dist-packages:${PYTHONPATH}"
# Optimized environment variables for hardware acceleration
ENV OPENCV_FFMPEG_CAPTURE_OPTIONS="rtsp_transport;tcp|hwaccel;cuda|hwaccel_device;0|video_codec;h264_cuvid|hwaccel_output_format;cuda"
ENV OPENCV_FFMPEG_WRITER_OPTIONS="video_codec;h264_nvenc|preset;fast|tune;zerolatency|gpu;0"
ENV CUDA_VISIBLE_DEVICES=0
ENV NVIDIA_VISIBLE_DEVICES=all
ENV NVIDIA_DRIVER_CAPABILITIES=compute,video,utility
# Copy and install base requirements (exclude opencv-python since we built from source)
# Copy and install base requirements (ML dependencies that rarely change)
COPY requirements.base.txt .
RUN grep -v opencv-python requirements.base.txt > requirements.tmp && \
mv requirements.tmp requirements.base.txt && \
pip install --no-cache-dir -r requirements.base.txt
RUN pip install --no-cache-dir -r requirements.base.txt
# Set working directory
WORKDIR /app
# Create images directory for bind mount
RUN mkdir -p /app/images && \
chmod 755 /app/images
# This base image will be reused for all worker builds
CMD ["python3", "-m", "fastapi", "run", "--host", "0.0.0.0", "--port", "8000"]

View file

@ -1,545 +0,0 @@
# Detector Worker Refactoring Plan
## Project Overview
Transform the current monolithic structure (~4000 lines across `app.py` and `siwatsystem/pympta.py`) into a modular, maintainable system with clear separation of concerns. The goal is to make the sophisticated computer vision pipeline easily understandable for other engineers while maintaining all existing functionality.
## Current System Flow Understanding
### Validated System Flow
1. **WebSocket Connection** → Backend connects and sends `setSubscriptionList`
2. **Model Management** → Download unique `.mpta` files to `models/` and extract
3. **Tracking Phase** → Continuous tracking with `front_rear_detection_v1.pt`
4. **Validation Phase** → Validate stable car (not just passing by)
5. **Pipeline Execution**
- Detect car with `yolo11m.pt`
- **Branch 1**: Front/rear detection → crop frontal → save to Redis + brand classification
- **Branch 2**: Body type classification from car crop
6. **Communication** → Send `imageDetection` → Backend generates `sessionId` → Fueling starts
7. **Post-Fueling** → Backend clears `sessionId` → Continue tracking same car to avoid re-pipeline
### Core Responsibilities Identified
1. **WebSocket Communication** - Message handling and protocol compliance
2. **Stream Management** - RTSP/HTTP frame processing and buffering
3. **Model Management** - MPTA download, extraction, and loading
4. **Pipeline Configuration** - Parse `pipeline.json` and setup execution flow
5. **Vehicle Tracking** - Continuous tracking and car identification
6. **Validation Logic** - Stable car detection vs. passing-by cars
7. **Detection Pipeline** - Main ML pipeline with parallel branches
8. **Data Persistence** - Redis/PostgreSQL operations
9. **Session Management** - Handle session IDs and lifecycle
## Proposed Directory Structure
```
core/
├── communication/
│ ├── __init__.py
│ ├── websocket.py # WebSocket message handling & protocol
│ ├── messages.py # Message types and validation
│ ├── models.py # Message data structures
│ └── state.py # Worker state management
├── streaming/
│ ├── __init__.py
│ ├── manager.py # Stream coordination and lifecycle
│ ├── readers.py # RTSP/HTTP frame readers
│ └── buffers.py # Frame buffering and caching
├── models/
│ ├── __init__.py
│ ├── manager.py # MPTA download and model loading
│ ├── pipeline.py # Pipeline.json parser and config
│ └── inference.py # YOLO model wrapper and optimization
├── tracking/
│ ├── __init__.py
│ ├── tracker.py # Vehicle tracking with front_rear_detection_v1
│ ├── validator.py # Stable car validation logic
│ └── integration.py # Tracking-pipeline integration
├── detection/
│ ├── __init__.py
│ ├── pipeline.py # Main detection pipeline orchestration
│ └── branches.py # Parallel branch processing (brand/bodytype)
└── storage/
├── __init__.py
├── redis.py # Redis operations and image storage
└── database.py # PostgreSQL operations (existing - will be moved)
```
## Implementation Strategy (Feature-by-Feature Testing)
### Phase 1: Communication Layer
- WebSocket message handling (setSubscriptionList, sessionId management)
- HTTP API endpoints (camera image retrieval)
- Worker state reporting
### Phase 2: Pipeline Configuration Reader
- Parse `pipeline.json`
- Model dependency resolution
- Branch configuration setup
### Phase 3: Tracking System
- Continuous vehicle tracking
- Car identification and persistence
### Phase 4: Tracking Validator
- Stable car detection logic
- Passing-by vs. fueling car differentiation
### Phase 5: Model Pipeline Execution
- Main detection pipeline
- Parallel branch processing
- Redis/DB integration
### Phase 6: Post-Session Tracking Validation
- Same car validation after sessionId cleared
- Prevent duplicate pipeline execution
## Key Preservation Requirements
- **HTTP Endpoint**: `/camera/{camera_id}/image` must remain unchanged
- **WebSocket Protocol**: Full compliance with `worker.md` specification
- **MPTA Format**: Maintain compatibility with existing model archives
- **Database Schema**: Keep existing PostgreSQL structure
- **Redis Integration**: Preserve image storage and pub/sub functionality
- **Configuration**: Maintain `config.json` compatibility
- **Logging**: Preserve structured logging format
## Expected Benefits
- **Maintainability**: Single responsibility modules (~200-400 lines each)
- **Testability**: Independent testing of each component
- **Readability**: Clear separation of concerns
- **Scalability**: Easy to extend and modify individual components
- **Documentation**: Self-documenting code structure
---
# Comprehensive TODO List
## ✅ Phase 1: Project Setup & Communication Layer - COMPLETED
### 1.1 Project Structure Setup
- ✅ Create `core/` directory structure
- ✅ Create all module directories and `__init__.py` files
- ✅ Set up logging configuration for new modules
- ✅ Update imports in existing files to prepare for migration
### 1.2 Communication Module (`core/communication/`)
- ✅ **Create `models.py`** - Message data structures
- ✅ Define WebSocket message models (SubscriptionList, StateReport, etc.)
- ✅ Add validation schemas for incoming messages
- ✅ Create response models for outgoing messages
- ✅ **Create `messages.py`** - Message types and validation
- ✅ Implement message type constants
- ✅ Add message validation functions
- ✅ Create message builders for common responses
- ✅ **Create `websocket.py`** - WebSocket message handling
- ✅ Extract WebSocket connection management from `app.py`
- ✅ Implement message routing and dispatching
- ✅ Add connection lifecycle management (connect, disconnect, reconnect)
- ✅ Handle `setSubscriptionList` message processing
- ✅ Handle `setSessionId` and `setProgressionStage` messages
- ✅ Handle `requestState` and `patchSessionResult` messages
- ✅ **Create `state.py`** - Worker state management
- ✅ Extract state reporting logic from `app.py`
- ✅ Implement system metrics collection (CPU, memory, GPU)
- ✅ Manage active subscriptions state
- ✅ Handle session ID mapping and storage
### 1.3 HTTP API Preservation
- ✅ **Preserve `/camera/{camera_id}/image` endpoint**
- ✅ Extract REST API logic from `app.py`
- ✅ Ensure frame caching mechanism works with new structure
- ✅ Maintain exact same response format and error handling
### 1.4 Testing Phase 1
- ✅ Test WebSocket connection and message handling
- ✅ Test HTTP API endpoint functionality
- ✅ Verify state reporting works correctly
- ✅ Test session management functionality
### 1.5 Phase 1 Results
- ✅ **Modular Architecture**: Transformed ~900 lines into 4 focused modules (~200 lines each)
- ✅ **WebSocket Protocol**: Full compliance with worker.md specification
- ✅ **System Metrics**: Real-time CPU, memory, GPU monitoring
- ✅ **State Management**: Thread-safe subscription and session tracking
- ✅ **Backward Compatibility**: All existing endpoints preserved
- ✅ **Modern FastAPI**: Lifespan events, Pydantic v2 compatibility
## ✅ Phase 2: Pipeline Configuration & Model Management - COMPLETED
### 2.1 Models Module (`core/models/`)
- ✅ **Create `pipeline.py`** - Pipeline.json parser
- ✅ Extract pipeline configuration parsing from `pympta.py`
- ✅ Implement pipeline validation
- ✅ Add configuration schema validation
- ✅ Handle Redis and PostgreSQL configuration parsing
- ✅ **Create `manager.py`** - MPTA download and model loading
- ✅ Extract MPTA download logic from `pympta.py`
- ✅ Implement ZIP extraction and validation
- ✅ Add model file management and caching
- ✅ Handle model loading with GPU optimization
- ✅ Implement model dependency resolution
- ✅ **Create `inference.py`** - YOLO model wrapper
- ✅ Create unified YOLO model interface
- ✅ Add inference optimization and caching
- ✅ Implement batch processing capabilities
- ✅ Handle model switching and memory management
### 2.2 Testing Phase 2
- ✅ Test MPTA file download and extraction
- ✅ Test pipeline.json parsing and validation
- ✅ Test model loading with different configurations
- ✅ Verify GPU optimization works correctly
### 2.3 Phase 2 Results
- ✅ **ModelManager**: Downloads, extracts, and manages MPTA files with model ID-based directory structure
- ✅ **PipelineParser**: Parses and validates pipeline.json with full support for Redis, PostgreSQL, tracking, and branches
- ✅ **YOLOWrapper**: Unified interface for YOLO models with caching, tracking, and classification support
- ✅ **Model Caching**: Shared model cache across instances to optimize memory usage
- ✅ **Dependency Resolution**: Automatically identifies and tracks all model file dependencies
## ✅ Phase 3: Streaming System - COMPLETED
### 3.1 Streaming Module (`core/streaming/`)
- ✅ **Create `readers.py`** - RTSP/HTTP frame readers
- ✅ Extract `frame_reader` function from `app.py`
- ✅ Extract `snapshot_reader` function from `app.py`
- ✅ Add connection management and retry logic
- ✅ Implement frame rate control and optimization
- ✅ **Create `buffers.py`** - Frame buffering and caching
- ✅ Extract frame buffer management from `app.py`
- ✅ Implement efficient frame caching for REST API
- ✅ Add buffer size management and memory optimization
- ✅ **Create `manager.py`** - Stream coordination
- ✅ Extract stream lifecycle management from `app.py`
- ✅ Implement shared stream optimization
- ✅ Add subscription reconciliation logic
- ✅ Handle stream sharing across multiple subscriptions
### 3.2 Testing Phase 3
- ✅ Test RTSP stream reading and buffering
- ✅ Test HTTP snapshot capture functionality
- ✅ Test shared stream optimization
- ✅ Verify frame caching for REST API access
### 3.3 Phase 3 Results
- ✅ **RTSPReader**: OpenCV-based RTSP stream reader with automatic reconnection and frame callbacks
- ✅ **HTTPSnapshotReader**: Periodic HTTP snapshot capture with HTTPBasicAuth and HTTPDigestAuth support
- ✅ **FrameBuffer**: Thread-safe frame storage with automatic aging and cleanup
- ✅ **CacheBuffer**: Enhanced frame cache with cropping support and highest quality JPEG encoding (default quality=100)
- ✅ **StreamManager**: Complete stream lifecycle management with shared optimization and subscription reconciliation
- ✅ **Authentication Support**: Proper handling of credentials in URLs with automatic auth type detection
- ✅ **Real Camera Testing**: Verified with authenticated RTSP (1280x720) and HTTP snapshot (2688x1520) cameras
- ✅ **Production Ready**: Stable concurrent streaming from multiple camera sources
- ✅ **Dependencies**: Added opencv-python, numpy, and requests to requirements.txt
### 3.4 Recent Streaming Enhancements (Post-Phase 3)
- ✅ **Format-Specific Optimization**: Tailored for 1280x720@6fps RTSP streams and 2560x1440 HTTP snapshots
- ✅ **H.264 Error Recovery**: Enhanced error handling for corrupted frames with automatic stream recovery
- ✅ **Frame Validation**: Implemented corruption detection using edge density analysis
- ✅ **Buffer Size Optimization**: Adjusted buffer limits to 3MB for RTSP frames (1280x720x3 bytes)
- ✅ **FFMPEG Integration**: Added environment variables to suppress verbose H.264 decoder errors
- ✅ **URL Preservation**: Maintained clean RTSP URLs without parameter injection
- ✅ **Type Detection**: Automatic stream type detection based on frame dimensions
- ✅ **Quality Settings**: Format-specific JPEG quality (90% for RTSP, 95% for HTTP)
## ✅ Phase 4: Vehicle Tracking System - COMPLETED
### 4.1 Tracking Module (`core/tracking/`)
- ✅ **Create `tracker.py`** - Vehicle tracking implementation (305 lines)
- ✅ Implement continuous tracking with configurable model (front_rear_detection_v1.pt)
- ✅ Add vehicle identification and persistence with TrackedVehicle dataclass
- ✅ Implement tracking state management with thread-safe operations
- ✅ Add bounding box tracking and motion analysis with position history
- ✅ Multi-class tracking support for complex detection scenarios
- ✅ **Create `validator.py`** - Stable car validation (417 lines)
- ✅ Implement stable car detection algorithm with multiple validation criteria
- ✅ Add passing-by vs. fueling car differentiation using velocity and position analysis
- ✅ Implement validation thresholds and timing with configurable parameters
- ✅ Add confidence scoring for validation decisions with state history
- ✅ Advanced motion analysis with velocity smoothing and position variance
- ✅ **Create `integration.py`** - Tracking-pipeline integration (547 lines)
- ✅ Connect tracking system with main pipeline through TrackingPipelineIntegration
- ✅ Handle tracking state transitions and session management
- ✅ Implement post-session tracking validation with cooldown periods
- ✅ Add same-car validation after sessionId cleared with 30-second cooldown
- ✅ Car abandonment detection with automatic timeout monitoring
- ✅ Mock detection system for backend communication
- ✅ Async pipeline execution with proper error handling
### 4.2 Testing Phase 4
- ✅ Test continuous vehicle tracking functionality
- ✅ Test stable car validation logic
- ✅ Test integration with existing pipeline
- ✅ Verify tracking performance and accuracy
- ✅ Test car abandonment detection with null detection messages
- ✅ Verify session management and progression stage handling
### 4.3 Phase 4 Results
- ✅ **VehicleTracker**: Complete tracking implementation with YOLO tracking integration, position history, and stability calculations
- ✅ **StableCarValidator**: Sophisticated validation logic using velocity, position variance, and state consistency
- ✅ **TrackingPipelineIntegration**: Full integration with pipeline system including session management and async processing
- ✅ **StreamManager Integration**: Updated streaming manager to process tracking on every frame with proper threading
- ✅ **Thread-Safe Operations**: All tracking operations are thread-safe with proper locking mechanisms
- ✅ **Configurable Parameters**: All tracking parameters are configurable through pipeline.json
- ✅ **Session Management**: Complete session lifecycle management with post-fueling validation
- ✅ **Statistics and Monitoring**: Comprehensive statistics collection for tracking performance
- ✅ **Car Abandonment Detection**: Automatic detection when cars leave without fueling, sends `detection: null` to backend
- ✅ **Message Protocol**: Fixed JSON serialization to include `detection: null` for abandonment notifications
- ✅ **Streaming Optimization**: Enhanced RTSP/HTTP readers for 1280x720@6fps RTSP and 2560x1440 HTTP snapshots
- ✅ **Error Recovery**: Improved H.264 error handling and corrupted frame detection
## ✅ Phase 5: Detection Pipeline System - COMPLETED
### 5.1 Detection Module (`core/detection/`) ✅
- ✅ **Create `pipeline.py`** - Main detection orchestration (574 lines)
- ✅ Extracted main pipeline execution from `pympta.py` with full orchestration
- ✅ Implemented detection flow coordination with async execution
- ✅ Added pipeline state management with comprehensive statistics
- ✅ Handled pipeline result aggregation with branch synchronization
- ✅ Redis and database integration with error handling
- ✅ Immediate and parallel action execution with template resolution
- ✅ **Create `branches.py`** - Parallel branch processing (442 lines)
- ✅ Extracted parallel branch execution from `pympta.py`
- ✅ Implemented ThreadPoolExecutor-based parallel processing
- ✅ Added branch synchronization and result collection
- ✅ Handled branch failure and retry logic with graceful degradation
- ✅ Support for nested branches and model caching
- ✅ Both detection and classification model support
### 5.2 Storage Module (`core/storage/`) ✅
- ✅ **Create `redis.py`** - Redis operations (410 lines)
- ✅ Extracted Redis action execution from `pympta.py`
- ✅ Implemented async image storage with region cropping
- ✅ Added pub/sub messaging functionality with JSON support
- ✅ Handled Redis connection management and retry logic
- ✅ Added statistics tracking and health monitoring
- ✅ Support for various image formats (JPEG, PNG) with quality control
- ✅ **Move `database.py`** - PostgreSQL operations (339 lines)
- ✅ Moved existing `archive/siwatsystem/database.py` to `core/storage/`
- ✅ Updated imports and integration points
- ✅ Ensured compatibility with new module structure
- ✅ Added session management and statistics methods
- ✅ Enhanced error handling and connection management
### 5.3 Integration Updates ✅
- ✅ **Updated `core/tracking/integration.py`**
- ✅ Added DetectionPipeline integration
- ✅ Replaced placeholder `_execute_pipeline` with real implementation
- ✅ Added detection pipeline initialization and cleanup
- ✅ Integrated with existing tracking system flow
- ✅ Maintained backward compatibility with test mode
### 5.4 Testing Phase 5 ✅
- ✅ Verified module imports work correctly
- ✅ All new modules follow established coding patterns
- ✅ Integration points properly connected
- ✅ Error handling and cleanup methods implemented
- ✅ Statistics and monitoring capabilities added
### 5.5 Phase 5 Results ✅
- ✅ **DetectionPipeline**: Complete detection orchestration with Redis/PostgreSQL integration, async execution, and comprehensive error handling
- ✅ **BranchProcessor**: Parallel branch execution with ThreadPoolExecutor, model caching, and nested branch support
- ✅ **RedisManager**: Async Redis operations with image storage, pub/sub messaging, and connection management
- ✅ **DatabaseManager**: Enhanced PostgreSQL operations with session management and statistics
- ✅ **Module Integration**: Seamless integration with existing tracking system while maintaining compatibility
- ✅ **Error Handling**: Comprehensive error handling and graceful degradation throughout all components
- ✅ **Performance**: Optimized parallel processing and caching for high-performance pipeline execution
## ✅ Additional Implemented Features (Not in Original Plan)
### License Plate Recognition Integration (`core/storage/license_plate.py`) ✅
- ✅ **LicensePlateManager**: Subscribes to Redis channel `license_results` for external LPR service
- ✅ **Multi-format Support**: Handles various message formats from LPR service
- ✅ **Result Caching**: 5-minute TTL for license plate results
- ✅ **WebSocket Integration**: Sends combined `imageDetection` messages with license data
- ✅ **Asynchronous Processing**: Non-blocking Redis pub/sub listener
### Advanced Session State Management (`core/communication/state.py`) ✅
- ✅ **Session ID Mapping**: Per-display session identifier tracking
- ✅ **Progression Stage Tracking**: Workflow state per display (welcome, car_wait_staff, finished, cleared)
- ✅ **Thread-Safe Operations**: RLock-based synchronization for concurrent access
- ✅ **Comprehensive State Reporting**: Full system state for debugging
### Car Abandonment Detection (`core/tracking/integration.py`) ✅
- ✅ **Abandonment Monitoring**: Detects cars leaving without completing fueling
- ✅ **Timeout Configuration**: 3-second abandonment timeout
- ✅ **Null Detection Messages**: Sends `detection: null` to backend for abandoned cars
- ✅ **Automatic Cleanup**: Removes abandoned sessions from tracking
### Enhanced Message Protocol (`core/communication/models.py`) ✅
- ✅ **PatchSessionResult**: Session data patching support
- ✅ **SetProgressionStage**: Workflow stage management messages
- ✅ **Null Detection Handling**: Support for abandonment notifications
- ✅ **Complex Detection Structure**: Supports both classification and null states
### Comprehensive Timeout and Cooldown Systems ✅
- ✅ **Post-Session Cooldown**: 30-second cooldown after session clearing
- ✅ **Processing Cooldown**: 10-second cooldown for repeated processing
- ✅ **Abandonment Timeout**: 3-second timeout for car abandonment detection
- ✅ **Vehicle Expiration**: 2-second timeout for tracking cleanup
- ✅ **Stream Timeouts**: 30-second connection timeout management
## 📋 Phase 6: Integration & Final Testing
### 6.1 Main Application Refactoring
- [ ] **Refactor `app.py`**
- [ ] Remove extracted functionality
- [ ] Update to use new modular structure
- [ ] Maintain FastAPI application structure
- [ ] Update imports and dependencies
- [ ] **Clean up `siwatsystem/pympta.py`**
- [ ] Remove extracted functionality
- [ ] Keep only necessary legacy compatibility code
- [ ] Update imports to use new modules
### 6.2 Post-Session Tracking Validation
- [ ] Implement same-car validation after sessionId cleared
- [ ] Add logic to prevent duplicate pipeline execution
- [ ] Test tracking persistence through session lifecycle
- [ ] Verify correct behavior during edge cases
### 6.3 Configuration & Documentation
- [ ] Update configuration handling for new structure
- [ ] Ensure `config.json` compatibility maintained
- [ ] Update logging configuration for all modules
- [ ] Add module-level documentation
### 6.4 Comprehensive Testing
- [ ] **Integration Testing**
- [ ] Test complete system flow end-to-end
- [ ] Test all WebSocket message types
- [ ] Test HTTP API endpoints
- [ ] Test error handling and recovery
- [ ] **Performance Testing**
- [ ] Verify system performance is maintained
- [ ] Test memory usage optimization
- [ ] Test GPU utilization efficiency
- [ ] Benchmark against original implementation
- [ ] **Edge Case Testing**
- [ ] Test connection failures and reconnection
- [ ] Test model loading failures
- [ ] Test stream interruption handling
- [ ] Test concurrent subscription management
### 6.5 Logging Optimization & Cleanup ✅
- ✅ **Removed Debug Frame Saving**
- ✅ Removed hard-coded debug frame saving in `core/detection/pipeline.py`
- ✅ Removed hard-coded debug frame saving in `core/detection/branches.py`
- ✅ Eliminated absolute debug paths for production use
- ✅ **Eliminated Test/Mock Functionality**
- ✅ Removed `save_frame_for_testing` function from `core/streaming/buffers.py`
- ✅ Removed `save_test_frames` configuration from `StreamConfig`
- ✅ Cleaned up test frame saving calls in stream manager
- ✅ Updated module exports to remove test functions
- ✅ **Reduced Verbose Logging**
- ✅ Commented out verbose frame storage logging (every frame)
- ✅ Converted debug-level info logs to proper debug level
- ✅ Reduced repetitive frame dimension logging
- ✅ Maintained important model results and detection confidence logging
- ✅ Kept critical pipeline execution and error messages
- ✅ **Production-Ready Logging**
- ✅ Clean startup and initialization messages
- ✅ Clear model loading and pipeline status
- ✅ Preserved detection results with confidence scores
- ✅ Maintained session management and tracking messages
- ✅ Kept important error and warning messages
### 6.6 Final Cleanup
- [ ] Remove any remaining duplicate code
- [ ] Optimize imports across all modules
- [ ] Clean up temporary files and debugging code
- [ ] Update project documentation
## 📋 Post-Refactoring Tasks
### Documentation Updates
- [ ] Update `CLAUDE.md` with new architecture
- [ ] Create module-specific documentation
- [ ] Update installation and deployment guides
- [ ] Add troubleshooting guide for new structure
### Code Quality
- [ ] Add type hints to all new modules
- [ ] Implement proper error handling patterns
- [ ] Add logging consistency across modules
- [ ] Ensure proper resource cleanup
### Future Enhancements (Optional)
- [ ] Add unit tests for each module
- [ ] Implement monitoring and metrics collection
- [ ] Add configuration validation
- [ ] Consider adding dependency injection container
---
## Success Criteria
**Modularity**: Each module has a single, clear responsibility
**Testability**: Each phase can be tested independently
**Maintainability**: Code is easy to understand and modify
**Compatibility**: All existing functionality preserved
**Performance**: System performance is maintained or improved
**Documentation**: Clear documentation for new architecture
## Risk Mitigation
- **Feature-by-feature testing** ensures functionality is preserved at each step
- **Gradual migration** minimizes risk of breaking existing functionality
- **Preserve critical interfaces** (WebSocket protocol, HTTP endpoints)
- **Maintain backward compatibility** with existing configurations
- **Comprehensive testing** at each phase before proceeding
---
## 🎯 Current Status Summary
### ✅ Completed Phases (95% Complete)
- **Phase 1**: Communication Layer - ✅ COMPLETED
- **Phase 2**: Pipeline Configuration & Model Management - ✅ COMPLETED
- **Phase 3**: Streaming System - ✅ COMPLETED
- **Phase 4**: Vehicle Tracking System - ✅ COMPLETED
- **Phase 5**: Detection Pipeline System - ✅ COMPLETED
- **Additional Features**: License Plate Recognition, Car Abandonment, Session Management - ✅ COMPLETED
### 📋 Remaining Work (5%)
- **Phase 6**: Final Integration & Testing
- Main application cleanup (`app.py` and `pympta.py`)
- Comprehensive integration testing
- Performance benchmarking
- Documentation updates
### 🚀 Production Ready Features
- ✅ **Modular Architecture**: ~4000 lines refactored into 20+ focused modules
- ✅ **WebSocket Protocol**: Full compliance with all message types
- ✅ **License Plate Recognition**: External LPR service integration via Redis
- ✅ **Car Abandonment Detection**: Automatic detection and notification
- ✅ **Session Management**: Complete lifecycle with progression stages
- ✅ **Parallel Processing**: ThreadPoolExecutor for branch execution
- ✅ **Redis Integration**: Pub/sub, image storage, LPR subscription
- ✅ **PostgreSQL Integration**: Automatic schema management, combined updates
- ✅ **Stream Optimization**: Shared streams, format-specific handling
- ✅ **Error Recovery**: H.264 corruption detection, automatic reconnection
- ✅ **Production Logging**: Clean, informative logging without debug clutter
### 📊 Metrics
- **Modules Created**: 20+ specialized modules
- **Lines Per Module**: ~200-500 (highly maintainable)
- **Test Coverage**: Feature-by-feature validation completed
- **Performance**: Maintained or improved from original implementation
- **Backward Compatibility**: 100% preserved

1331
app.py

File diff suppressed because it is too large Load diff

View file

@ -1,903 +0,0 @@
from typing import Any, Dict
import os
import json
import time
import queue
import torch
import cv2
import numpy as np
import base64
import logging
import threading
import requests
import asyncio
import psutil
import zipfile
from urllib.parse import urlparse
from fastapi import FastAPI, WebSocket, HTTPException
from fastapi.websockets import WebSocketDisconnect
from fastapi.responses import Response
from websockets.exceptions import ConnectionClosedError
from ultralytics import YOLO
# Import shared pipeline functions
from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline
app = FastAPI()
# Global dictionaries to keep track of models and streams
# "models" now holds a nested dict: { camera_id: { modelId: model_tree } }
models: Dict[str, Dict[str, Any]] = {}
streams: Dict[str, Dict[str, Any]] = {}
# Store session IDs per display
session_ids: Dict[str, int] = {}
# Track shared camera streams by camera URL
camera_streams: Dict[str, Dict[str, Any]] = {}
# Map subscriptions to their camera URL
subscription_to_camera: Dict[str, str] = {}
# Store latest frames for REST API access (separate from processing buffer)
latest_frames: Dict[str, Any] = {}
with open("config.json", "r") as f:
config = json.load(f)
poll_interval = config.get("poll_interval_ms", 100)
reconnect_interval = config.get("reconnect_interval_sec", 5)
TARGET_FPS = config.get("target_fps", 10)
poll_interval = 1000 / TARGET_FPS
logging.info(f"Poll interval: {poll_interval}ms")
max_streams = config.get("max_streams", 5)
max_retries = config.get("max_retries", 3)
# Configure logging
logging.basicConfig(
level=logging.INFO, # Set to INFO level for less verbose output
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[
logging.FileHandler("detector_worker.log"), # Write logs to a file
logging.StreamHandler() # Also output to console
]
)
# Create a logger specifically for this application
logger = logging.getLogger("detector_worker")
logger.setLevel(logging.DEBUG) # Set app-specific logger to DEBUG level
# Ensure all other libraries (including root) use at least INFO level
logging.getLogger().setLevel(logging.INFO)
logger.info("Starting detector worker application")
logger.info(f"Configuration: Target FPS: {TARGET_FPS}, Max streams: {max_streams}, Max retries: {max_retries}")
# Ensure the models directory exists
os.makedirs("models", exist_ok=True)
logger.info("Ensured models directory exists")
# Constants for heartbeat and timeouts
HEARTBEAT_INTERVAL = 2 # seconds
WORKER_TIMEOUT_MS = 10000
logger.debug(f"Heartbeat interval set to {HEARTBEAT_INTERVAL} seconds")
# Locks for thread-safe operations
streams_lock = threading.Lock()
models_lock = threading.Lock()
logger.debug("Initialized thread locks")
# Add helper to download mpta ZIP file from a remote URL
def download_mpta(url: str, dest_path: str) -> str:
try:
logger.info(f"Starting download of model from {url} to {dest_path}")
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
response = requests.get(url, stream=True)
if response.status_code == 200:
file_size = int(response.headers.get('content-length', 0))
logger.info(f"Model file size: {file_size/1024/1024:.2f} MB")
downloaded = 0
with open(dest_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
downloaded += len(chunk)
if file_size > 0 and downloaded % (file_size // 10) < 8192: # Log approximately every 10%
logger.debug(f"Download progress: {downloaded/file_size*100:.1f}%")
logger.info(f"Successfully downloaded mpta file from {url} to {dest_path}")
return dest_path
else:
logger.error(f"Failed to download mpta file (status code {response.status_code}): {response.text}")
return None
except Exception as e:
logger.error(f"Exception downloading mpta file from {url}: {str(e)}", exc_info=True)
return None
# Add helper to fetch snapshot image from HTTP/HTTPS URL
def fetch_snapshot(url: str):
try:
from requests.auth import HTTPBasicAuth, HTTPDigestAuth
# Parse URL to extract credentials
parsed = urlparse(url)
# Prepare headers - some cameras require User-Agent
headers = {
'User-Agent': 'Mozilla/5.0 (compatible; DetectorWorker/1.0)'
}
# Reconstruct URL without credentials
clean_url = f"{parsed.scheme}://{parsed.hostname}"
if parsed.port:
clean_url += f":{parsed.port}"
clean_url += parsed.path
if parsed.query:
clean_url += f"?{parsed.query}"
auth = None
if parsed.username and parsed.password:
# Try HTTP Digest authentication first (common for IP cameras)
try:
auth = HTTPDigestAuth(parsed.username, parsed.password)
response = requests.get(clean_url, auth=auth, headers=headers, timeout=10)
if response.status_code == 200:
logger.debug(f"Successfully authenticated using HTTP Digest for {clean_url}")
elif response.status_code == 401:
# If Digest fails, try Basic auth
logger.debug(f"HTTP Digest failed, trying Basic auth for {clean_url}")
auth = HTTPBasicAuth(parsed.username, parsed.password)
response = requests.get(clean_url, auth=auth, headers=headers, timeout=10)
if response.status_code == 200:
logger.debug(f"Successfully authenticated using HTTP Basic for {clean_url}")
except Exception as auth_error:
logger.debug(f"Authentication setup error: {auth_error}")
# Fallback to original URL with embedded credentials
response = requests.get(url, headers=headers, timeout=10)
else:
# No credentials in URL, make request as-is
response = requests.get(url, headers=headers, timeout=10)
if response.status_code == 200:
# Convert response content to numpy array
nparr = np.frombuffer(response.content, np.uint8)
# Decode image
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if frame is not None:
logger.debug(f"Successfully fetched snapshot from {clean_url}, shape: {frame.shape}")
return frame
else:
logger.error(f"Failed to decode image from snapshot URL: {clean_url}")
return None
else:
logger.error(f"Failed to fetch snapshot (status code {response.status_code}): {clean_url}")
return None
except Exception as e:
logger.error(f"Exception fetching snapshot from {url}: {str(e)}")
return None
# Helper to get crop coordinates from stream
def get_crop_coords(stream):
return {
"cropX1": stream.get("cropX1"),
"cropY1": stream.get("cropY1"),
"cropX2": stream.get("cropX2"),
"cropY2": stream.get("cropY2")
}
####################################################
# REST API endpoint for image retrieval
####################################################
@app.get("/camera/{camera_id}/image")
async def get_camera_image(camera_id: str):
"""
Get the current frame from a camera as JPEG image
"""
try:
# URL decode the camera_id to handle encoded characters like %3B for semicolon
from urllib.parse import unquote
original_camera_id = camera_id
camera_id = unquote(camera_id)
logger.debug(f"REST API request: original='{original_camera_id}', decoded='{camera_id}'")
with streams_lock:
if camera_id not in streams:
logger.warning(f"Camera ID '{camera_id}' not found in streams. Current streams: {list(streams.keys())}")
raise HTTPException(status_code=404, detail=f"Camera {camera_id} not found or not active")
# Check if we have a cached frame for this camera
if camera_id not in latest_frames:
logger.warning(f"No cached frame available for camera '{camera_id}'.")
raise HTTPException(status_code=404, detail=f"No frame available for camera {camera_id}")
frame = latest_frames[camera_id]
logger.debug(f"Retrieved cached frame for camera '{camera_id}', frame shape: {frame.shape}")
# Encode frame as JPEG
success, buffer_img = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
if not success:
raise HTTPException(status_code=500, detail="Failed to encode image as JPEG")
# Return image as binary response
return Response(content=buffer_img.tobytes(), media_type="image/jpeg")
except HTTPException:
raise
except Exception as e:
logger.error(f"Error retrieving image for camera {camera_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
####################################################
# Detection and frame processing functions
####################################################
@app.websocket("/")
async def detect(websocket: WebSocket):
logger.info("WebSocket connection accepted")
persistent_data_dict = {}
async def handle_detection(camera_id, stream, frame, websocket, model_tree, persistent_data):
try:
# Apply crop if specified
cropped_frame = frame
if all(coord is not None for coord in [stream.get("cropX1"), stream.get("cropY1"), stream.get("cropX2"), stream.get("cropY2")]):
cropX1, cropY1, cropX2, cropY2 = stream["cropX1"], stream["cropY1"], stream["cropX2"], stream["cropY2"]
cropped_frame = frame[cropY1:cropY2, cropX1:cropX2]
logger.debug(f"Applied crop coordinates ({cropX1}, {cropY1}, {cropX2}, {cropY2}) to frame for camera {camera_id}")
logger.debug(f"Processing frame for camera {camera_id} with model {stream['modelId']}")
start_time = time.time()
# Extract display identifier for session ID lookup
subscription_parts = stream["subscriptionIdentifier"].split(';')
display_identifier = subscription_parts[0] if subscription_parts else None
session_id = session_ids.get(display_identifier) if display_identifier else None
# Create context for pipeline execution
pipeline_context = {
"camera_id": camera_id,
"display_id": display_identifier,
"session_id": session_id
}
detection_result = run_pipeline(cropped_frame, model_tree, context=pipeline_context)
process_time = (time.time() - start_time) * 1000
logger.debug(f"Detection for camera {camera_id} completed in {process_time:.2f}ms")
# Log the raw detection result for debugging
logger.debug(f"Raw detection result for camera {camera_id}:\n{json.dumps(detection_result, indent=2, default=str)}")
# Direct class result (no detections/classifications structure)
if detection_result and isinstance(detection_result, dict) and "class" in detection_result and "confidence" in detection_result:
highest_confidence_detection = {
"class": detection_result.get("class", "none"),
"confidence": detection_result.get("confidence", 1.0),
"box": [0, 0, 0, 0] # Empty bounding box for classifications
}
# Handle case when no detections found or result is empty
elif not detection_result or not detection_result.get("detections"):
# Check if we have classification results
if detection_result and detection_result.get("classifications"):
# Get the highest confidence classification
classifications = detection_result.get("classifications", [])
highest_confidence_class = max(classifications, key=lambda x: x.get("confidence", 0)) if classifications else None
if highest_confidence_class:
highest_confidence_detection = {
"class": highest_confidence_class.get("class", "none"),
"confidence": highest_confidence_class.get("confidence", 1.0),
"box": [0, 0, 0, 0] # Empty bounding box for classifications
}
else:
highest_confidence_detection = {
"class": "none",
"confidence": 1.0,
"box": [0, 0, 0, 0]
}
else:
highest_confidence_detection = {
"class": "none",
"confidence": 1.0,
"box": [0, 0, 0, 0]
}
else:
# Find detection with highest confidence
detections = detection_result.get("detections", [])
highest_confidence_detection = max(detections, key=lambda x: x.get("confidence", 0)) if detections else {
"class": "none",
"confidence": 1.0,
"box": [0, 0, 0, 0]
}
# Convert detection format to match protocol - flatten detection attributes
detection_dict = {}
# Handle different detection result formats
if isinstance(highest_confidence_detection, dict):
# Copy all fields from the detection result
for key, value in highest_confidence_detection.items():
if key not in ["box", "id"]: # Skip internal fields
detection_dict[key] = value
detection_data = {
"type": "imageDetection",
"subscriptionIdentifier": stream["subscriptionIdentifier"],
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S.%fZ", time.gmtime()),
"data": {
"detection": detection_dict,
"modelId": stream["modelId"],
"modelName": stream["modelName"]
}
}
# Add session ID if available
if session_id is not None:
detection_data["sessionId"] = session_id
if highest_confidence_detection["class"] != "none":
logger.info(f"Camera {camera_id}: Detected {highest_confidence_detection['class']} with confidence {highest_confidence_detection['confidence']:.2f} using model {stream['modelName']}")
# Log session ID if available
if session_id:
logger.debug(f"Detection associated with session ID: {session_id}")
await websocket.send_json(detection_data)
logger.debug(f"Sent detection data to client for camera {camera_id}")
return persistent_data
except Exception as e:
logger.error(f"Error in handle_detection for camera {camera_id}: {str(e)}", exc_info=True)
return persistent_data
def frame_reader(camera_id, cap, buffer, stop_event):
retries = 0
logger.info(f"Starting frame reader thread for camera {camera_id}")
frame_count = 0
last_log_time = time.time()
try:
# Log initial camera status and properties
if cap.isOpened():
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
logger.info(f"Camera {camera_id} opened successfully with resolution {width}x{height}, FPS: {fps}")
else:
logger.error(f"Camera {camera_id} failed to open initially")
while not stop_event.is_set():
try:
if not cap.isOpened():
logger.error(f"Camera {camera_id} is not open before trying to read")
# Attempt to reopen
cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"])
time.sleep(reconnect_interval)
continue
logger.debug(f"Attempting to read frame from camera {camera_id}")
ret, frame = cap.read()
if not ret:
logger.warning(f"Connection lost for camera: {camera_id}, retry {retries+1}/{max_retries}")
cap.release()
time.sleep(reconnect_interval)
retries += 1
if retries > max_retries and max_retries != -1:
logger.error(f"Max retries reached for camera: {camera_id}, stopping frame reader")
break
# Re-open
logger.info(f"Attempting to reopen RTSP stream for camera: {camera_id}")
cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"])
if not cap.isOpened():
logger.error(f"Failed to reopen RTSP stream for camera: {camera_id}")
continue
logger.info(f"Successfully reopened RTSP stream for camera: {camera_id}")
continue
# Successfully read a frame
frame_count += 1
current_time = time.time()
# Log frame stats every 5 seconds
if current_time - last_log_time > 5:
logger.info(f"Camera {camera_id}: Read {frame_count} frames in the last {current_time - last_log_time:.1f} seconds")
frame_count = 0
last_log_time = current_time
logger.debug(f"Successfully read frame from camera {camera_id}, shape: {frame.shape}")
retries = 0
# Overwrite old frame if buffer is full
if not buffer.empty():
try:
buffer.get_nowait()
logger.debug(f"[frame_reader] Removed old frame from buffer for camera {camera_id}")
except queue.Empty:
pass
buffer.put(frame)
logger.debug(f"[frame_reader] Added new frame to buffer for camera {camera_id}. Buffer size: {buffer.qsize()}")
# Short sleep to avoid CPU overuse
time.sleep(0.01)
except cv2.error as e:
logger.error(f"OpenCV error for camera {camera_id}: {e}", exc_info=True)
cap.release()
time.sleep(reconnect_interval)
retries += 1
if retries > max_retries and max_retries != -1:
logger.error(f"Max retries reached after OpenCV error for camera {camera_id}")
break
logger.info(f"Attempting to reopen RTSP stream after OpenCV error for camera: {camera_id}")
cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"])
if not cap.isOpened():
logger.error(f"Failed to reopen RTSP stream for camera {camera_id} after OpenCV error")
continue
logger.info(f"Successfully reopened RTSP stream after OpenCV error for camera: {camera_id}")
except Exception as e:
logger.error(f"Unexpected error for camera {camera_id}: {str(e)}", exc_info=True)
cap.release()
break
except Exception as e:
logger.error(f"Error in frame_reader thread for camera {camera_id}: {str(e)}", exc_info=True)
finally:
logger.info(f"Frame reader thread for camera {camera_id} is exiting")
if cap and cap.isOpened():
cap.release()
def snapshot_reader(camera_id, snapshot_url, snapshot_interval, buffer, stop_event):
"""Frame reader that fetches snapshots from HTTP/HTTPS URL at specified intervals"""
retries = 0
logger.info(f"Starting snapshot reader thread for camera {camera_id} from {snapshot_url}")
frame_count = 0
last_log_time = time.time()
try:
interval_seconds = snapshot_interval / 1000.0 # Convert milliseconds to seconds
logger.info(f"Snapshot interval for camera {camera_id}: {interval_seconds}s")
while not stop_event.is_set():
try:
start_time = time.time()
frame = fetch_snapshot(snapshot_url)
if frame is None:
logger.warning(f"Failed to fetch snapshot for camera: {camera_id}, retry {retries+1}/{max_retries}")
retries += 1
if retries > max_retries and max_retries != -1:
logger.error(f"Max retries reached for snapshot camera: {camera_id}, stopping reader")
break
time.sleep(min(interval_seconds, reconnect_interval))
continue
# Successfully fetched a frame
frame_count += 1
current_time = time.time()
# Log frame stats every 5 seconds
if current_time - last_log_time > 5:
logger.info(f"Camera {camera_id}: Fetched {frame_count} snapshots in the last {current_time - last_log_time:.1f} seconds")
frame_count = 0
last_log_time = current_time
logger.debug(f"Successfully fetched snapshot from camera {camera_id}, shape: {frame.shape}")
retries = 0
# Overwrite old frame if buffer is full
if not buffer.empty():
try:
buffer.get_nowait()
logger.debug(f"[snapshot_reader] Removed old snapshot from buffer for camera {camera_id}")
except queue.Empty:
pass
buffer.put(frame)
logger.debug(f"[snapshot_reader] Added new snapshot to buffer for camera {camera_id}. Buffer size: {buffer.qsize()}")
# Wait for the specified interval
elapsed = time.time() - start_time
sleep_time = max(interval_seconds - elapsed, 0)
if sleep_time > 0:
time.sleep(sleep_time)
except Exception as e:
logger.error(f"Unexpected error fetching snapshot for camera {camera_id}: {str(e)}", exc_info=True)
retries += 1
if retries > max_retries and max_retries != -1:
logger.error(f"Max retries reached after error for snapshot camera {camera_id}")
break
time.sleep(min(interval_seconds, reconnect_interval))
except Exception as e:
logger.error(f"Error in snapshot_reader thread for camera {camera_id}: {str(e)}", exc_info=True)
finally:
logger.info(f"Snapshot reader thread for camera {camera_id} is exiting")
async def process_streams():
logger.info("Started processing streams")
try:
while True:
start_time = time.time()
with streams_lock:
current_streams = list(streams.items())
if current_streams:
logger.debug(f"Processing {len(current_streams)} active streams")
else:
logger.debug("No active streams to process")
for camera_id, stream in current_streams:
buffer = stream["buffer"]
if buffer.empty():
logger.debug(f"Frame buffer is empty for camera {camera_id}")
continue
logger.debug(f"Got frame from buffer for camera {camera_id}")
frame = buffer.get()
# Cache the frame for REST API access
latest_frames[camera_id] = frame.copy()
logger.debug(f"Cached frame for REST API access for camera {camera_id}")
with models_lock:
model_tree = models.get(camera_id, {}).get(stream["modelId"])
if not model_tree:
logger.warning(f"Model not found for camera {camera_id}, modelId {stream['modelId']}")
continue
logger.debug(f"Found model tree for camera {camera_id}, modelId {stream['modelId']}")
key = (camera_id, stream["modelId"])
persistent_data = persistent_data_dict.get(key, {})
logger.debug(f"Starting detection for camera {camera_id} with modelId {stream['modelId']}")
updated_persistent_data = await handle_detection(
camera_id, stream, frame, websocket, model_tree, persistent_data
)
persistent_data_dict[key] = updated_persistent_data
elapsed_time = (time.time() - start_time) * 1000 # ms
sleep_time = max(poll_interval - elapsed_time, 0)
logger.debug(f"Frame processing cycle: {elapsed_time:.2f}ms, sleeping for: {sleep_time:.2f}ms")
await asyncio.sleep(sleep_time / 1000.0)
except asyncio.CancelledError:
logger.info("Stream processing task cancelled")
except Exception as e:
logger.error(f"Error in process_streams: {str(e)}", exc_info=True)
async def send_heartbeat():
while True:
try:
cpu_usage = psutil.cpu_percent()
memory_usage = psutil.virtual_memory().percent
if torch.cuda.is_available():
gpu_usage = torch.cuda.utilization() if hasattr(torch.cuda, 'utilization') else None
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2)
else:
gpu_usage = None
gpu_memory_usage = None
camera_connections = [
{
"subscriptionIdentifier": stream["subscriptionIdentifier"],
"modelId": stream["modelId"],
"modelName": stream["modelName"],
"online": True,
**{k: v for k, v in get_crop_coords(stream).items() if v is not None}
}
for camera_id, stream in streams.items()
]
state_report = {
"type": "stateReport",
"cpuUsage": cpu_usage,
"memoryUsage": memory_usage,
"gpuUsage": gpu_usage,
"gpuMemoryUsage": gpu_memory_usage,
"cameraConnections": camera_connections
}
await websocket.send_text(json.dumps(state_report))
logger.debug(f"Sent stateReport as heartbeat: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, {len(camera_connections)} active cameras")
await asyncio.sleep(HEARTBEAT_INTERVAL)
except Exception as e:
logger.error(f"Error sending stateReport heartbeat: {e}")
break
async def on_message():
while True:
try:
msg = await websocket.receive_text()
logger.debug(f"Received message: {msg}")
data = json.loads(msg)
msg_type = data.get("type")
if msg_type == "subscribe":
payload = data.get("payload", {})
subscriptionIdentifier = payload.get("subscriptionIdentifier")
rtsp_url = payload.get("rtspUrl")
snapshot_url = payload.get("snapshotUrl")
snapshot_interval = payload.get("snapshotInterval")
model_url = payload.get("modelUrl")
modelId = payload.get("modelId")
modelName = payload.get("modelName")
cropX1 = payload.get("cropX1")
cropY1 = payload.get("cropY1")
cropX2 = payload.get("cropX2")
cropY2 = payload.get("cropY2")
# Extract camera_id from subscriptionIdentifier (format: displayIdentifier;cameraIdentifier)
parts = subscriptionIdentifier.split(';')
if len(parts) != 2:
logger.error(f"Invalid subscriptionIdentifier format: {subscriptionIdentifier}")
continue
display_identifier, camera_identifier = parts
camera_id = subscriptionIdentifier # Use full subscriptionIdentifier as camera_id for mapping
if model_url:
with models_lock:
if (camera_id not in models) or (modelId not in models[camera_id]):
logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}")
extraction_dir = os.path.join("models", camera_identifier, str(modelId))
os.makedirs(extraction_dir, exist_ok=True)
# If model_url is remote, download it first.
parsed = urlparse(model_url)
if parsed.scheme in ("http", "https"):
logger.info(f"Downloading remote .mpta file from {model_url}")
filename = os.path.basename(parsed.path) or f"model_{modelId}.mpta"
local_mpta = os.path.join(extraction_dir, filename)
logger.debug(f"Download destination: {local_mpta}")
local_path = download_mpta(model_url, local_mpta)
if not local_path:
logger.error(f"Failed to download the remote .mpta file from {model_url}")
error_response = {
"type": "error",
"subscriptionIdentifier": subscriptionIdentifier,
"error": f"Failed to download model from {model_url}"
}
await websocket.send_json(error_response)
continue
model_tree = load_pipeline_from_zip(local_path, extraction_dir)
else:
logger.info(f"Loading local .mpta file from {model_url}")
# Check if file exists before attempting to load
if not os.path.exists(model_url):
logger.error(f"Local .mpta file not found: {model_url}")
logger.debug(f"Current working directory: {os.getcwd()}")
error_response = {
"type": "error",
"subscriptionIdentifier": subscriptionIdentifier,
"error": f"Model file not found: {model_url}"
}
await websocket.send_json(error_response)
continue
model_tree = load_pipeline_from_zip(model_url, extraction_dir)
if model_tree is None:
logger.error(f"Failed to load model {modelId} from .mpta file for camera {camera_id}")
error_response = {
"type": "error",
"subscriptionIdentifier": subscriptionIdentifier,
"error": f"Failed to load model {modelId}"
}
await websocket.send_json(error_response)
continue
if camera_id not in models:
models[camera_id] = {}
models[camera_id][modelId] = model_tree
logger.info(f"Successfully loaded model {modelId} for camera {camera_id}")
logger.debug(f"Model extraction directory: {extraction_dir}")
if camera_id and (rtsp_url or snapshot_url):
with streams_lock:
# Determine camera URL for shared stream management
camera_url = snapshot_url if snapshot_url else rtsp_url
if camera_id not in streams and len(streams) < max_streams:
# Check if we already have a stream for this camera URL
shared_stream = camera_streams.get(camera_url)
if shared_stream:
# Reuse existing stream
logger.info(f"Reusing existing stream for camera URL: {camera_url}")
buffer = shared_stream["buffer"]
stop_event = shared_stream["stop_event"]
thread = shared_stream["thread"]
mode = shared_stream["mode"]
# Increment reference count
shared_stream["ref_count"] = shared_stream.get("ref_count", 0) + 1
else:
# Create new stream
buffer = queue.Queue(maxsize=1)
stop_event = threading.Event()
if snapshot_url and snapshot_interval:
logger.info(f"Creating new snapshot stream for camera {camera_id}: {snapshot_url}")
thread = threading.Thread(target=snapshot_reader, args=(camera_id, snapshot_url, snapshot_interval, buffer, stop_event))
thread.daemon = True
thread.start()
mode = "snapshot"
# Store shared stream info
shared_stream = {
"buffer": buffer,
"thread": thread,
"stop_event": stop_event,
"mode": mode,
"url": snapshot_url,
"snapshot_interval": snapshot_interval,
"ref_count": 1
}
camera_streams[camera_url] = shared_stream
elif rtsp_url:
logger.info(f"Creating new RTSP stream for camera {camera_id}: {rtsp_url}")
cap = cv2.VideoCapture(rtsp_url)
if not cap.isOpened():
logger.error(f"Failed to open RTSP stream for camera {camera_id}")
continue
thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event))
thread.daemon = True
thread.start()
mode = "rtsp"
# Store shared stream info
shared_stream = {
"buffer": buffer,
"thread": thread,
"stop_event": stop_event,
"mode": mode,
"url": rtsp_url,
"cap": cap,
"ref_count": 1
}
camera_streams[camera_url] = shared_stream
else:
logger.error(f"No valid URL provided for camera {camera_id}")
continue
# Create stream info for this subscription
stream_info = {
"buffer": buffer,
"thread": thread,
"stop_event": stop_event,
"modelId": modelId,
"modelName": modelName,
"subscriptionIdentifier": subscriptionIdentifier,
"cropX1": cropX1,
"cropY1": cropY1,
"cropX2": cropX2,
"cropY2": cropY2,
"mode": mode,
"camera_url": camera_url
}
if mode == "snapshot":
stream_info["snapshot_url"] = snapshot_url
stream_info["snapshot_interval"] = snapshot_interval
elif mode == "rtsp":
stream_info["rtsp_url"] = rtsp_url
stream_info["cap"] = shared_stream["cap"]
streams[camera_id] = stream_info
subscription_to_camera[camera_id] = camera_url
elif camera_id and camera_id in streams:
# If already subscribed, unsubscribe first
logger.info(f"Resubscribing to camera {camera_id}")
# Note: Keep models in memory for reuse across subscriptions
elif msg_type == "unsubscribe":
payload = data.get("payload", {})
subscriptionIdentifier = payload.get("subscriptionIdentifier")
camera_id = subscriptionIdentifier
with streams_lock:
if camera_id and camera_id in streams:
stream = streams.pop(camera_id)
camera_url = subscription_to_camera.pop(camera_id, None)
if camera_url and camera_url in camera_streams:
shared_stream = camera_streams[camera_url]
shared_stream["ref_count"] -= 1
# If no more references, stop the shared stream
if shared_stream["ref_count"] <= 0:
logger.info(f"Stopping shared stream for camera URL: {camera_url}")
shared_stream["stop_event"].set()
shared_stream["thread"].join()
if "cap" in shared_stream:
shared_stream["cap"].release()
del camera_streams[camera_url]
else:
logger.info(f"Shared stream for {camera_url} still has {shared_stream['ref_count']} references")
# Clean up cached frame
latest_frames.pop(camera_id, None)
logger.info(f"Unsubscribed from camera {camera_id}")
# Note: Keep models in memory for potential reuse
elif msg_type == "requestState":
cpu_usage = psutil.cpu_percent()
memory_usage = psutil.virtual_memory().percent
if torch.cuda.is_available():
gpu_usage = torch.cuda.utilization() if hasattr(torch.cuda, 'utilization') else None
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2)
else:
gpu_usage = None
gpu_memory_usage = None
camera_connections = [
{
"subscriptionIdentifier": stream["subscriptionIdentifier"],
"modelId": stream["modelId"],
"modelName": stream["modelName"],
"online": True,
**{k: v for k, v in get_crop_coords(stream).items() if v is not None}
}
for camera_id, stream in streams.items()
]
state_report = {
"type": "stateReport",
"cpuUsage": cpu_usage,
"memoryUsage": memory_usage,
"gpuUsage": gpu_usage,
"gpuMemoryUsage": gpu_memory_usage,
"cameraConnections": camera_connections
}
await websocket.send_text(json.dumps(state_report))
elif msg_type == "setSessionId":
payload = data.get("payload", {})
display_identifier = payload.get("displayIdentifier")
session_id = payload.get("sessionId")
if display_identifier:
# Store session ID for this display
if session_id is None:
session_ids.pop(display_identifier, None)
logger.info(f"Cleared session ID for display {display_identifier}")
else:
session_ids[display_identifier] = session_id
logger.info(f"Set session ID {session_id} for display {display_identifier}")
elif msg_type == "patchSession":
session_id = data.get("sessionId")
patch_data = data.get("data", {})
# For now, just acknowledge the patch - actual implementation depends on backend requirements
response = {
"type": "patchSessionResult",
"payload": {
"sessionId": session_id,
"success": True,
"message": "Session patch acknowledged"
}
}
await websocket.send_json(response)
logger.info(f"Acknowledged patch for session {session_id}")
else:
logger.error(f"Unknown message type: {msg_type}")
except json.JSONDecodeError:
logger.error("Received invalid JSON message")
except (WebSocketDisconnect, ConnectionClosedError) as e:
logger.warning(f"WebSocket disconnected: {e}")
break
except Exception as e:
logger.error(f"Error handling message: {e}")
break
try:
await websocket.accept()
stream_task = asyncio.create_task(process_streams())
heartbeat_task = asyncio.create_task(send_heartbeat())
message_task = asyncio.create_task(on_message())
await asyncio.gather(heartbeat_task, message_task)
except Exception as e:
logger.error(f"Error in detect websocket: {e}")
finally:
stream_task.cancel()
await stream_task
with streams_lock:
# Clean up shared camera streams
for camera_url, shared_stream in camera_streams.items():
shared_stream["stop_event"].set()
shared_stream["thread"].join()
if "cap" in shared_stream:
shared_stream["cap"].release()
while not shared_stream["buffer"].empty():
try:
shared_stream["buffer"].get_nowait()
except queue.Empty:
pass
logger.info(f"Released shared camera stream for {camera_url}")
streams.clear()
camera_streams.clear()
subscription_to_camera.clear()
with models_lock:
models.clear()
latest_frames.clear()
session_ids.clear()
logger.info("WebSocket connection closed")

View file

@ -1,9 +1,7 @@
{
"poll_interval_ms": 100,
"max_streams": 20,
"max_streams": 5,
"target_fps": 2,
"reconnect_interval_sec": 10,
"max_retries": -1,
"rtsp_buffer_size": 3,
"rtsp_tcp_transport": true
"reconnect_interval_sec": 5,
"max_retries": -1
}

View file

@ -1 +0,0 @@
# Core package for detector worker

View file

@ -1 +0,0 @@
# Communication module for WebSocket and HTTP handling

View file

@ -1,212 +0,0 @@
"""
Message types, constants, and validation functions for WebSocket communication.
"""
import json
import logging
from typing import Dict, Any, Optional, Union
from .models import (
IncomingMessage, OutgoingMessage,
SetSubscriptionListMessage, SetSessionIdMessage, SetProgressionStageMessage,
RequestStateMessage, PatchSessionResultMessage,
StateReportMessage, ImageDetectionMessage, PatchSessionMessage
)
logger = logging.getLogger(__name__)
# Message type constants
class MessageTypes:
"""WebSocket message type constants."""
# Incoming from backend
SET_SUBSCRIPTION_LIST = "setSubscriptionList"
SET_SESSION_ID = "setSessionId"
SET_PROGRESSION_STAGE = "setProgressionStage"
REQUEST_STATE = "requestState"
PATCH_SESSION_RESULT = "patchSessionResult"
# Outgoing to backend
STATE_REPORT = "stateReport"
IMAGE_DETECTION = "imageDetection"
PATCH_SESSION = "patchSession"
def parse_incoming_message(raw_message: str) -> Optional[IncomingMessage]:
"""
Parse incoming WebSocket message and validate against known types.
Args:
raw_message: Raw JSON string from WebSocket
Returns:
Parsed message object or None if invalid
"""
try:
data = json.loads(raw_message)
message_type = data.get("type")
if not message_type:
logger.error("Message missing 'type' field")
return None
# Route to appropriate message class
if message_type == MessageTypes.SET_SUBSCRIPTION_LIST:
return SetSubscriptionListMessage(**data)
elif message_type == MessageTypes.SET_SESSION_ID:
return SetSessionIdMessage(**data)
elif message_type == MessageTypes.SET_PROGRESSION_STAGE:
return SetProgressionStageMessage(**data)
elif message_type == MessageTypes.REQUEST_STATE:
return RequestStateMessage(**data)
elif message_type == MessageTypes.PATCH_SESSION_RESULT:
return PatchSessionResultMessage(**data)
else:
logger.warning(f"Unknown message type: {message_type}")
return None
except json.JSONDecodeError as e:
logger.error(f"Failed to decode JSON message: {e}")
return None
except Exception as e:
logger.error(f"Failed to parse incoming message: {e}")
return None
def serialize_outgoing_message(message: OutgoingMessage) -> str:
"""
Serialize outgoing message to JSON string.
Args:
message: Message object to serialize
Returns:
JSON string representation
"""
try:
# For ImageDetectionMessage, we need to include None values for abandonment detection
from .models import ImageDetectionMessage
if isinstance(message, ImageDetectionMessage):
return message.model_dump_json(exclude_none=False)
else:
return message.model_dump_json(exclude_none=True)
except Exception as e:
logger.error(f"Failed to serialize outgoing message: {e}")
raise
def validate_subscription_identifier(identifier: str) -> bool:
"""
Validate subscription identifier format (displayId;cameraId).
Args:
identifier: Subscription identifier to validate
Returns:
True if valid format, False otherwise
"""
if not identifier or not isinstance(identifier, str):
return False
parts = identifier.split(';')
if len(parts) != 2:
logger.error(f"Invalid subscription identifier format: {identifier}")
return False
display_id, camera_id = parts
if not display_id or not camera_id:
logger.error(f"Empty display or camera ID in identifier: {identifier}")
return False
return True
def extract_display_identifier(subscription_identifier: str) -> Optional[str]:
"""
Extract display identifier from subscription identifier.
Args:
subscription_identifier: Full subscription identifier (displayId;cameraId)
Returns:
Display identifier or None if invalid format
"""
if not validate_subscription_identifier(subscription_identifier):
return None
return subscription_identifier.split(';')[0]
def create_state_report(cpu_usage: float, memory_usage: float,
gpu_usage: Optional[float] = None,
gpu_memory_usage: Optional[float] = None,
camera_connections: Optional[list] = None) -> StateReportMessage:
"""
Create a state report message with system metrics.
Args:
cpu_usage: CPU usage percentage
memory_usage: Memory usage percentage
gpu_usage: GPU usage percentage (optional)
gpu_memory_usage: GPU memory usage in MB (optional)
camera_connections: List of active camera connections
Returns:
StateReportMessage object
"""
return StateReportMessage(
cpuUsage=cpu_usage,
memoryUsage=memory_usage,
gpuUsage=gpu_usage,
gpuMemoryUsage=gpu_memory_usage,
cameraConnections=camera_connections or []
)
def create_image_detection(subscription_identifier: str, detection_data: Union[Dict[str, Any], None],
model_id: int, model_name: str) -> ImageDetectionMessage:
"""
Create an image detection message.
Args:
subscription_identifier: Camera subscription identifier
detection_data: Detection results - Dict for data, {} for empty, None for abandonment
model_id: Model identifier
model_name: Model name
Returns:
ImageDetectionMessage object
"""
from .models import DetectionData
from typing import Union
# Handle three cases:
# 1. None = car abandonment (detection: null)
# 2. {} = empty detection (triggers session creation)
# 3. {...} = full detection data (updates session)
data = DetectionData(
detection=detection_data,
modelId=model_id,
modelName=model_name
)
return ImageDetectionMessage(
subscriptionIdentifier=subscription_identifier,
data=data
)
def create_patch_session(session_id: int, patch_data: Dict[str, Any]) -> PatchSessionMessage:
"""
Create a patch session message.
Args:
session_id: Session ID to patch
patch_data: Partial session data to update
Returns:
PatchSessionMessage object
"""
return PatchSessionMessage(
sessionId=session_id,
data=patch_data
)

View file

@ -1,150 +0,0 @@
"""
Message data structures for WebSocket communication.
Based on worker.md protocol specification.
"""
from typing import Dict, Any, List, Optional, Union, Literal
from pydantic import BaseModel, Field
from datetime import datetime
class SubscriptionObject(BaseModel):
"""Individual camera subscription configuration."""
subscriptionIdentifier: str = Field(..., description="Format: displayId;cameraId")
rtspUrl: Optional[str] = Field(None, description="RTSP stream URL")
snapshotUrl: Optional[str] = Field(None, description="HTTP snapshot URL")
snapshotInterval: Optional[int] = Field(None, description="Snapshot interval in milliseconds")
modelUrl: str = Field(..., description="Pre-signed URL to .mpta file")
modelId: int = Field(..., description="Unique model identifier")
modelName: str = Field(..., description="Human-readable model name")
cropX1: Optional[int] = Field(None, description="Crop region X1 coordinate")
cropY1: Optional[int] = Field(None, description="Crop region Y1 coordinate")
cropX2: Optional[int] = Field(None, description="Crop region X2 coordinate")
cropY2: Optional[int] = Field(None, description="Crop region Y2 coordinate")
class CameraConnection(BaseModel):
"""Camera connection status for state reporting."""
subscriptionIdentifier: str
modelId: int
modelName: str
online: bool
cropX1: Optional[int] = None
cropY1: Optional[int] = None
cropX2: Optional[int] = None
cropY2: Optional[int] = None
class DetectionData(BaseModel):
"""
Detection result data structure.
Supports three cases:
1. Empty detection: detection = {} (triggers session creation)
2. Full detection: detection = {"carBrand": "Honda", ...} (updates session)
3. Null detection: detection = None (car abandonment)
"""
model_config = {
"json_encoders": {type(None): lambda v: None},
"arbitrary_types_allowed": True
}
detection: Union[Dict[str, Any], None] = Field(
default_factory=dict,
description="Detection results: {} for empty, {...} for data, None/null for abandonment"
)
modelId: int
modelName: str
# Incoming Messages from Backend to Worker
class SetSubscriptionListMessage(BaseModel):
"""Complete subscription list for declarative state management."""
type: Literal["setSubscriptionList"] = "setSubscriptionList"
subscriptions: List[SubscriptionObject]
class SetSessionIdPayload(BaseModel):
"""Session ID association payload."""
displayIdentifier: str
sessionId: Optional[int] = None
class SetSessionIdMessage(BaseModel):
"""Associate session ID with display."""
type: Literal["setSessionId"] = "setSessionId"
payload: SetSessionIdPayload
class SetProgressionStagePayload(BaseModel):
"""Progression stage payload."""
displayIdentifier: str
progressionStage: Optional[str] = None
class SetProgressionStageMessage(BaseModel):
"""Set progression stage for display."""
type: Literal["setProgressionStage"] = "setProgressionStage"
payload: SetProgressionStagePayload
class RequestStateMessage(BaseModel):
"""Request current worker state."""
type: Literal["requestState"] = "requestState"
class PatchSessionResultPayload(BaseModel):
"""Patch session result payload."""
sessionId: int
success: bool
message: str
class PatchSessionResultMessage(BaseModel):
"""Response to patch session request."""
type: Literal["patchSessionResult"] = "patchSessionResult"
payload: PatchSessionResultPayload
# Outgoing Messages from Worker to Backend
class StateReportMessage(BaseModel):
"""Periodic heartbeat with system metrics."""
type: Literal["stateReport"] = "stateReport"
cpuUsage: float
memoryUsage: float
gpuUsage: Optional[float] = None
gpuMemoryUsage: Optional[float] = None
cameraConnections: List[CameraConnection]
class ImageDetectionMessage(BaseModel):
"""Detection event message."""
type: Literal["imageDetection"] = "imageDetection"
subscriptionIdentifier: str
timestamp: str = Field(default_factory=lambda: datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ"))
data: DetectionData
class PatchSessionMessage(BaseModel):
"""Request to modify session data."""
type: Literal["patchSession"] = "patchSession"
sessionId: int
data: Dict[str, Any] = Field(..., description="Partial DisplayPersistentData structure")
# Union type for all incoming messages
IncomingMessage = Union[
SetSubscriptionListMessage,
SetSessionIdMessage,
SetProgressionStageMessage,
RequestStateMessage,
PatchSessionResultMessage
]
# Union type for all outgoing messages
OutgoingMessage = Union[
StateReportMessage,
ImageDetectionMessage,
PatchSessionMessage
]

View file

@ -1,234 +0,0 @@
"""
Worker state management for system metrics and subscription tracking.
"""
import logging
import psutil
import threading
from typing import Dict, Set, Optional, List
from dataclasses import dataclass, field
from .models import CameraConnection, SubscriptionObject
logger = logging.getLogger(__name__)
# Try to import torch and pynvml for GPU monitoring
try:
import torch
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
logger.warning("PyTorch not available, GPU metrics will not be collected")
try:
import pynvml
PYNVML_AVAILABLE = True
pynvml.nvmlInit()
logger.info("NVIDIA ML Python (pynvml) initialized successfully")
except ImportError:
PYNVML_AVAILABLE = False
logger.debug("pynvml not available, falling back to PyTorch GPU monitoring")
except Exception as e:
PYNVML_AVAILABLE = False
logger.warning(f"Failed to initialize pynvml: {e}")
@dataclass
class WorkerState:
"""Central state management for the detector worker."""
# Active subscriptions indexed by subscription identifier
subscriptions: Dict[str, SubscriptionObject] = field(default_factory=dict)
# Session ID mapping: display_identifier -> session_id
session_ids: Dict[str, int] = field(default_factory=dict)
# Progression stage mapping: display_identifier -> stage
progression_stages: Dict[str, str] = field(default_factory=dict)
# Active camera connections for state reporting
camera_connections: List[CameraConnection] = field(default_factory=list)
# Thread lock for state synchronization
_lock: threading.RLock = field(default_factory=threading.RLock)
def set_subscriptions(self, new_subscriptions: List[SubscriptionObject]) -> None:
"""
Update active subscriptions with declarative list from backend.
Args:
new_subscriptions: Complete list of desired subscriptions
"""
with self._lock:
# Convert to dict for easy lookup
new_sub_dict = {sub.subscriptionIdentifier: sub for sub in new_subscriptions}
# Log changes for debugging
current_ids = set(self.subscriptions.keys())
new_ids = set(new_sub_dict.keys())
added = new_ids - current_ids
removed = current_ids - new_ids
updated = current_ids & new_ids
if added:
logger.info(f"[State Update] Adding subscriptions: {added}")
if removed:
logger.info(f"[State Update] Removing subscriptions: {removed}")
if updated:
logger.info(f"[State Update] Updating subscriptions: {updated}")
# Replace entire subscription dict
self.subscriptions = new_sub_dict
# Update camera connections for state reporting
self._update_camera_connections()
def get_subscription(self, subscription_identifier: str) -> Optional[SubscriptionObject]:
"""Get subscription by identifier."""
with self._lock:
return self.subscriptions.get(subscription_identifier)
def get_all_subscriptions(self) -> List[SubscriptionObject]:
"""Get all active subscriptions."""
with self._lock:
return list(self.subscriptions.values())
def set_session_id(self, display_identifier: str, session_id: Optional[int]) -> None:
"""
Set or clear session ID for a display.
Args:
display_identifier: Display identifier
session_id: Session ID to set, or None to clear
"""
with self._lock:
if session_id is None:
self.session_ids.pop(display_identifier, None)
logger.info(f"[State Update] Cleared session ID for display {display_identifier}")
else:
self.session_ids[display_identifier] = session_id
logger.info(f"[State Update] Set session ID {session_id} for display {display_identifier}")
def get_session_id(self, display_identifier: str) -> Optional[int]:
"""Get session ID for display identifier."""
with self._lock:
return self.session_ids.get(display_identifier)
def get_session_id_for_subscription(self, subscription_identifier: str) -> Optional[int]:
"""Get session ID for subscription by extracting display identifier."""
from .messages import extract_display_identifier
display_id = extract_display_identifier(subscription_identifier)
if display_id:
return self.get_session_id(display_id)
return None
def set_progression_stage(self, display_identifier: str, stage: Optional[str]) -> None:
"""
Set or clear progression stage for a display.
Args:
display_identifier: Display identifier
stage: Progression stage to set, or None to clear
"""
with self._lock:
if stage is None:
self.progression_stages.pop(display_identifier, None)
logger.info(f"[State Update] Cleared progression stage for display {display_identifier}")
else:
self.progression_stages[display_identifier] = stage
logger.info(f"[State Update] Set progression stage '{stage}' for display {display_identifier}")
def get_progression_stage(self, display_identifier: str) -> Optional[str]:
"""Get progression stage for display identifier."""
with self._lock:
return self.progression_stages.get(display_identifier)
def _update_camera_connections(self) -> None:
"""Update camera connections list for state reporting."""
connections = []
for sub in self.subscriptions.values():
connection = CameraConnection(
subscriptionIdentifier=sub.subscriptionIdentifier,
modelId=sub.modelId,
modelName=sub.modelName,
online=True, # TODO: Add actual online status tracking
cropX1=sub.cropX1,
cropY1=sub.cropY1,
cropX2=sub.cropX2,
cropY2=sub.cropY2
)
connections.append(connection)
self.camera_connections = connections
def get_camera_connections(self) -> List[CameraConnection]:
"""Get current camera connections for state reporting."""
with self._lock:
return self.camera_connections.copy()
class SystemMetrics:
"""System metrics collection for state reporting."""
@staticmethod
def get_cpu_usage() -> float:
"""Get current CPU usage percentage."""
try:
return psutil.cpu_percent(interval=0.1)
except Exception as e:
logger.error(f"Failed to get CPU usage: {e}")
return 0.0
@staticmethod
def get_memory_usage() -> float:
"""Get current memory usage percentage."""
try:
return psutil.virtual_memory().percent
except Exception as e:
logger.error(f"Failed to get memory usage: {e}")
return 0.0
@staticmethod
def get_gpu_usage() -> Optional[float]:
"""Get current GPU usage percentage."""
try:
# Prefer pynvml for accurate GPU utilization
if PYNVML_AVAILABLE:
handle = pynvml.nvmlDeviceGetHandleByIndex(0) # First GPU
utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
return float(utilization.gpu)
# Fallback to PyTorch memory-based estimation
elif TORCH_AVAILABLE and torch.cuda.is_available():
if hasattr(torch.cuda, 'utilization'):
return torch.cuda.utilization()
else:
# Estimate based on memory usage
allocated = torch.cuda.memory_allocated()
reserved = torch.cuda.memory_reserved()
if reserved > 0:
return (allocated / reserved) * 100
return None
except Exception as e:
logger.error(f"Failed to get GPU usage: {e}")
return None
@staticmethod
def get_gpu_memory_usage() -> Optional[float]:
"""Get current GPU memory usage in MB."""
if not TORCH_AVAILABLE:
return None
try:
if torch.cuda.is_available():
return torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB
return None
except Exception as e:
logger.error(f"Failed to get GPU memory usage: {e}")
return None
# Global worker state instance
worker_state = WorkerState()

View file

@ -1,677 +0,0 @@
"""
WebSocket message handling and protocol implementation.
"""
import asyncio
import json
import logging
import os
import cv2
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Optional
from fastapi import WebSocket, WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError
from .messages import (
parse_incoming_message, serialize_outgoing_message,
MessageTypes, create_state_report
)
from .models import (
SetSubscriptionListMessage, SetSessionIdMessage, SetProgressionStageMessage,
RequestStateMessage, PatchSessionResultMessage
)
from .state import worker_state, SystemMetrics
from ..models import ModelManager
from ..streaming.manager import shared_stream_manager
from ..tracking.integration import TrackingPipelineIntegration
logger = logging.getLogger(__name__)
# Constants
HEARTBEAT_INTERVAL = 2.0 # seconds
WORKER_TIMEOUT_MS = 10000
# Global model manager instance
model_manager = ModelManager()
class WebSocketHandler:
"""
Handles WebSocket connection lifecycle and message processing.
"""
def __init__(self, websocket: WebSocket):
self.websocket = websocket
self.connected = False
self._heartbeat_task: Optional[asyncio.Task] = None
self._message_task: Optional[asyncio.Task] = None
self._heartbeat_count = 0
self._last_processed_models: set = set() # Cache of last processed model IDs
async def handle_connection(self) -> None:
"""
Main connection handler that manages the WebSocket lifecycle.
Based on the original architecture from archive/app.py
"""
client_info = f"{self.websocket.client.host}:{self.websocket.client.port}" if self.websocket.client else "unknown"
logger.info(f"Starting WebSocket handler for {client_info}")
stream_task = None
try:
logger.info(f"Accepting WebSocket connection from {client_info}")
await self.websocket.accept()
self.connected = True
logger.info(f"WebSocket connection accepted and established for {client_info}")
# Send immediate heartbeat to show connection is alive
await self._send_immediate_heartbeat()
# Start background tasks (matching original architecture)
stream_task = asyncio.create_task(self._process_streams())
heartbeat_task = asyncio.create_task(self._send_heartbeat())
message_task = asyncio.create_task(self._handle_messages())
logger.info(f"WebSocket background tasks started for {client_info} (stream + heartbeat + message handler)")
# Wait for heartbeat and message tasks (stream runs independently)
await asyncio.gather(heartbeat_task, message_task)
except Exception as e:
logger.error(f"Error in WebSocket connection for {client_info}: {e}", exc_info=True)
finally:
logger.info(f"Cleaning up connection for {client_info}")
# Cancel stream task
if stream_task and not stream_task.done():
stream_task.cancel()
try:
await stream_task
except asyncio.CancelledError:
logger.debug(f"Stream task cancelled for {client_info}")
await self._cleanup()
async def _send_immediate_heartbeat(self) -> None:
"""Send immediate heartbeat on connection to show we're alive."""
try:
cpu_usage = SystemMetrics.get_cpu_usage()
memory_usage = SystemMetrics.get_memory_usage()
gpu_usage = SystemMetrics.get_gpu_usage()
gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
camera_connections = worker_state.get_camera_connections()
state_report = create_state_report(
cpu_usage=cpu_usage,
memory_usage=memory_usage,
gpu_usage=gpu_usage,
gpu_memory_usage=gpu_memory_usage,
camera_connections=camera_connections
)
await self._send_message(state_report)
logger.info(f"[TX → Backend] Initial stateReport: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, "
f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras")
except Exception as e:
logger.error(f"Error sending immediate heartbeat: {e}")
async def _send_heartbeat(self) -> None:
"""Send periodic state reports as heartbeat."""
while self.connected:
try:
# Collect system metrics
cpu_usage = SystemMetrics.get_cpu_usage()
memory_usage = SystemMetrics.get_memory_usage()
gpu_usage = SystemMetrics.get_gpu_usage()
gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
camera_connections = worker_state.get_camera_connections()
# Create and send state report
state_report = create_state_report(
cpu_usage=cpu_usage,
memory_usage=memory_usage,
gpu_usage=gpu_usage,
gpu_memory_usage=gpu_memory_usage,
camera_connections=camera_connections
)
await self._send_message(state_report)
# Only log full details every 10th heartbeat, otherwise just show a dot
self._heartbeat_count += 1
if self._heartbeat_count % 10 == 0:
logger.info(f"[TX → Backend] Heartbeat #{self._heartbeat_count}: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, "
f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras")
else:
print(".", end="", flush=True) # Just show a dot to indicate heartbeat activity
await asyncio.sleep(HEARTBEAT_INTERVAL)
except Exception as e:
logger.error(f"Error sending heartbeat: {e}")
break
async def _handle_messages(self) -> None:
"""Handle incoming WebSocket messages."""
while self.connected:
try:
raw_message = await self.websocket.receive_text()
logger.info(f"[RX ← Backend] {raw_message}")
# Parse incoming message
message = parse_incoming_message(raw_message)
if not message:
logger.warning("Failed to parse incoming message")
continue
# Route message to appropriate handler
await self._route_message(message)
except (WebSocketDisconnect, ConnectionClosedError) as e:
logger.warning(f"WebSocket disconnected: {e}")
break
except json.JSONDecodeError:
logger.error("Received invalid JSON message")
except Exception as e:
logger.error(f"Error handling message: {e}")
break
async def _route_message(self, message) -> None:
"""Route parsed message to appropriate handler."""
message_type = message.type
try:
if message_type == MessageTypes.SET_SUBSCRIPTION_LIST:
await self._handle_set_subscription_list(message)
elif message_type == MessageTypes.SET_SESSION_ID:
await self._handle_set_session_id(message)
elif message_type == MessageTypes.SET_PROGRESSION_STAGE:
await self._handle_set_progression_stage(message)
elif message_type == MessageTypes.REQUEST_STATE:
await self._handle_request_state(message)
elif message_type == MessageTypes.PATCH_SESSION_RESULT:
await self._handle_patch_session_result(message)
else:
logger.warning(f"Unknown message type: {message_type}")
except Exception as e:
logger.error(f"Error handling {message_type} message: {e}")
async def _handle_set_subscription_list(self, message: SetSubscriptionListMessage) -> None:
"""Handle setSubscriptionList message for declarative subscription management."""
logger.info(f"[RX Processing] setSubscriptionList with {len(message.subscriptions)} subscriptions")
# Update worker state with new subscriptions
worker_state.set_subscriptions(message.subscriptions)
# Phase 2: Download and manage models
await self._ensure_models(message.subscriptions)
# Phase 3 & 4: Integrate with streaming management and tracking
await self._update_stream_subscriptions(message.subscriptions)
logger.info("Subscription list updated successfully")
async def _ensure_models(self, subscriptions) -> None:
"""Ensure all required models are downloaded and available."""
# Extract unique model requirements
unique_models = {}
for subscription in subscriptions:
model_id = subscription.modelId
if model_id not in unique_models:
unique_models[model_id] = {
'model_url': subscription.modelUrl,
'model_name': subscription.modelName
}
# Check if model set has changed to avoid redundant processing
current_model_ids = set(unique_models.keys())
if current_model_ids == self._last_processed_models:
logger.debug(f"[Model Management] Model set unchanged {list(current_model_ids)}, skipping checks")
return
logger.info(f"[Model Management] Processing {len(unique_models)} unique models: {list(unique_models.keys())}")
self._last_processed_models = current_model_ids
# Check and download models concurrently
download_tasks = []
for model_id, model_info in unique_models.items():
task = asyncio.create_task(
self._ensure_single_model(model_id, model_info['model_url'], model_info['model_name'])
)
download_tasks.append(task)
# Wait for all downloads to complete
if download_tasks:
results = await asyncio.gather(*download_tasks, return_exceptions=True)
# Log results
success_count = 0
for i, result in enumerate(results):
model_id = list(unique_models.keys())[i]
if isinstance(result, Exception):
logger.error(f"[Model Management] Failed to ensure model {model_id}: {result}")
elif result:
success_count += 1
logger.info(f"[Model Management] Model {model_id} ready for use")
else:
logger.error(f"[Model Management] Failed to ensure model {model_id}")
logger.info(f"[Model Management] Successfully ensured {success_count}/{len(unique_models)} models")
async def _update_stream_subscriptions(self, subscriptions) -> None:
"""Update streaming subscriptions with tracking integration."""
try:
# Convert subscriptions to the format expected by StreamManager
subscription_payloads = []
for subscription in subscriptions:
payload = {
'subscriptionIdentifier': subscription.subscriptionIdentifier,
'rtspUrl': subscription.rtspUrl,
'snapshotUrl': subscription.snapshotUrl,
'snapshotInterval': subscription.snapshotInterval,
'modelId': subscription.modelId,
'modelUrl': subscription.modelUrl,
'modelName': subscription.modelName
}
# Add crop coordinates if present
if hasattr(subscription, 'cropX1'):
payload.update({
'cropX1': subscription.cropX1,
'cropY1': subscription.cropY1,
'cropX2': subscription.cropX2,
'cropY2': subscription.cropY2
})
subscription_payloads.append(payload)
# Reconcile subscriptions with StreamManager
logger.info("[Streaming] Reconciling stream subscriptions with tracking")
reconcile_result = await self._reconcile_subscriptions_with_tracking(subscription_payloads)
logger.info(f"[Streaming] Subscription reconciliation complete: "
f"added={reconcile_result.get('added', 0)}, "
f"removed={reconcile_result.get('removed', 0)}, "
f"failed={reconcile_result.get('failed', 0)}")
except Exception as e:
logger.error(f"Error updating stream subscriptions: {e}", exc_info=True)
async def _reconcile_subscriptions_with_tracking(self, target_subscriptions) -> dict:
"""Reconcile subscriptions with tracking integration."""
try:
# Create separate tracking integrations for each subscription (camera isolation)
tracking_integrations = {}
for subscription_payload in target_subscriptions:
subscription_id = subscription_payload['subscriptionIdentifier']
model_id = subscription_payload['modelId']
# Create separate tracking integration per subscription for camera isolation
# Get pipeline configuration for this model
pipeline_parser = model_manager.get_pipeline_config(model_id)
if pipeline_parser:
# Create tracking integration with message sender (separate instance per camera)
tracking_integration = TrackingPipelineIntegration(
pipeline_parser, model_manager, model_id, self._send_message
)
# Initialize tracking model
success = await tracking_integration.initialize_tracking_model()
if success:
tracking_integrations[subscription_id] = tracking_integration
logger.info(f"[Tracking] Created isolated tracking integration for subscription {subscription_id} (model {model_id})")
else:
logger.warning(f"[Tracking] Failed to initialize tracking for subscription {subscription_id} (model {model_id})")
else:
logger.warning(f"[Tracking] No pipeline config found for model {model_id} in subscription {subscription_id}")
# Now reconcile with StreamManager, adding tracking integrations
current_subscription_ids = set()
for subscription_info in shared_stream_manager.get_all_subscriptions():
current_subscription_ids.add(subscription_info.subscription_id)
target_subscription_ids = {sub['subscriptionIdentifier'] for sub in target_subscriptions}
# Find subscriptions to remove and add
to_remove = current_subscription_ids - target_subscription_ids
to_add = target_subscription_ids - current_subscription_ids
# Remove old subscriptions
removed_count = 0
for subscription_id in to_remove:
if shared_stream_manager.remove_subscription(subscription_id):
removed_count += 1
logger.info(f"[Streaming] Removed subscription {subscription_id}")
# Add new subscriptions with tracking
added_count = 0
failed_count = 0
for subscription_payload in target_subscriptions:
subscription_id = subscription_payload['subscriptionIdentifier']
if subscription_id in to_add:
success = await self._add_subscription_with_tracking(
subscription_payload, tracking_integrations
)
if success:
added_count += 1
logger.info(f"[Streaming] Added subscription {subscription_id} with tracking")
else:
failed_count += 1
logger.error(f"[Streaming] Failed to add subscription {subscription_id}")
return {
'removed': removed_count,
'added': added_count,
'failed': failed_count,
'total_active': len(shared_stream_manager.get_all_subscriptions())
}
except Exception as e:
logger.error(f"Error in subscription reconciliation with tracking: {e}", exc_info=True)
return {'removed': 0, 'added': 0, 'failed': 0, 'total_active': 0}
async def _add_subscription_with_tracking(self, payload, tracking_integrations) -> bool:
"""Add a subscription with tracking integration."""
try:
from ..streaming.manager import StreamConfig
subscription_id = payload['subscriptionIdentifier']
camera_id = subscription_id.split(';')[-1]
model_id = payload['modelId']
logger.info(f"[SUBSCRIPTION_MAPPING] subscription_id='{subscription_id}' → camera_id='{camera_id}'")
# Get tracking integration for this subscription (camera-isolated)
tracking_integration = tracking_integrations.get(subscription_id)
# Extract crop coordinates if present
crop_coords = None
if all(key in payload for key in ['cropX1', 'cropY1', 'cropX2', 'cropY2']):
crop_coords = (
payload['cropX1'],
payload['cropY1'],
payload['cropX2'],
payload['cropY2']
)
# Create stream configuration
stream_config = StreamConfig(
camera_id=camera_id,
rtsp_url=payload.get('rtspUrl'),
snapshot_url=payload.get('snapshotUrl'),
snapshot_interval=payload.get('snapshotInterval', 5000),
max_retries=3,
)
# Add subscription to StreamManager with tracking
success = shared_stream_manager.add_subscription(
subscription_id=subscription_id,
stream_config=stream_config,
crop_coords=crop_coords,
model_id=model_id,
model_url=payload.get('modelUrl'),
tracking_integration=tracking_integration
)
if success and tracking_integration:
logger.info(f"[Tracking] Subscription {subscription_id} configured with isolated tracking for model {model_id}")
return success
except Exception as e:
logger.error(f"Error adding subscription with tracking: {e}", exc_info=True)
return False
async def _ensure_single_model(self, model_id: int, model_url: str, model_name: str) -> bool:
"""Ensure a single model is downloaded and available."""
try:
# Check if model is already available
if model_manager.is_model_downloaded(model_id):
logger.info(f"[Model Management] Model {model_id} ({model_name}) already available")
return True
# Download and extract model in a thread pool to avoid blocking the event loop
logger.info(f"[Model Management] Downloading model {model_id} ({model_name}) from {model_url}")
# Use asyncio.to_thread for CPU-bound operations (Python 3.9+)
# For compatibility, we'll use run_in_executor
loop = asyncio.get_event_loop()
model_path = await loop.run_in_executor(
None,
model_manager.ensure_model,
model_id,
model_url,
model_name
)
if model_path:
logger.info(f"[Model Management] Successfully prepared model {model_id} at {model_path}")
return True
else:
logger.error(f"[Model Management] Failed to prepare model {model_id}")
return False
except Exception as e:
logger.error(f"[Model Management] Exception ensuring model {model_id}: {str(e)}", exc_info=True)
return False
async def _save_snapshot(self, display_identifier: str, session_id: int) -> None:
"""
Save snapshot image to images folder after receiving sessionId.
Args:
display_identifier: Display identifier to match with subscriptionIdentifier
session_id: Session ID to include in filename
"""
try:
# Find subscription that matches the displayIdentifier
matching_subscription = None
for subscription in worker_state.get_all_subscriptions():
# Extract display ID from subscriptionIdentifier (format: displayId;cameraId)
from .messages import extract_display_identifier
sub_display_id = extract_display_identifier(subscription.subscriptionIdentifier)
if sub_display_id == display_identifier:
matching_subscription = subscription
break
if not matching_subscription:
logger.error(f"[Snapshot Save] No subscription found for display {display_identifier}")
return
if not matching_subscription.snapshotUrl:
logger.error(f"[Snapshot Save] No snapshotUrl found for display {display_identifier}")
return
# Ensure images directory exists (relative path for Docker bind mount)
images_dir = Path("images")
images_dir.mkdir(exist_ok=True)
# Generate filename with timestamp and session ID
timestamp = datetime.now(tz=timezone(timedelta(hours=7))).strftime("%Y%m%d_%H%M%S")
filename = f"{session_id}_{display_identifier}_{timestamp}.jpg"
filepath = images_dir / filename
# Use existing HTTPSnapshotReader to fetch snapshot
logger.info(f"[Snapshot Save] Fetching snapshot from {matching_subscription.snapshotUrl}")
# Run snapshot fetch in thread pool to avoid blocking async loop
loop = asyncio.get_event_loop()
frame = await loop.run_in_executor(None, self._fetch_snapshot_sync, matching_subscription.snapshotUrl)
if frame is not None:
# Save the image using OpenCV
success = cv2.imwrite(str(filepath), frame)
if success:
logger.info(f"[Snapshot Save] Successfully saved snapshot to {filepath}")
else:
logger.error(f"[Snapshot Save] Failed to save image file {filepath}")
else:
logger.error(f"[Snapshot Save] Failed to fetch snapshot from {matching_subscription.snapshotUrl}")
except Exception as e:
logger.error(f"[Snapshot Save] Error saving snapshot for display {display_identifier}: {e}", exc_info=True)
def _fetch_snapshot_sync(self, snapshot_url: str):
"""
Synchronous snapshot fetching using existing HTTPSnapshotReader infrastructure.
Args:
snapshot_url: URL to fetch snapshot from
Returns:
np.ndarray or None: Fetched frame or None on error
"""
try:
from ..streaming.readers import HTTPSnapshotReader
# Create temporary snapshot reader for single fetch
snapshot_reader = HTTPSnapshotReader(
camera_id="temp_snapshot",
snapshot_url=snapshot_url,
interval_ms=5000 # Not used for single fetch
)
# Use existing fetch_single_snapshot method
return snapshot_reader.fetch_single_snapshot()
except Exception as e:
logger.error(f"Error in sync snapshot fetch: {e}")
return None
async def _handle_set_session_id(self, message: SetSessionIdMessage) -> None:
"""Handle setSessionId message."""
display_identifier = message.payload.displayIdentifier
session_id = str(message.payload.sessionId) if message.payload.sessionId is not None else None
logger.info(f"[RX Processing] setSessionId for display {display_identifier}: {session_id}")
# Update worker state
worker_state.set_session_id(display_identifier, session_id)
# Update tracking integrations with session ID
shared_stream_manager.set_session_id(display_identifier, session_id)
async def _handle_set_progression_stage(self, message: SetProgressionStageMessage) -> None:
"""Handle setProgressionStage message."""
display_identifier = message.payload.displayIdentifier
stage = message.payload.progressionStage
logger.info(f"[RX Processing] setProgressionStage for display {display_identifier}: {stage}")
# Update worker state
worker_state.set_progression_stage(display_identifier, stage)
# Update tracking integration for car abandonment detection
session_id = worker_state.get_session_id(display_identifier)
if session_id:
shared_stream_manager.set_progression_stage(session_id, stage)
# Save snapshot image when progression stage is car_fueling
if stage == 'car_fueling' and session_id:
await self._save_snapshot(display_identifier, session_id)
# If stage indicates session is cleared/finished, clear from tracking
if stage in ['finished', 'cleared', 'idle']:
# Get session ID for this display and clear it
if session_id:
shared_stream_manager.clear_session_id(session_id)
logger.info(f"[Tracking] Cleared session {session_id} due to progression stage: {stage}")
async def _handle_request_state(self, message: RequestStateMessage) -> None:
"""Handle requestState message by sending immediate state report."""
logger.debug("[RX Processing] requestState - sending immediate state report")
# Collect metrics and send state report
cpu_usage = SystemMetrics.get_cpu_usage()
memory_usage = SystemMetrics.get_memory_usage()
gpu_usage = SystemMetrics.get_gpu_usage()
gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
camera_connections = worker_state.get_camera_connections()
state_report = create_state_report(
cpu_usage=cpu_usage,
memory_usage=memory_usage,
gpu_usage=gpu_usage,
gpu_memory_usage=gpu_memory_usage,
camera_connections=camera_connections
)
await self._send_message(state_report)
async def _handle_patch_session_result(self, message: PatchSessionResultMessage) -> None:
"""Handle patchSessionResult message."""
payload = message.payload
logger.info(f"[RX Processing] patchSessionResult for session {payload.sessionId}: "
f"success={payload.success}, message='{payload.message}'")
# TODO: Handle patch session result if needed
# For now, just log the response
async def _send_message(self, message) -> None:
"""Send message to backend via WebSocket."""
if not self.connected:
logger.warning("Cannot send message: WebSocket not connected")
return
try:
json_message = serialize_outgoing_message(message)
await self.websocket.send_text(json_message)
# Log non-heartbeat messages only (heartbeats are logged in their respective functions)
if not (hasattr(message, 'type') and message.type == 'stateReport'):
logger.info(f"[TX → Backend] {json_message}")
except Exception as e:
logger.error(f"Failed to send WebSocket message: {e}")
raise
async def _process_streams(self) -> None:
"""
Stream processing task that handles frame processing and detection.
This is a placeholder for Phase 2 - currently just logs that it's running.
"""
logger.info("Stream processing task started")
try:
while self.connected:
# Get current subscriptions
subscriptions = worker_state.get_all_subscriptions()
# TODO: Phase 2 - Add actual frame processing logic here
# This will include:
# - Frame reading from RTSP/HTTP streams
# - Model inference using loaded pipelines
# - Detection result sending via WebSocket
# Sleep to prevent excessive CPU usage (similar to old poll_interval)
await asyncio.sleep(0.1) # 100ms polling interval
except asyncio.CancelledError:
logger.info("Stream processing task cancelled")
except Exception as e:
logger.error(f"Error in stream processing: {e}", exc_info=True)
async def _cleanup(self) -> None:
"""Clean up resources when connection closes."""
logger.info("Cleaning up WebSocket connection")
self.connected = False
# Cancel background tasks
if self._heartbeat_task and not self._heartbeat_task.done():
self._heartbeat_task.cancel()
if self._message_task and not self._message_task.done():
self._message_task.cancel()
# Clear worker state
worker_state.set_subscriptions([])
worker_state.session_ids.clear()
worker_state.progression_stages.clear()
logger.info("WebSocket connection cleanup completed")
# Factory function for FastAPI integration
async def websocket_endpoint(websocket: WebSocket) -> None:
"""
FastAPI WebSocket endpoint handler.
Args:
websocket: FastAPI WebSocket connection
"""
handler = WebSocketHandler(websocket)
await handler.handle_connection()

View file

@ -1,10 +0,0 @@
"""
Detection module for the Python Detector Worker.
This module provides the main detection pipeline orchestration and parallel branch processing
for advanced computer vision detection systems.
"""
from .pipeline import DetectionPipeline
from .branches import BranchProcessor
__all__ = ['DetectionPipeline', 'BranchProcessor']

View file

@ -1,936 +0,0 @@
"""
Parallel Branch Processing Module.
Handles concurrent execution of classification branches and result synchronization.
"""
import logging
import asyncio
import time
from typing import Dict, List, Optional, Any, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import cv2
from ..models.inference import YOLOWrapper
logger = logging.getLogger(__name__)
class BranchProcessor:
"""
Handles parallel processing of classification branches.
Manages branch synchronization and result collection.
"""
def __init__(self, model_manager: Any, model_id: int):
"""
Initialize branch processor.
Args:
model_manager: Model manager for loading models
model_id: The model ID to use for loading models
"""
self.model_manager = model_manager
self.model_id = model_id
# Branch models cache
self.branch_models: Dict[str, YOLOWrapper] = {}
# Dynamic field mapping: branch_id → output_field_name (e.g., {"car_brand_cls_v3": "brand"})
self.branch_output_fields: Dict[str, str] = {}
# Thread pool for parallel execution
self.executor = ThreadPoolExecutor(max_workers=4)
# Storage managers (set during initialization)
self.redis_manager = None
# self.db_manager = None # Disabled - PostgreSQL operations moved to microservices
# Branch execution timeout (seconds)
self.branch_timeout = 30.0
# Statistics
self.stats = {
'branches_processed': 0,
'parallel_executions': 0,
'total_processing_time': 0.0,
'models_loaded': 0,
'branches_timed_out': 0,
'branches_failed': 0
}
logger.info("BranchProcessor initialized")
async def initialize(self, pipeline_config: Any, redis_manager: Any, db_manager: Any = None) -> bool:
"""
Initialize branch processor with pipeline configuration.
Args:
pipeline_config: Pipeline configuration object
redis_manager: Redis manager instance
db_manager: Database manager instance (deprecated, not used)
Returns:
True if successful, False otherwise
"""
try:
self.redis_manager = redis_manager
# self.db_manager = db_manager # Disabled - PostgreSQL operations moved to microservices
# Parse field mappings from parallelActions to enable dynamic field extraction
self._parse_branch_output_fields(pipeline_config)
# Pre-load branch models if they exist
branches = getattr(pipeline_config, 'branches', [])
if branches:
await self._preload_branch_models(branches)
logger.info(f"BranchProcessor initialized with {len(self.branch_models)} models")
return True
except Exception as e:
logger.error(f"Error initializing branch processor: {e}", exc_info=True)
return False
async def _preload_branch_models(self, branches: List[Any]) -> None:
"""
Pre-load all branch models for faster execution.
Args:
branches: List of branch configurations
"""
for branch in branches:
try:
await self._load_branch_model(branch)
# Recursively load nested branches
nested_branches = getattr(branch, 'branches', [])
if nested_branches:
await self._preload_branch_models(nested_branches)
except Exception as e:
logger.error(f"Error preloading branch model {getattr(branch, 'model_id', 'unknown')}: {e}")
async def _load_branch_model(self, branch_config: Any) -> Optional[YOLOWrapper]:
"""
Load a branch model if not already loaded.
Args:
branch_config: Branch configuration object
Returns:
Loaded YOLO model wrapper or None
"""
try:
model_id = getattr(branch_config, 'model_id', None)
model_file = getattr(branch_config, 'model_file', None)
if not model_id or not model_file:
logger.warning(f"Invalid branch config: model_id={model_id}, model_file={model_file}")
return None
# Check if model is already loaded
if model_id in self.branch_models:
logger.debug(f"Branch model {model_id} already loaded")
return self.branch_models[model_id]
# Load model
logger.info(f"Loading branch model: {model_id} ({model_file})")
# Load model using the proper model ID
model = self.model_manager.get_yolo_model(self.model_id, model_file)
if model:
self.branch_models[model_id] = model
self.stats['models_loaded'] += 1
logger.info(f"Branch model {model_id} loaded successfully")
return model
else:
logger.error(f"Failed to load branch model {model_id}")
return None
except Exception as e:
logger.error(f"Error loading branch model {getattr(branch_config, 'model_id', 'unknown')}: {e}")
return None
def _parse_branch_output_fields(self, pipeline_config: Any) -> None:
"""
Parse parallelActions.fields to determine what output field each branch produces.
Creates dynamic mapping from branch_id to output field name.
Example:
Input: parallelActions.fields = {"car_brand": "{car_brand_cls_v3.brand}"}
Output: self.branch_output_fields = {"car_brand_cls_v3": "brand"}
Args:
pipeline_config: Pipeline configuration object
"""
try:
if not pipeline_config or not hasattr(pipeline_config, 'parallel_actions'):
logger.debug("[FIELD MAPPING] No parallelActions found in pipeline config")
return
for action in pipeline_config.parallel_actions:
# Skip PostgreSQL actions - they are disabled
if action.type.value == 'postgresql_update_combined':
logger.debug(f"[FIELD MAPPING] Skipping PostgreSQL action (disabled)")
continue # Skip field parsing for disabled PostgreSQL operations
# fields = action.params.get('fields', {})
#
# # Parse each field template to extract branch_id and field_name
# for db_field_name, template in fields.items():
# # Template format: "{branch_id.field_name}"
# if template.startswith('{') and template.endswith('}'):
# var_name = template[1:-1] # Remove { }
#
# if '.' in var_name:
# branch_id, field_name = var_name.split('.', 1)
#
# # Store the mapping
# self.branch_output_fields[branch_id] = field_name
#
# logger.info(f"[FIELD MAPPING] Branch '{branch_id}' → outputs field '{field_name}'")
logger.info(f"[FIELD MAPPING] Parsed {len(self.branch_output_fields)} branch output field mappings")
except Exception as e:
logger.error(f"[FIELD MAPPING] Error parsing branch output fields: {e}", exc_info=True)
async def execute_branches(self,
frame: np.ndarray,
branches: List[Any],
detected_regions: Dict[str, Any],
detection_context: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute all branches in parallel and collect results.
Args:
frame: Input frame
branches: List of branch configurations
detected_regions: Dictionary of detected regions from main detection
detection_context: Detection context data
Returns:
Dictionary with branch execution results
"""
start_time = time.time()
branch_results = {}
try:
# Separate parallel and sequential branches
parallel_branches = []
sequential_branches = []
for branch in branches:
if getattr(branch, 'parallel', False):
parallel_branches.append(branch)
else:
sequential_branches.append(branch)
# Execute parallel branches concurrently
if parallel_branches:
logger.info(f"Executing {len(parallel_branches)} branches in parallel")
parallel_results = await self._execute_parallel_branches(
frame, parallel_branches, detected_regions, detection_context
)
branch_results.update(parallel_results)
self.stats['parallel_executions'] += 1
# Execute sequential branches one by one
if sequential_branches:
logger.info(f"Executing {len(sequential_branches)} branches sequentially")
sequential_results = await self._execute_sequential_branches(
frame, sequential_branches, detected_regions, detection_context
)
branch_results.update(sequential_results)
# Update statistics
self.stats['branches_processed'] += len(branches)
processing_time = time.time() - start_time
self.stats['total_processing_time'] += processing_time
logger.info(f"Branch execution completed in {processing_time:.3f}s with {len(branch_results)} results")
except Exception as e:
logger.error(f"Error in branch execution: {e}", exc_info=True)
return branch_results
async def _execute_parallel_branches(self,
frame: np.ndarray,
branches: List[Any],
detected_regions: Dict[str, Any],
detection_context: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute branches in parallel using ThreadPoolExecutor.
Args:
frame: Input frame
branches: List of parallel branch configurations
detected_regions: Dictionary of detected regions
detection_context: Detection context data
Returns:
Dictionary with parallel branch results
"""
results = {}
# Submit all branches for parallel execution
future_to_branch = {}
for branch in branches:
branch_id = getattr(branch, 'model_id', 'unknown')
logger.info(f"[PARALLEL SUBMIT] {branch_id}: Submitting branch to thread pool")
future = self.executor.submit(
self._execute_single_branch_sync,
frame, branch, detected_regions, detection_context
)
future_to_branch[future] = branch
# Collect results as they complete with timeout
try:
for future in as_completed(future_to_branch, timeout=self.branch_timeout):
branch = future_to_branch[future]
branch_id = getattr(branch, 'model_id', 'unknown')
try:
# Get result with timeout to prevent indefinite hanging
result = future.result(timeout=self.branch_timeout)
results[branch_id] = result
logger.info(f"[PARALLEL COMPLETE] {branch_id}: Branch completed successfully")
except TimeoutError:
logger.error(f"[TIMEOUT] Branch {branch_id} exceeded timeout of {self.branch_timeout}s")
self.stats['branches_timed_out'] += 1
results[branch_id] = {
'status': 'timeout',
'message': f'Branch execution timeout after {self.branch_timeout}s',
'processing_time': self.branch_timeout
}
except Exception as e:
logger.error(f"[ERROR] Error in parallel branch {branch_id}: {e}", exc_info=True)
self.stats['branches_failed'] += 1
results[branch_id] = {
'status': 'error',
'message': str(e),
'processing_time': 0.0
}
except TimeoutError:
# as_completed iterator timed out - mark remaining futures as timed out
logger.error(f"[TIMEOUT] Branch execution timeout after {self.branch_timeout}s - some branches did not complete")
for future, branch in future_to_branch.items():
branch_id = getattr(branch, 'model_id', 'unknown')
if branch_id not in results:
logger.error(f"[TIMEOUT] Branch {branch_id} did not complete within timeout")
self.stats['branches_timed_out'] += 1
results[branch_id] = {
'status': 'timeout',
'message': f'Branch did not complete within {self.branch_timeout}s timeout',
'processing_time': self.branch_timeout
}
# Flatten nested branch results to top level for database access
flattened_results = {}
for branch_id, branch_result in results.items():
# Add the branch result itself
flattened_results[branch_id] = branch_result
# If this branch has nested branches, add them to the top level too
if isinstance(branch_result, dict) and 'nested_branches' in branch_result:
nested_branches = branch_result['nested_branches']
for nested_branch_id, nested_result in nested_branches.items():
flattened_results[nested_branch_id] = nested_result
logger.info(f"[FLATTEN] Added nested branch {nested_branch_id} to top-level results")
# Log summary of branch execution results
succeeded = [bid for bid, res in results.items() if res.get('status') == 'success']
failed = [bid for bid, res in results.items() if res.get('status') == 'error']
timed_out = [bid for bid, res in results.items() if res.get('status') == 'timeout']
skipped = [bid for bid, res in results.items() if res.get('status') == 'skipped']
summary_parts = []
if succeeded:
summary_parts.append(f"{len(succeeded)} succeeded: {', '.join(succeeded)}")
if failed:
summary_parts.append(f"{len(failed)} FAILED: {', '.join(failed)}")
if timed_out:
summary_parts.append(f"{len(timed_out)} TIMED OUT: {', '.join(timed_out)}")
if skipped:
summary_parts.append(f"{len(skipped)} skipped: {', '.join(skipped)}")
logger.info(f"[PARALLEL SUMMARY] Branch execution completed: {' | '.join(summary_parts) if summary_parts else 'no branches'}")
return flattened_results
async def _execute_sequential_branches(self,
frame: np.ndarray,
branches: List[Any],
detected_regions: Dict[str, Any],
detection_context: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute branches sequentially.
Args:
frame: Input frame
branches: List of sequential branch configurations
detected_regions: Dictionary of detected regions
detection_context: Detection context data
Returns:
Dictionary with sequential branch results
"""
results = {}
for branch in branches:
branch_id = getattr(branch, 'model_id', 'unknown')
try:
result = await asyncio.get_event_loop().run_in_executor(
self.executor,
self._execute_single_branch_sync,
frame, branch, detected_regions, detection_context
)
results[branch_id] = result
logger.debug(f"Sequential branch {branch_id} completed successfully")
except Exception as e:
logger.error(f"Error in sequential branch {branch_id}: {e}")
results[branch_id] = {
'status': 'error',
'message': str(e),
'processing_time': 0.0
}
# Flatten nested branch results to top level for database access
flattened_results = {}
for branch_id, branch_result in results.items():
# Add the branch result itself
flattened_results[branch_id] = branch_result
# If this branch has nested branches, add them to the top level too
if isinstance(branch_result, dict) and 'nested_branches' in branch_result:
nested_branches = branch_result['nested_branches']
for nested_branch_id, nested_result in nested_branches.items():
flattened_results[nested_branch_id] = nested_result
logger.info(f"[FLATTEN] Added nested branch {nested_branch_id} to top-level results")
return flattened_results
def _execute_single_branch_sync(self,
frame: np.ndarray,
branch_config: Any,
detected_regions: Dict[str, Any],
detection_context: Dict[str, Any]) -> Dict[str, Any]:
"""
Synchronous execution of a single branch (for ThreadPoolExecutor).
Args:
frame: Input frame
branch_config: Branch configuration object
detected_regions: Dictionary of detected regions
detection_context: Detection context data
Returns:
Dictionary with branch execution result
"""
start_time = time.time()
branch_id = getattr(branch_config, 'model_id', 'unknown')
logger.info(f"[BRANCH START] {branch_id}: Starting branch execution")
logger.debug(f"[BRANCH CONFIG] {branch_id}: crop={getattr(branch_config, 'crop', False)}, "
f"trigger_classes={getattr(branch_config, 'trigger_classes', [])}, "
f"min_confidence={getattr(branch_config, 'min_confidence', 0.6)}")
# Check if branch should execute based on triggerClasses (execution conditions)
trigger_classes = getattr(branch_config, 'trigger_classes', [])
logger.info(f"[DETECTED REGIONS] {branch_id}: Available parent detections: {list(detected_regions.keys())}")
for region_name, region_data in detected_regions.items():
# Handle both list (new) and single dict (backward compat)
if isinstance(region_data, list):
for i, region in enumerate(region_data):
logger.debug(f"[REGION DATA] {branch_id}: '{region_name}[{i}]' -> bbox={region.get('bbox')}, conf={region.get('confidence')}")
else:
logger.debug(f"[REGION DATA] {branch_id}: '{region_name}' -> bbox={region_data.get('bbox')}, conf={region_data.get('confidence')}")
if trigger_classes:
# Check if any parent detection matches our trigger classes (case-insensitive)
should_execute = False
for trigger_class in trigger_classes:
# Case-insensitive comparison for robustness
if trigger_class.lower() in [k.lower() for k in detected_regions.keys()]:
should_execute = True
logger.info(f"[TRIGGER CHECK] {branch_id}: Found '{trigger_class}' in parent detections - branch will execute")
break
if not should_execute:
logger.warning(f"[TRIGGER CHECK] {branch_id}: None of trigger classes {trigger_classes} found in parent detections {list(detected_regions.keys())} - skipping branch")
return {
'status': 'skipped',
'branch_id': branch_id,
'message': f'No trigger classes {trigger_classes} found in parent detections',
'processing_time': time.time() - start_time
}
result = {
'status': 'success',
'branch_id': branch_id,
'result': {},
'processing_time': 0.0,
'timestamp': time.time()
}
try:
# Get or load branch model
if branch_id not in self.branch_models:
logger.warning(f"Branch model {branch_id} not preloaded, loading now...")
# This should be rare since models are preloaded
return {
'status': 'error',
'message': f'Branch model {branch_id} not available',
'processing_time': time.time() - start_time
}
model = self.branch_models[branch_id]
# Get configuration values first
min_confidence = getattr(branch_config, 'min_confidence', 0.6)
# Prepare input frame for this branch
input_frame = frame
# Handle cropping if required - use biggest bbox that passes min_confidence
if getattr(branch_config, 'crop', False):
crop_classes = getattr(branch_config, 'crop_class', [])
if isinstance(crop_classes, str):
crop_classes = [crop_classes]
# Find the biggest bbox that passes min_confidence threshold
best_region = None
best_class = None
best_area = 0.0
for crop_class in crop_classes:
if crop_class in detected_regions:
regions = detected_regions[crop_class]
# Handle both list (new) and single dict (backward compat)
if not isinstance(regions, list):
regions = [regions]
# Find largest bbox from all detections of this class
for region in regions:
confidence = region.get('confidence', 0.0)
bbox = region['bbox']
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) # width * height
# Choose biggest bbox among all available detections
if area > best_area:
best_region = region
best_class = crop_class
best_area = area
logger.debug(f"[CROP] Selected larger bbox for '{crop_class}': area={area:.0f}px², conf={confidence:.3f}")
if best_region:
bbox = best_region['bbox']
x1, y1, x2, y2 = [int(coord) for coord in bbox]
cropped = frame[y1:y2, x1:x2]
if cropped.size > 0:
input_frame = cropped
confidence = best_region.get('confidence', 0.0)
logger.info(f"[CROP SUCCESS] {branch_id}: cropped '{best_class}' region (conf={confidence:.3f}, area={int(best_area)}) -> shape={cropped.shape}")
else:
logger.warning(f"Branch {branch_id}: empty crop, using full frame")
else:
logger.warning(f"Branch {branch_id}: no valid crop regions found (min_conf={min_confidence})")
logger.info(f"[INFERENCE START] {branch_id}: Running inference on {'cropped' if input_frame is not frame else 'full'} frame "
f"({input_frame.shape[1]}x{input_frame.shape[0]}) with confidence={min_confidence}")
# Use .predict() method for both detection and classification models
inference_start = time.time()
try:
detection_results = model.model.predict(input_frame, conf=min_confidence, verbose=False)
inference_time = time.time() - inference_start
logger.info(f"[INFERENCE DONE] {branch_id}: Predict completed in {inference_time:.3f}s using .predict() method")
except Exception as inference_error:
inference_time = time.time() - inference_start
logger.error(f"[INFERENCE ERROR] {branch_id}: Model inference failed after {inference_time:.3f}s: {inference_error}", exc_info=True)
return {
'status': 'error',
'branch_id': branch_id,
'message': f'Model inference failed: {str(inference_error)}',
'processing_time': time.time() - start_time
}
# Initialize branch_detections outside the conditional
branch_detections = []
# Process results using clean, unified logic
if detection_results and len(detection_results) > 0:
result_obj = detection_results[0]
# Handle detection models (have .boxes attribute)
if hasattr(result_obj, 'boxes') and result_obj.boxes is not None:
logger.info(f"[RAW DETECTIONS] {branch_id}: Found {len(result_obj.boxes)} raw detections")
for i, box in enumerate(result_obj.boxes):
class_id = int(box.cls[0])
confidence = float(box.conf[0])
bbox = box.xyxy[0].cpu().numpy().tolist() # [x1, y1, x2, y2]
class_name = model.model.names[class_id]
logger.debug(f"[RAW DETECTION {i+1}] {branch_id}: '{class_name}', conf={confidence:.3f}")
# All detections are included - no filtering by trigger_classes here
branch_detections.append({
'class_name': class_name,
'confidence': confidence,
'bbox': bbox
})
# Handle classification models (have .probs attribute)
elif hasattr(result_obj, 'probs') and result_obj.probs is not None:
logger.info(f"[RAW CLASSIFICATION] {branch_id}: Processing classification results")
probs = result_obj.probs
top_indices = probs.top5 # Get top 5 predictions
top_conf = probs.top5conf.cpu().numpy()
# For classification: take only TOP-1 prediction (not all top-5)
# This prevents empty results when all top-5 predictions are below threshold
if len(top_indices) > 0 and len(top_conf) > 0:
top_idx = top_indices[0]
top_confidence = float(top_conf[0])
# Apply minConfidence threshold to top-1 only
if top_confidence >= min_confidence:
class_name = model.model.names[int(top_idx)]
logger.info(f"[CLASSIFICATION TOP-1] {branch_id}: '{class_name}', conf={top_confidence:.3f}")
# For classification, use full input frame dimensions as bbox
branch_detections.append({
'class_name': class_name,
'confidence': top_confidence,
'bbox': [0, 0, input_frame.shape[1], input_frame.shape[0]]
})
else:
logger.warning(f"[CLASSIFICATION FILTERED] {branch_id}: Top prediction conf={top_confidence:.3f} < threshold={min_confidence}")
else:
logger.warning(f"[UNKNOWN MODEL] {branch_id}: Model results have no .boxes or .probs")
result['result'] = {
'detections': branch_detections,
'detection_count': len(branch_detections)
}
logger.info(f"[FINAL RESULTS] {branch_id}: {len(branch_detections)} detections processed")
# Determine output field name from dynamic mapping (parsed from parallelActions.fields)
output_field = self.branch_output_fields.get(branch_id)
# Always initialize the field (even if None) to ensure it exists for database update
if output_field:
result['result'][output_field] = None
logger.debug(f"[FIELD INIT] {branch_id}: Initialized field '{output_field}' = None")
# Extract best detection if available
if branch_detections:
best_detection = max(branch_detections, key=lambda x: x['confidence'])
logger.info(f"[BEST DETECTION] {branch_id}: '{best_detection['class_name']}' with confidence {best_detection['confidence']:.3f}")
# Set the output field value using dynamic mapping
if output_field:
result['result'][output_field] = best_detection['class_name']
logger.info(f"[FIELD SET] {branch_id}: Set field '{output_field}' = '{best_detection['class_name']}'")
else:
logger.warning(f"[NO MAPPING] {branch_id}: No output field defined in parallelActions.fields")
else:
logger.warning(f"[NO RESULTS] {branch_id}: No detections found, field '{output_field}' remains None")
# Execute branch actions if this branch found valid detections
actions_executed = []
branch_actions = getattr(branch_config, 'actions', [])
if branch_actions and branch_detections:
logger.info(f"[BRANCH ACTIONS] {branch_id}: Executing {len(branch_actions)} actions")
# Create detected_regions from THIS branch's detections for actions
branch_detected_regions = {}
for detection in branch_detections:
branch_detected_regions[detection['class_name']] = {
'bbox': detection['bbox'],
'confidence': detection['confidence']
}
for action in branch_actions:
try:
action_type = action.type.value # Access the enum value
logger.info(f"[ACTION EXECUTE] {branch_id}: Executing action '{action_type}'")
if action_type == 'redis_save_image':
action_result = self._execute_redis_save_image_sync(
action, input_frame, branch_detected_regions, detection_context
)
elif action_type == 'redis_publish':
action_result = self._execute_redis_publish_sync(
action, detection_context
)
else:
logger.warning(f"[ACTION UNKNOWN] {branch_id}: Unknown action type '{action_type}'")
action_result = {'status': 'error', 'message': f'Unknown action type: {action_type}'}
actions_executed.append({
'action_type': action_type,
'result': action_result
})
logger.info(f"[ACTION COMPLETE] {branch_id}: Action '{action_type}' result: {action_result.get('status')}")
except Exception as e:
action_type = getattr(action, 'type', None)
if action_type:
action_type = action_type.value if hasattr(action_type, 'value') else str(action_type)
logger.error(f"[ACTION ERROR] {branch_id}: Error executing action '{action_type}': {e}", exc_info=True)
actions_executed.append({
'action_type': action_type,
'result': {'status': 'error', 'message': str(e)}
})
# Add actions executed to result
if actions_executed:
result['actions_executed'] = actions_executed
# Handle nested branches ONLY if parent found valid detections
nested_branches = getattr(branch_config, 'branches', [])
if nested_branches:
# Check if parent branch found any valid detections
if not branch_detections:
logger.warning(f"[BRANCH SKIP] {branch_id}: Skipping {len(nested_branches)} nested branches - parent found no valid detections")
else:
logger.debug(f"Branch {branch_id}: executing {len(nested_branches)} nested branches")
# Create detected_regions from THIS branch's detections for nested branches
# Nested branches should see their immediate parent's detections, not the root pipeline
nested_detected_regions = {}
for detection in branch_detections:
nested_detected_regions[detection['class_name']] = {
'bbox': detection['bbox'],
'confidence': detection['confidence']
}
logger.info(f"[NESTED REGIONS] {branch_id}: Passing {list(nested_detected_regions.keys())} to nested branches")
# Note: For simplicity, nested branches are executed sequentially in this sync method
# In a full async implementation, these could also be parallelized
nested_results = {}
for nested_branch in nested_branches:
nested_result = self._execute_single_branch_sync(
input_frame, nested_branch, nested_detected_regions, detection_context
)
nested_branch_id = getattr(nested_branch, 'model_id', 'unknown')
nested_results[nested_branch_id] = nested_result
result['nested_branches'] = nested_results
except Exception as e:
logger.error(f"[BRANCH ERROR] {branch_id}: Error in execution: {e}", exc_info=True)
result['status'] = 'error'
result['message'] = str(e)
result['processing_time'] = time.time() - start_time
# Summary log
logger.info(f"[BRANCH COMPLETE] {branch_id}: status={result['status']}, "
f"processing_time={result['processing_time']:.3f}s, "
f"result_keys={list(result['result'].keys()) if result['result'] else 'none'}")
return result
def _execute_redis_save_image_sync(self,
action: Dict,
frame: np.ndarray,
detected_regions: Dict[str, Any],
context: Dict[str, Any]) -> Dict[str, Any]:
"""Execute redis_save_image action synchronously."""
if not self.redis_manager:
return {'status': 'error', 'message': 'Redis not available'}
try:
# Get image to save (cropped or full frame)
image_to_save = frame
region_name = action.params.get('region')
bbox = None
if region_name and region_name in detected_regions:
# Crop the specified region
# Handle both list (new) and single dict (backward compat)
regions = detected_regions[region_name]
if isinstance(regions, list):
# Multiple detections - select largest bbox
if regions:
best_region = max(regions, key=lambda r: (r['bbox'][2] - r['bbox'][0]) * (r['bbox'][3] - r['bbox'][1]))
bbox = best_region['bbox']
else:
bbox = regions['bbox']
elif region_name and region_name.lower() == 'frontal' and 'front_rear' in detected_regions:
# Special case: "frontal" region maps to "front_rear" detection
# Handle both list (new) and single dict (backward compat)
regions = detected_regions['front_rear']
if isinstance(regions, list):
# Multiple detections - select largest bbox
if regions:
best_region = max(regions, key=lambda r: (r['bbox'][2] - r['bbox'][0]) * (r['bbox'][3] - r['bbox'][1]))
bbox = best_region['bbox']
else:
bbox = regions['bbox']
if bbox is not None:
x1, y1, x2, y2 = [int(coord) for coord in bbox]
cropped = frame[y1:y2, x1:x2]
if cropped.size > 0:
image_to_save = cropped
logger.debug(f"Cropped region '{region_name}' for redis_save_image")
else:
logger.warning(f"Empty crop for region '{region_name}', using full frame")
# Format key with context
key = action.params['key'].format(**context)
# Convert image to bytes
import cv2
image_format = action.params.get('format', 'jpeg')
quality = action.params.get('quality', 90)
if image_format.lower() == 'jpeg':
encode_param = [cv2.IMWRITE_JPEG_QUALITY, quality]
_, image_bytes = cv2.imencode('.jpg', image_to_save, encode_param)
else:
_, image_bytes = cv2.imencode('.png', image_to_save)
# Save to Redis synchronously using a sync Redis client
try:
import redis
import cv2
# Create a synchronous Redis client with same connection details
sync_redis = redis.Redis(
host=self.redis_manager.host,
port=self.redis_manager.port,
password=self.redis_manager.password,
db=self.redis_manager.db,
decode_responses=False, # We're storing binary data
socket_timeout=self.redis_manager.socket_timeout,
socket_connect_timeout=self.redis_manager.socket_connect_timeout
)
# Encode the image
if image_format.lower() == 'jpeg':
encode_param = [cv2.IMWRITE_JPEG_QUALITY, quality]
success, encoded_image = cv2.imencode('.jpg', image_to_save, encode_param)
else:
success, encoded_image = cv2.imencode('.png', image_to_save)
if not success:
return {'status': 'error', 'message': 'Failed to encode image'}
# Save to Redis with expiration
expire_seconds = action.params.get('expire_seconds', 600)
result = sync_redis.setex(key, expire_seconds, encoded_image.tobytes())
sync_redis.close() # Clean up connection
if result:
# Add image_key to context for subsequent actions
context['image_key'] = key
return {'status': 'success', 'key': key}
else:
return {'status': 'error', 'message': 'Failed to save image to Redis'}
except Exception as redis_error:
logger.error(f"Error calling Redis from sync context: {redis_error}")
return {'status': 'error', 'message': f'Redis operation failed: {redis_error}'}
except Exception as e:
logger.error(f"Error in redis_save_image action: {e}", exc_info=True)
return {'status': 'error', 'message': str(e)}
def _execute_redis_publish_sync(self, action: Dict, context: Dict[str, Any]) -> Dict[str, Any]:
"""Execute redis_publish action synchronously."""
if not self.redis_manager:
return {'status': 'error', 'message': 'Redis not available'}
try:
channel = action.params['channel']
message_template = action.params['message']
# Debug the message template
logger.debug(f"Message template: {repr(message_template)}")
logger.debug(f"Context keys: {list(context.keys())}")
# Format message with context - handle JSON string formatting carefully
# The message template contains JSON which causes issues with .format()
# Use string replacement instead of format to avoid JSON brace conflicts
try:
# Ensure image_key is available for message formatting
if 'image_key' not in context:
context['image_key'] = '' # Default empty value if redis_save_image failed
# Use string replacement to avoid JSON formatting issues
message = message_template
for key, value in context.items():
placeholder = '{' + key + '}'
message = message.replace(placeholder, str(value))
logger.debug(f"Formatted message using replacement: {message}")
except Exception as e:
logger.error(f"Message formatting failed: {e}")
logger.error(f"Template: {repr(message_template)}")
logger.error(f"Context: {context}")
return {'status': 'error', 'message': f'Message formatting failed: {e}'}
# Publish message synchronously using a sync Redis client
try:
import redis
# Create a synchronous Redis client with same connection details
sync_redis = redis.Redis(
host=self.redis_manager.host,
port=self.redis_manager.port,
password=self.redis_manager.password,
db=self.redis_manager.db,
decode_responses=True, # For publishing text messages
socket_timeout=self.redis_manager.socket_timeout,
socket_connect_timeout=self.redis_manager.socket_connect_timeout
)
# Publish message
result = sync_redis.publish(channel, message)
sync_redis.close() # Clean up connection
if result >= 0: # Redis publish returns number of subscribers
return {'status': 'success', 'subscribers': result, 'channel': channel}
else:
return {'status': 'error', 'message': 'Failed to publish message to Redis'}
except Exception as redis_error:
logger.error(f"Error calling Redis from sync context: {redis_error}")
return {'status': 'error', 'message': f'Redis operation failed: {redis_error}'}
except Exception as e:
logger.error(f"Error in redis_publish action: {e}", exc_info=True)
return {'status': 'error', 'message': str(e)}
def get_statistics(self) -> Dict[str, Any]:
"""Get branch processor statistics."""
return {
**self.stats,
'loaded_models': list(self.branch_models.keys()),
'model_count': len(self.branch_models)
}
def cleanup(self):
"""Cleanup resources."""
if self.executor:
self.executor.shutdown(wait=False)
# Clear model cache
self.branch_models.clear()
logger.info("BranchProcessor cleaned up")

File diff suppressed because it is too large Load diff

View file

@ -1,42 +0,0 @@
"""
Models Module - MPTA management, pipeline configuration, and YOLO inference
"""
from .manager import ModelManager
from .pipeline import (
PipelineParser,
PipelineConfig,
TrackingConfig,
ModelBranch,
Action,
ActionType,
RedisConfig,
# PostgreSQLConfig # Disabled - moved to microservices
)
from .inference import (
YOLOWrapper,
ModelInferenceManager,
Detection,
InferenceResult
)
__all__ = [
# Manager
'ModelManager',
# Pipeline
'PipelineParser',
'PipelineConfig',
'TrackingConfig',
'ModelBranch',
'Action',
'ActionType',
'RedisConfig',
# 'PostgreSQLConfig', # Disabled - moved to microservices
# Inference
'YOLOWrapper',
'ModelInferenceManager',
'Detection',
'InferenceResult',
]

View file

@ -1,447 +0,0 @@
"""
YOLO Model Inference Wrapper - Handles model loading and inference optimization
"""
import logging
import torch
import numpy as np
from pathlib import Path
from typing import Dict, List, Optional, Any, Tuple, Union
from threading import Lock
from dataclasses import dataclass
import cv2
logger = logging.getLogger(__name__)
@dataclass
class Detection:
"""Represents a single detection result"""
bbox: List[float] # [x1, y1, x2, y2]
confidence: float
class_id: int
class_name: str
track_id: Optional[int] = None
@dataclass
class InferenceResult:
"""Result from model inference"""
detections: List[Detection]
image_shape: Tuple[int, int] # (height, width)
inference_time: float
model_id: str
class YOLOWrapper:
"""Wrapper for YOLO models with caching and optimization"""
# Class-level model cache shared across all instances
_model_cache: Dict[str, Any] = {}
_cache_lock = Lock()
def __init__(self, model_path: Path, model_id: str, device: Optional[str] = None):
"""
Initialize YOLO wrapper
Args:
model_path: Path to the .pt model file
model_id: Unique identifier for the model
device: Device to run inference on ('cuda', 'cpu', or None for auto)
"""
self.model_path = model_path
self.model_id = model_id
# Auto-detect device if not specified
if device is None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self.device = device
self.model = None
self._class_names = []
self._load_model()
logger.info(f"Initialized YOLO wrapper for {model_id} on {self.device}")
def _load_model(self) -> None:
"""Load the YOLO model with caching"""
cache_key = str(self.model_path)
with self._cache_lock:
# Check if model is already cached
if cache_key in self._model_cache:
logger.info(f"Loading model {self.model_id} from cache")
self.model = self._model_cache[cache_key]
self._extract_class_names()
return
# Load model
try:
from ultralytics import YOLO
logger.info(f"Loading YOLO model from {self.model_path}")
self.model = YOLO(str(self.model_path))
# Move model to device
if self.device == 'cuda' and torch.cuda.is_available():
self.model.to('cuda')
logger.info(f"Model {self.model_id} moved to GPU")
# Cache the model
self._model_cache[cache_key] = self.model
self._extract_class_names()
logger.info(f"Successfully loaded model {self.model_id}")
except ImportError:
logger.error("Ultralytics YOLO not installed. Install with: pip install ultralytics")
raise
except Exception as e:
logger.error(f"Failed to load YOLO model {self.model_id}: {str(e)}", exc_info=True)
raise
def _extract_class_names(self) -> None:
"""Extract class names from the model"""
try:
if hasattr(self.model, 'names'):
self._class_names = self.model.names
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'names'):
self._class_names = self.model.model.names
else:
logger.warning(f"Could not extract class names from model {self.model_id}")
self._class_names = {}
except Exception as e:
logger.error(f"Failed to extract class names: {str(e)}")
self._class_names = {}
def infer(
self,
image: np.ndarray,
confidence_threshold: float = 0.5,
trigger_classes: Optional[List[str]] = None,
iou_threshold: float = 0.45
) -> InferenceResult:
"""
Run inference on an image
Args:
image: Input image as numpy array (BGR format)
confidence_threshold: Minimum confidence for detections
trigger_classes: List of class names to filter (None = all classes)
iou_threshold: IoU threshold for NMS
Returns:
InferenceResult containing detections
"""
if self.model is None:
raise RuntimeError(f"Model {self.model_id} not loaded")
try:
import time
start_time = time.time()
# Run inference
results = self.model(
image,
conf=confidence_threshold,
iou=iou_threshold,
verbose=False
)
inference_time = time.time() - start_time
# Parse results
detections = self._parse_results(results[0], trigger_classes)
return InferenceResult(
detections=detections,
image_shape=(image.shape[0], image.shape[1]),
inference_time=inference_time,
model_id=self.model_id
)
except Exception as e:
logger.error(f"Inference failed for model {self.model_id}: {str(e)}", exc_info=True)
raise
def _parse_results(
self,
result: Any,
trigger_classes: Optional[List[str]] = None
) -> List[Detection]:
"""
Parse YOLO results into Detection objects
Args:
result: YOLO result object
trigger_classes: Optional list of class names to filter
Returns:
List of Detection objects
"""
detections = []
try:
if result.boxes is None:
return detections
boxes = result.boxes
for i in range(len(boxes)):
# Get box coordinates
box = boxes.xyxy[i].cpu().numpy()
x1, y1, x2, y2 = box
# Get confidence and class
conf = float(boxes.conf[i])
cls_id = int(boxes.cls[i])
# Get class name
class_name = self._class_names.get(cls_id, f"class_{cls_id}")
# Filter by trigger classes if specified
if trigger_classes and class_name not in trigger_classes:
continue
# Get track ID if available
track_id = None
if hasattr(boxes, 'id') and boxes.id is not None:
track_id = int(boxes.id[i])
detection = Detection(
bbox=[float(x1), float(y1), float(x2), float(y2)],
confidence=conf,
class_id=cls_id,
class_name=class_name,
track_id=track_id
)
detections.append(detection)
except Exception as e:
logger.error(f"Failed to parse results: {str(e)}", exc_info=True)
return detections
def track(
self,
image: np.ndarray,
confidence_threshold: float = 0.5,
trigger_classes: Optional[List[str]] = None,
persist: bool = True,
camera_id: Optional[str] = None
) -> InferenceResult:
"""
Run detection (tracking will be handled by external tracker)
Args:
image: Input image as numpy array (BGR format)
confidence_threshold: Minimum confidence for detections
trigger_classes: List of class names to filter
persist: Ignored - tracking handled externally
camera_id: Ignored - tracking handled externally
Returns:
InferenceResult containing detections (no track IDs from YOLO)
"""
# Just do detection - no YOLO tracking
return self.infer(image, confidence_threshold, trigger_classes)
def predict_classification(
self,
image: np.ndarray,
top_k: int = 1
) -> Dict[str, float]:
"""
Run classification on an image
Args:
image: Input image as numpy array (BGR format)
top_k: Number of top predictions to return
Returns:
Dictionary of class_name -> confidence scores
"""
if self.model is None:
raise RuntimeError(f"Model {self.model_id} not loaded")
try:
# Run inference
results = self.model(image, verbose=False)
# For classification models, extract probabilities
if hasattr(results[0], 'probs'):
probs = results[0].probs
top_indices = probs.top5[:top_k]
top_conf = probs.top5conf[:top_k].cpu().numpy()
predictions = {}
for idx, conf in zip(top_indices, top_conf):
class_name = self._class_names.get(int(idx), f"class_{idx}")
predictions[class_name] = float(conf)
return predictions
else:
logger.warning(f"Model {self.model_id} does not support classification")
return {}
except Exception as e:
logger.error(f"Classification failed for model {self.model_id}: {str(e)}", exc_info=True)
raise
def crop_detection(
self,
image: np.ndarray,
detection: Detection,
padding: int = 0
) -> np.ndarray:
"""
Crop image to detection bounding box
Args:
image: Original image
detection: Detection to crop
padding: Additional padding around the box
Returns:
Cropped image region
"""
h, w = image.shape[:2]
x1, y1, x2, y2 = detection.bbox
# Add padding and clip to image boundaries
x1 = max(0, int(x1) - padding)
y1 = max(0, int(y1) - padding)
x2 = min(w, int(x2) + padding)
y2 = min(h, int(y2) + padding)
return image[y1:y2, x1:x2]
def get_class_names(self) -> Dict[int, str]:
"""Get the class names dictionary"""
return self._class_names.copy()
def get_num_classes(self) -> int:
"""Get the number of classes the model can detect"""
return len(self._class_names)
def clear_cache(self) -> None:
"""Clear the model cache"""
with self._cache_lock:
cache_key = str(self.model_path)
if cache_key in self._model_cache:
del self._model_cache[cache_key]
logger.info(f"Cleared cache for model {self.model_id}")
@classmethod
def clear_all_cache(cls) -> None:
"""Clear all cached models"""
with cls._cache_lock:
cls._model_cache.clear()
logger.info("Cleared all model cache")
def warmup(self, image_size: Tuple[int, int] = (640, 640)) -> None:
"""
Warmup the model with a dummy inference
Args:
image_size: Size of dummy image (height, width)
"""
try:
dummy_image = np.zeros((image_size[0], image_size[1], 3), dtype=np.uint8)
self.infer(dummy_image, confidence_threshold=0.5)
logger.info(f"Model {self.model_id} warmed up")
except Exception as e:
logger.warning(f"Failed to warmup model {self.model_id}: {str(e)}")
class ModelInferenceManager:
"""Manages multiple YOLO models for a pipeline"""
def __init__(self, model_dir: Path):
"""
Initialize the inference manager
Args:
model_dir: Directory containing model files
"""
self.model_dir = model_dir
self.models: Dict[str, YOLOWrapper] = {}
self._lock = Lock()
logger.info(f"Initialized ModelInferenceManager with model directory: {model_dir}")
def load_model(
self,
model_id: str,
model_file: str,
device: Optional[str] = None
) -> YOLOWrapper:
"""
Load a model for inference
Args:
model_id: Unique identifier for the model
model_file: Filename of the model
device: Device to run on
Returns:
YOLOWrapper instance
"""
with self._lock:
# Check if already loaded
if model_id in self.models:
logger.debug(f"Model {model_id} already loaded")
return self.models[model_id]
# Load the model
model_path = self.model_dir / model_file
if not model_path.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
wrapper = YOLOWrapper(model_path, model_id, device)
self.models[model_id] = wrapper
return wrapper
def get_model(self, model_id: str) -> Optional[YOLOWrapper]:
"""
Get a loaded model
Args:
model_id: Model identifier
Returns:
YOLOWrapper instance or None if not loaded
"""
return self.models.get(model_id)
def unload_model(self, model_id: str) -> bool:
"""
Unload a model to free memory
Args:
model_id: Model identifier
Returns:
True if unloaded, False if not found
"""
with self._lock:
if model_id in self.models:
self.models[model_id].clear_cache()
del self.models[model_id]
logger.info(f"Unloaded model {model_id}")
return True
return False
def unload_all(self) -> None:
"""Unload all models"""
with self._lock:
for model_id in list(self.models.keys()):
self.models[model_id].clear_cache()
self.models.clear()
logger.info("Unloaded all models")

View file

@ -1,439 +0,0 @@
"""
Model Manager Module - Handles MPTA download, extraction, and model loading
"""
import os
import logging
import zipfile
import json
import hashlib
import requests
from pathlib import Path
from typing import Dict, Optional, Any, Set
from threading import Lock
from urllib.parse import urlparse, parse_qs
logger = logging.getLogger(__name__)
class ModelManager:
"""Manages MPTA model downloads, extraction, and caching"""
def __init__(self, models_dir: str = "models"):
"""
Initialize the Model Manager
Args:
models_dir: Base directory for storing models
"""
self.models_dir = Path(models_dir)
self.models_dir.mkdir(parents=True, exist_ok=True)
# Track downloaded models to avoid duplicates
self._downloaded_models: Set[int] = set()
self._model_paths: Dict[int, Path] = {}
self._download_lock = Lock()
# Scan existing models
self._scan_existing_models()
logger.info(f"ModelManager initialized with models directory: {self.models_dir}")
logger.info(f"Found existing models: {list(self._downloaded_models)}")
def _scan_existing_models(self) -> None:
"""Scan the models directory for existing downloaded models"""
if not self.models_dir.exists():
return
for model_dir in self.models_dir.iterdir():
if model_dir.is_dir() and model_dir.name.isdigit():
model_id = int(model_dir.name)
# Check if extraction was successful by looking for pipeline.json
extracted_dirs = list(model_dir.glob("*/pipeline.json"))
if extracted_dirs:
self._downloaded_models.add(model_id)
# Store path to the extracted model directory
self._model_paths[model_id] = extracted_dirs[0].parent
logger.debug(f"Found existing model {model_id} at {extracted_dirs[0].parent}")
def get_model_path(self, model_id: int) -> Optional[Path]:
"""
Get the path to an extracted model directory
Args:
model_id: The model ID
Returns:
Path to the extracted model directory or None if not found
"""
return self._model_paths.get(model_id)
def is_model_downloaded(self, model_id: int) -> bool:
"""
Check if a model has already been downloaded and extracted
Args:
model_id: The model ID to check
Returns:
True if the model is already available
"""
return model_id in self._downloaded_models
def ensure_model(self, model_id: int, model_url: str, model_name: str = None) -> Optional[Path]:
"""
Ensure a model is downloaded and extracted, downloading if necessary
Args:
model_id: The model ID
model_url: URL to download the MPTA file from
model_name: Optional model name for logging
Returns:
Path to the extracted model directory or None if failed
"""
# Check if already downloaded
if self.is_model_downloaded(model_id):
logger.info(f"Model {model_id} already available at {self._model_paths[model_id]}")
return self._model_paths[model_id]
# Download and extract with lock to prevent concurrent downloads of same model
with self._download_lock:
# Double-check after acquiring lock
if self.is_model_downloaded(model_id):
return self._model_paths[model_id]
logger.info(f"Model {model_id} not found locally, downloading from {model_url}")
# Create model directory
model_dir = self.models_dir / str(model_id)
model_dir.mkdir(parents=True, exist_ok=True)
# Extract filename from URL
mpta_filename = self._extract_filename_from_url(model_url, model_name, model_id)
mpta_path = model_dir / mpta_filename
# Download MPTA file
if not self._download_mpta(model_url, mpta_path):
logger.error(f"Failed to download model {model_id}")
return None
# Extract MPTA file
extracted_path = self._extract_mpta(mpta_path, model_dir)
if not extracted_path:
logger.error(f"Failed to extract model {model_id}")
return None
# Mark as downloaded and store path
self._downloaded_models.add(model_id)
self._model_paths[model_id] = extracted_path
logger.info(f"Successfully prepared model {model_id} at {extracted_path}")
return extracted_path
def _extract_filename_from_url(self, url: str, model_name: str = None, model_id: int = None) -> str:
"""
Extract a suitable filename from the URL
Args:
url: The URL to extract filename from
model_name: Optional model name
model_id: Optional model ID
Returns:
A suitable filename for the MPTA file
"""
parsed = urlparse(url)
path = parsed.path
# Try to get filename from path
if path:
filename = os.path.basename(path)
if filename and filename.endswith('.mpta'):
return filename
# Fallback to constructed name
if model_name:
return f"{model_name}-{model_id}.mpta"
else:
return f"model-{model_id}.mpta"
def _download_mpta(self, url: str, dest_path: Path) -> bool:
"""
Download an MPTA file from a URL
Args:
url: URL to download from
dest_path: Destination path for the file
Returns:
True if successful, False otherwise
"""
try:
logger.info(f"Starting download of model from {url}")
logger.debug(f"Download destination: {dest_path}")
response = requests.get(url, stream=True, timeout=300)
if response.status_code != 200:
logger.error(f"Failed to download MPTA file (status {response.status_code})")
return False
file_size = int(response.headers.get('content-length', 0))
logger.info(f"Model file size: {file_size/1024/1024:.2f} MB")
downloaded = 0
last_log_percent = 0
with open(dest_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
downloaded += len(chunk)
# Log progress every 10%
if file_size > 0:
percent = int(downloaded * 100 / file_size)
if percent >= last_log_percent + 10:
logger.debug(f"Download progress: {percent}%")
last_log_percent = percent
logger.info(f"Successfully downloaded MPTA file to {dest_path}")
return True
except requests.RequestException as e:
logger.error(f"Network error downloading MPTA: {str(e)}", exc_info=True)
# Clean up partial download
if dest_path.exists():
dest_path.unlink()
return False
except Exception as e:
logger.error(f"Unexpected error downloading MPTA: {str(e)}", exc_info=True)
# Clean up partial download
if dest_path.exists():
dest_path.unlink()
return False
def _extract_mpta(self, mpta_path: Path, target_dir: Path) -> Optional[Path]:
"""
Extract an MPTA (ZIP) file to the target directory
Args:
mpta_path: Path to the MPTA file
target_dir: Directory to extract to
Returns:
Path to the extracted model directory containing pipeline.json, or None if failed
"""
try:
if not mpta_path.exists():
logger.error(f"MPTA file not found: {mpta_path}")
return None
logger.info(f"Extracting MPTA file from {mpta_path} to {target_dir}")
with zipfile.ZipFile(mpta_path, 'r') as zip_ref:
# Get list of files
file_list = zip_ref.namelist()
logger.debug(f"Files in MPTA archive: {len(file_list)} files")
# Extract all files
zip_ref.extractall(target_dir)
logger.info(f"Successfully extracted MPTA file to {target_dir}")
# Find the directory containing pipeline.json
pipeline_files = list(target_dir.glob("*/pipeline.json"))
if not pipeline_files:
# Check if pipeline.json is in root
if (target_dir / "pipeline.json").exists():
logger.info(f"Found pipeline.json in root of {target_dir}")
return target_dir
logger.error(f"No pipeline.json found after extraction in {target_dir}")
return None
# Return the directory containing pipeline.json
extracted_dir = pipeline_files[0].parent
logger.info(f"Extracted model to {extracted_dir}")
# Keep the MPTA file for reference but could delete if space is a concern
# mpta_path.unlink()
# logger.debug(f"Removed MPTA file after extraction: {mpta_path}")
return extracted_dir
except zipfile.BadZipFile as e:
logger.error(f"Invalid ZIP/MPTA file {mpta_path}: {str(e)}", exc_info=True)
return None
except Exception as e:
logger.error(f"Failed to extract MPTA file {mpta_path}: {str(e)}", exc_info=True)
return None
def load_pipeline_config(self, model_id: int) -> Optional[Dict[str, Any]]:
"""
Load the pipeline.json configuration for a model
Args:
model_id: The model ID
Returns:
The pipeline configuration dictionary or None if not found
"""
model_path = self.get_model_path(model_id)
if not model_path:
logger.error(f"Model {model_id} not found")
return None
pipeline_path = model_path / "pipeline.json"
if not pipeline_path.exists():
logger.error(f"pipeline.json not found for model {model_id}")
return None
try:
with open(pipeline_path, 'r') as f:
config = json.load(f)
logger.debug(f"Loaded pipeline config for model {model_id}")
return config
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in pipeline.json for model {model_id}: {str(e)}")
return None
except Exception as e:
logger.error(f"Failed to load pipeline.json for model {model_id}: {str(e)}")
return None
def get_model_file_path(self, model_id: int, filename: str) -> Optional[Path]:
"""
Get the full path to a model file (e.g., .pt file)
Args:
model_id: The model ID
filename: The filename within the model directory
Returns:
Full path to the model file or None if not found
"""
model_path = self.get_model_path(model_id)
if not model_path:
return None
file_path = model_path / filename
if not file_path.exists():
logger.error(f"Model file {filename} not found in model {model_id}")
return None
return file_path
def cleanup_model(self, model_id: int) -> bool:
"""
Remove a downloaded model to free up space
Args:
model_id: The model ID to remove
Returns:
True if successful, False otherwise
"""
if model_id not in self._downloaded_models:
logger.warning(f"Model {model_id} not in downloaded models")
return False
try:
model_dir = self.models_dir / str(model_id)
if model_dir.exists():
import shutil
shutil.rmtree(model_dir)
logger.info(f"Removed model directory: {model_dir}")
self._downloaded_models.discard(model_id)
self._model_paths.pop(model_id, None)
return True
except Exception as e:
logger.error(f"Failed to cleanup model {model_id}: {str(e)}")
return False
def get_all_downloaded_models(self) -> Set[int]:
"""
Get a set of all downloaded model IDs
Returns:
Set of model IDs that are currently downloaded
"""
return self._downloaded_models.copy()
def get_pipeline_config(self, model_id: int) -> Optional[Any]:
"""
Get the pipeline configuration for a model.
Args:
model_id: The model ID
Returns:
PipelineConfig object if found, None otherwise
"""
try:
if model_id not in self._downloaded_models:
logger.warning(f"Model {model_id} not downloaded")
return None
model_path = self._model_paths.get(model_id)
if not model_path:
logger.warning(f"Model path not found for model {model_id}")
return None
# Import here to avoid circular imports
from .pipeline import PipelineParser
# Load pipeline.json
pipeline_file = model_path / "pipeline.json"
if not pipeline_file.exists():
logger.warning(f"No pipeline.json found for model {model_id}")
return None
# Create PipelineParser object and parse the configuration
pipeline_parser = PipelineParser()
success = pipeline_parser.parse(pipeline_file)
if success:
return pipeline_parser
else:
logger.error(f"Failed to parse pipeline.json for model {model_id}")
return None
except Exception as e:
logger.error(f"Error getting pipeline config for model {model_id}: {e}", exc_info=True)
return None
def get_yolo_model(self, model_id: int, model_filename: str) -> Optional[Any]:
"""
Create a YOLOWrapper instance for a specific model file.
Args:
model_id: The model ID
model_filename: The .pt model filename
Returns:
YOLOWrapper instance if successful, None otherwise
"""
try:
# Get the model file path
model_file_path = self.get_model_file_path(model_id, model_filename)
if not model_file_path or not model_file_path.exists():
logger.error(f"Model file {model_filename} not found for model {model_id}")
return None
# Import here to avoid circular imports
from .inference import YOLOWrapper
# Create YOLOWrapper instance
yolo_model = YOLOWrapper(
model_path=model_file_path,
model_id=f"{model_id}_{model_filename}",
device=None # Auto-detect device
)
logger.info(f"Created YOLOWrapper for model {model_id}: {model_filename}")
return yolo_model
except Exception as e:
logger.error(f"Error creating YOLO model for {model_id}:{model_filename}: {e}", exc_info=True)
return None

View file

@ -1,373 +0,0 @@
"""
Pipeline Configuration Parser - Handles pipeline.json parsing and validation
"""
import json
import logging
from pathlib import Path
from typing import Dict, List, Any, Optional, Set
from dataclasses import dataclass, field
from enum import Enum
logger = logging.getLogger(__name__)
class ActionType(Enum):
"""Supported action types in pipeline"""
REDIS_SAVE_IMAGE = "redis_save_image"
REDIS_PUBLISH = "redis_publish"
# PostgreSQL actions below are DEPRECATED - kept for backward compatibility only
# These actions will be silently skipped during pipeline execution
POSTGRESQL_UPDATE = "postgresql_update"
POSTGRESQL_UPDATE_COMBINED = "postgresql_update_combined"
POSTGRESQL_INSERT = "postgresql_insert"
@dataclass
class RedisConfig:
"""Redis connection configuration"""
host: str
port: int = 6379
password: Optional[str] = None
db: int = 0
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'RedisConfig':
return cls(
host=data['host'],
port=data.get('port', 6379),
password=data.get('password'),
db=data.get('db', 0)
)
@dataclass
class PostgreSQLConfig:
"""
PostgreSQL connection configuration - DISABLED
NOTE: This configuration is kept for backward compatibility with existing
pipeline.json files, but PostgreSQL operations are disabled. All database
operations have been moved to microservices architecture.
This config will be parsed but not used for any database connections.
"""
host: str
port: int
database: str
username: str
password: str
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'PostgreSQLConfig':
"""Parse PostgreSQL config from dict (kept for backward compatibility)"""
return cls(
host=data['host'],
port=data.get('port', 5432),
database=data['database'],
username=data['username'],
password=data['password']
)
@dataclass
class Action:
"""Represents an action in the pipeline"""
type: ActionType
params: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Action':
action_type = ActionType(data['type'])
params = {k: v for k, v in data.items() if k != 'type'}
return cls(type=action_type, params=params)
@dataclass
class ModelBranch:
"""Represents a branch in the pipeline with its own model"""
model_id: str
model_file: str
trigger_classes: List[str]
min_confidence: float = 0.5
crop: bool = False
crop_class: Optional[Any] = None # Can be string or list
parallel: bool = False
actions: List[Action] = field(default_factory=list)
branches: List['ModelBranch'] = field(default_factory=list)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ModelBranch':
actions = [Action.from_dict(a) for a in data.get('actions', [])]
branches = [cls.from_dict(b) for b in data.get('branches', [])]
return cls(
model_id=data['modelId'],
model_file=data['modelFile'],
trigger_classes=data.get('triggerClasses', []),
min_confidence=data.get('minConfidence', 0.5),
crop=data.get('crop', False),
crop_class=data.get('cropClass'),
parallel=data.get('parallel', False),
actions=actions,
branches=branches
)
@dataclass
class TrackingConfig:
"""Configuration for the tracking phase"""
model_id: str
model_file: str
trigger_classes: List[str]
min_confidence: float = 0.6
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'TrackingConfig':
return cls(
model_id=data['modelId'],
model_file=data['modelFile'],
trigger_classes=data.get('triggerClasses', []),
min_confidence=data.get('minConfidence', 0.6)
)
@dataclass
class PipelineConfig:
"""Main pipeline configuration"""
model_id: str
model_file: str
trigger_classes: List[str]
min_confidence: float = 0.5
crop: bool = False
branches: List[ModelBranch] = field(default_factory=list)
parallel_actions: List[Action] = field(default_factory=list)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'PipelineConfig':
branches = [ModelBranch.from_dict(b) for b in data.get('branches', [])]
parallel_actions = [Action.from_dict(a) for a in data.get('parallelActions', [])]
return cls(
model_id=data['modelId'],
model_file=data['modelFile'],
trigger_classes=data.get('triggerClasses', []),
min_confidence=data.get('minConfidence', 0.5),
crop=data.get('crop', False),
branches=branches,
parallel_actions=parallel_actions
)
class PipelineParser:
"""Parser for pipeline.json configuration files"""
def __init__(self):
self.redis_config: Optional[RedisConfig] = None
self.postgresql_config: Optional[PostgreSQLConfig] = None
self.tracking_config: Optional[TrackingConfig] = None
self.pipeline_config: Optional[PipelineConfig] = None
self._model_dependencies: Set[str] = set()
def parse(self, config_path: Path) -> bool:
"""
Parse a pipeline.json configuration file
Args:
config_path: Path to the pipeline.json file
Returns:
True if parsing was successful, False otherwise
"""
try:
if not config_path.exists():
logger.error(f"Pipeline config not found: {config_path}")
return False
with open(config_path, 'r') as f:
data = json.load(f)
return self.parse_dict(data)
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in pipeline config: {str(e)}")
return False
except Exception as e:
logger.error(f"Failed to parse pipeline config: {str(e)}", exc_info=True)
return False
def parse_dict(self, data: Dict[str, Any]) -> bool:
"""
Parse a pipeline configuration from a dictionary
Args:
data: The configuration dictionary
Returns:
True if parsing was successful, False otherwise
"""
try:
# Parse Redis configuration
if 'redis' in data:
self.redis_config = RedisConfig.from_dict(data['redis'])
logger.debug(f"Parsed Redis config: {self.redis_config.host}:{self.redis_config.port}")
# Parse PostgreSQL configuration
if 'postgresql' in data:
self.postgresql_config = PostgreSQLConfig.from_dict(data['postgresql'])
logger.debug(f"Parsed PostgreSQL config: {self.postgresql_config.host}:{self.postgresql_config.port}/{self.postgresql_config.database}")
# Parse tracking configuration
if 'tracking' in data:
self.tracking_config = TrackingConfig.from_dict(data['tracking'])
self._model_dependencies.add(self.tracking_config.model_file)
logger.debug(f"Parsed tracking config: {self.tracking_config.model_id}")
# Parse main pipeline configuration
if 'pipeline' in data:
self.pipeline_config = PipelineConfig.from_dict(data['pipeline'])
self._collect_model_dependencies(self.pipeline_config)
logger.debug(f"Parsed pipeline config: {self.pipeline_config.model_id}")
logger.info(f"Successfully parsed pipeline configuration")
logger.debug(f"Model dependencies: {self._model_dependencies}")
return True
except KeyError as e:
logger.error(f"Missing required field in pipeline config: {str(e)}")
return False
except Exception as e:
logger.error(f"Failed to parse pipeline config: {str(e)}", exc_info=True)
return False
def _collect_model_dependencies(self, config: Any) -> None:
"""
Recursively collect all model file dependencies
Args:
config: Pipeline or branch configuration
"""
if hasattr(config, 'model_file'):
self._model_dependencies.add(config.model_file)
if hasattr(config, 'branches'):
for branch in config.branches:
self._collect_model_dependencies(branch)
def get_model_dependencies(self) -> Set[str]:
"""
Get all model file dependencies from the pipeline
Returns:
Set of model filenames required by the pipeline
"""
return self._model_dependencies.copy()
def validate(self) -> bool:
"""
Validate the parsed configuration
Returns:
True if configuration is valid, False otherwise
"""
if not self.pipeline_config:
logger.error("No pipeline configuration found")
return False
# Check that all required model files are specified
if not self.pipeline_config.model_file:
logger.error("Main pipeline model file not specified")
return False
# Validate action configurations
if not self._validate_actions(self.pipeline_config):
return False
# Validate parallel actions (PostgreSQL actions are skipped)
for action in self.pipeline_config.parallel_actions:
if action.type == ActionType.POSTGRESQL_UPDATE_COMBINED:
logger.warning(f"PostgreSQL parallel action {action.type.value} found but will be SKIPPED (PostgreSQL disabled)")
# Skip validation for PostgreSQL actions since they won't be executed
# wait_for = action.params.get('waitForBranches', [])
# if wait_for:
# # Check that referenced branches exist
# branch_ids = self._get_all_branch_ids(self.pipeline_config)
# for branch_id in wait_for:
# if branch_id not in branch_ids:
# logger.error(f"Referenced branch '{branch_id}' in waitForBranches not found")
# return False
logger.info("Pipeline configuration validated successfully")
return True
def _validate_actions(self, config: Any) -> bool:
"""
Validate actions in a pipeline or branch configuration
Args:
config: Pipeline or branch configuration
Returns:
True if valid, False otherwise
"""
if hasattr(config, 'actions'):
for action in config.actions:
# Validate Redis actions need Redis config
if action.type in [ActionType.REDIS_SAVE_IMAGE, ActionType.REDIS_PUBLISH]:
if not self.redis_config:
logger.error(f"Action {action.type} requires Redis configuration")
return False
# PostgreSQL actions are disabled - log warning instead of failing
# Kept for backward compatibility with existing pipeline.json files
if action.type in [ActionType.POSTGRESQL_UPDATE, ActionType.POSTGRESQL_UPDATE_COMBINED, ActionType.POSTGRESQL_INSERT]:
logger.warning(f"PostgreSQL action {action.type.value} found but will be SKIPPED (PostgreSQL disabled)")
# Do not fail validation - just skip these actions during execution
# if not self.postgresql_config:
# logger.error(f"Action {action.type} requires PostgreSQL configuration")
# return False
# Recursively validate branches
if hasattr(config, 'branches'):
for branch in config.branches:
if not self._validate_actions(branch):
return False
return True
def _get_all_branch_ids(self, config: Any, branch_ids: Set[str] = None) -> Set[str]:
"""
Recursively collect all branch model IDs
Args:
config: Pipeline or branch configuration
branch_ids: Set to collect IDs into
Returns:
Set of all branch model IDs
"""
if branch_ids is None:
branch_ids = set()
if hasattr(config, 'branches'):
for branch in config.branches:
branch_ids.add(branch.model_id)
self._get_all_branch_ids(branch, branch_ids)
return branch_ids
def get_redis_config(self) -> Optional[RedisConfig]:
"""Get the Redis configuration"""
return self.redis_config
def get_postgresql_config(self) -> Optional[PostgreSQLConfig]:
"""Get the PostgreSQL configuration"""
return self.postgresql_config
def get_tracking_config(self) -> Optional[TrackingConfig]:
"""Get the tracking configuration"""
return self.tracking_config
def get_pipeline_config(self) -> Optional[PipelineConfig]:
"""Get the main pipeline configuration"""
return self.pipeline_config

View file

@ -1,18 +0,0 @@
"""
Comprehensive health monitoring system for detector worker.
Tracks stream health, thread responsiveness, and system performance.
"""
from .health import HealthMonitor, HealthStatus, HealthCheck
from .stream_health import StreamHealthTracker
from .thread_health import ThreadHealthMonitor
from .recovery import RecoveryManager
__all__ = [
'HealthMonitor',
'HealthStatus',
'HealthCheck',
'StreamHealthTracker',
'ThreadHealthMonitor',
'RecoveryManager'
]

View file

@ -1,456 +0,0 @@
"""
Core health monitoring system for comprehensive stream and system health tracking.
Provides centralized health status, alerting, and recovery coordination.
"""
import time
import threading
import logging
import psutil
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass, field
from enum import Enum
from collections import defaultdict, deque
logger = logging.getLogger(__name__)
class HealthStatus(Enum):
"""Health status levels."""
HEALTHY = "healthy"
WARNING = "warning"
CRITICAL = "critical"
UNKNOWN = "unknown"
@dataclass
class HealthCheck:
"""Individual health check result."""
name: str
status: HealthStatus
message: str
timestamp: float = field(default_factory=time.time)
details: Dict[str, Any] = field(default_factory=dict)
recovery_action: Optional[str] = None
@dataclass
class HealthMetrics:
"""Health metrics for a component."""
component_id: str
last_update: float
frame_count: int = 0
error_count: int = 0
warning_count: int = 0
restart_count: int = 0
avg_frame_interval: float = 0.0
last_frame_time: Optional[float] = None
thread_alive: bool = True
connection_healthy: bool = True
memory_usage_mb: float = 0.0
cpu_usage_percent: float = 0.0
class HealthMonitor:
"""Comprehensive health monitoring system."""
def __init__(self, check_interval: float = 30.0):
"""
Initialize health monitor.
Args:
check_interval: Interval between health checks in seconds
"""
self.check_interval = check_interval
self.running = False
self.monitor_thread = None
self._lock = threading.RLock()
# Health data storage
self.health_checks: Dict[str, HealthCheck] = {}
self.metrics: Dict[str, HealthMetrics] = {}
self.alert_history: deque = deque(maxlen=1000)
self.recovery_actions: deque = deque(maxlen=500)
# Thresholds (configurable)
self.thresholds = {
'frame_stale_warning_seconds': 120, # 2 minutes
'frame_stale_critical_seconds': 300, # 5 minutes
'thread_unresponsive_seconds': 60, # 1 minute
'memory_warning_mb': 500, # 500MB per stream
'memory_critical_mb': 1000, # 1GB per stream
'cpu_warning_percent': 80, # 80% CPU
'cpu_critical_percent': 95, # 95% CPU
'error_rate_warning': 0.1, # 10% error rate
'error_rate_critical': 0.3, # 30% error rate
'restart_threshold': 3 # Max restarts per hour
}
# Health check functions
self.health_checkers: List[Callable[[], List[HealthCheck]]] = []
self.recovery_callbacks: Dict[str, Callable[[str, HealthCheck], bool]] = {}
# System monitoring
self.process = psutil.Process()
self.system_start_time = time.time()
def start(self):
"""Start health monitoring."""
if self.running:
logger.warning("Health monitor already running")
return
self.running = True
self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
self.monitor_thread.start()
logger.info(f"Health monitor started (check interval: {self.check_interval}s)")
def stop(self):
"""Stop health monitoring."""
self.running = False
if self.monitor_thread:
self.monitor_thread.join(timeout=5.0)
logger.info("Health monitor stopped")
def register_health_checker(self, checker: Callable[[], List[HealthCheck]]):
"""Register a health check function."""
self.health_checkers.append(checker)
logger.debug(f"Registered health checker: {checker.__name__}")
def register_recovery_callback(self, component: str, callback: Callable[[str, HealthCheck], bool]):
"""Register a recovery callback for a component."""
self.recovery_callbacks[component] = callback
logger.debug(f"Registered recovery callback for {component}")
def update_metrics(self, component_id: str, **kwargs):
"""Update metrics for a component."""
with self._lock:
if component_id not in self.metrics:
self.metrics[component_id] = HealthMetrics(
component_id=component_id,
last_update=time.time()
)
metrics = self.metrics[component_id]
metrics.last_update = time.time()
# Update provided metrics
for key, value in kwargs.items():
if hasattr(metrics, key):
setattr(metrics, key, value)
def report_frame_received(self, component_id: str):
"""Report that a frame was received for a component."""
current_time = time.time()
with self._lock:
if component_id not in self.metrics:
self.metrics[component_id] = HealthMetrics(
component_id=component_id,
last_update=current_time
)
metrics = self.metrics[component_id]
# Update frame metrics
if metrics.last_frame_time:
interval = current_time - metrics.last_frame_time
# Moving average of frame intervals
if metrics.avg_frame_interval == 0:
metrics.avg_frame_interval = interval
else:
metrics.avg_frame_interval = (metrics.avg_frame_interval * 0.9) + (interval * 0.1)
metrics.last_frame_time = current_time
metrics.frame_count += 1
metrics.last_update = current_time
def report_error(self, component_id: str, error_type: str = "general"):
"""Report an error for a component."""
with self._lock:
if component_id not in self.metrics:
self.metrics[component_id] = HealthMetrics(
component_id=component_id,
last_update=time.time()
)
self.metrics[component_id].error_count += 1
self.metrics[component_id].last_update = time.time()
logger.debug(f"Error reported for {component_id}: {error_type}")
def report_warning(self, component_id: str, warning_type: str = "general"):
"""Report a warning for a component."""
with self._lock:
if component_id not in self.metrics:
self.metrics[component_id] = HealthMetrics(
component_id=component_id,
last_update=time.time()
)
self.metrics[component_id].warning_count += 1
self.metrics[component_id].last_update = time.time()
logger.debug(f"Warning reported for {component_id}: {warning_type}")
def report_restart(self, component_id: str):
"""Report that a component was restarted."""
with self._lock:
if component_id not in self.metrics:
self.metrics[component_id] = HealthMetrics(
component_id=component_id,
last_update=time.time()
)
self.metrics[component_id].restart_count += 1
self.metrics[component_id].last_update = time.time()
# Log recovery action
recovery_action = {
'timestamp': time.time(),
'component': component_id,
'action': 'restart',
'reason': 'manual_restart'
}
with self._lock:
self.recovery_actions.append(recovery_action)
logger.info(f"Restart reported for {component_id}")
def get_health_status(self, component_id: Optional[str] = None) -> Dict[str, Any]:
"""Get comprehensive health status."""
with self._lock:
if component_id:
# Get health for specific component
return self._get_component_health(component_id)
else:
# Get overall health status
return self._get_overall_health()
def _get_component_health(self, component_id: str) -> Dict[str, Any]:
"""Get health status for a specific component."""
if component_id not in self.metrics:
return {
'component_id': component_id,
'status': HealthStatus.UNKNOWN.value,
'message': 'No metrics available',
'metrics': {}
}
metrics = self.metrics[component_id]
current_time = time.time()
# Determine health status
status = HealthStatus.HEALTHY
issues = []
# Check frame freshness
if metrics.last_frame_time:
frame_age = current_time - metrics.last_frame_time
if frame_age > self.thresholds['frame_stale_critical_seconds']:
status = HealthStatus.CRITICAL
issues.append(f"Frames stale for {frame_age:.1f}s")
elif frame_age > self.thresholds['frame_stale_warning_seconds']:
if status == HealthStatus.HEALTHY:
status = HealthStatus.WARNING
issues.append(f"Frames aging ({frame_age:.1f}s)")
# Check error rates
if metrics.frame_count > 0:
error_rate = metrics.error_count / metrics.frame_count
if error_rate > self.thresholds['error_rate_critical']:
status = HealthStatus.CRITICAL
issues.append(f"High error rate ({error_rate:.1%})")
elif error_rate > self.thresholds['error_rate_warning']:
if status == HealthStatus.HEALTHY:
status = HealthStatus.WARNING
issues.append(f"Elevated error rate ({error_rate:.1%})")
# Check restart frequency
restart_rate = metrics.restart_count / max(1, (current_time - self.system_start_time) / 3600)
if restart_rate > self.thresholds['restart_threshold']:
status = HealthStatus.CRITICAL
issues.append(f"Frequent restarts ({restart_rate:.1f}/hour)")
# Check thread health
if not metrics.thread_alive:
status = HealthStatus.CRITICAL
issues.append("Thread not alive")
# Check connection health
if not metrics.connection_healthy:
if status == HealthStatus.HEALTHY:
status = HealthStatus.WARNING
issues.append("Connection unhealthy")
return {
'component_id': component_id,
'status': status.value,
'message': '; '.join(issues) if issues else 'All checks passing',
'metrics': {
'frame_count': metrics.frame_count,
'error_count': metrics.error_count,
'warning_count': metrics.warning_count,
'restart_count': metrics.restart_count,
'avg_frame_interval': metrics.avg_frame_interval,
'last_frame_age': current_time - metrics.last_frame_time if metrics.last_frame_time else None,
'thread_alive': metrics.thread_alive,
'connection_healthy': metrics.connection_healthy,
'memory_usage_mb': metrics.memory_usage_mb,
'cpu_usage_percent': metrics.cpu_usage_percent,
'uptime_seconds': current_time - self.system_start_time
},
'last_update': metrics.last_update
}
def _get_overall_health(self) -> Dict[str, Any]:
"""Get overall system health status."""
current_time = time.time()
components = {}
overall_status = HealthStatus.HEALTHY
# Get health for all components
for component_id in self.metrics.keys():
component_health = self._get_component_health(component_id)
components[component_id] = component_health
# Determine overall status
component_status = HealthStatus(component_health['status'])
if component_status == HealthStatus.CRITICAL:
overall_status = HealthStatus.CRITICAL
elif component_status == HealthStatus.WARNING and overall_status == HealthStatus.HEALTHY:
overall_status = HealthStatus.WARNING
# System metrics
try:
system_memory = self.process.memory_info()
system_cpu = self.process.cpu_percent()
except Exception:
system_memory = None
system_cpu = 0.0
return {
'overall_status': overall_status.value,
'timestamp': current_time,
'uptime_seconds': current_time - self.system_start_time,
'total_components': len(self.metrics),
'components': components,
'system_metrics': {
'memory_mb': system_memory.rss / (1024 * 1024) if system_memory else 0,
'cpu_percent': system_cpu,
'process_id': self.process.pid
},
'recent_alerts': list(self.alert_history)[-10:], # Last 10 alerts
'recent_recoveries': list(self.recovery_actions)[-10:] # Last 10 recovery actions
}
def _monitor_loop(self):
"""Main health monitoring loop."""
logger.info("Health monitor loop started")
while self.running:
try:
start_time = time.time()
# Run all registered health checks
all_checks = []
for checker in self.health_checkers:
try:
checks = checker()
all_checks.extend(checks)
except Exception as e:
logger.error(f"Error in health checker {checker.__name__}: {e}")
# Process health checks and trigger recovery if needed
for check in all_checks:
self._process_health_check(check)
# Update system metrics
self._update_system_metrics()
# Sleep until next check
elapsed = time.time() - start_time
sleep_time = max(0, self.check_interval - elapsed)
if sleep_time > 0:
time.sleep(sleep_time)
except Exception as e:
logger.error(f"Error in health monitor loop: {e}")
time.sleep(5.0) # Fallback sleep
logger.info("Health monitor loop ended")
def _process_health_check(self, check: HealthCheck):
"""Process a health check result and trigger recovery if needed."""
with self._lock:
# Store health check
self.health_checks[check.name] = check
# Log alerts for non-healthy status
if check.status != HealthStatus.HEALTHY:
alert = {
'timestamp': check.timestamp,
'component': check.name,
'status': check.status.value,
'message': check.message,
'details': check.details
}
self.alert_history.append(alert)
logger.warning(f"Health alert [{check.status.value.upper()}] {check.name}: {check.message}")
# Trigger recovery if critical and recovery action available
if check.status == HealthStatus.CRITICAL and check.recovery_action:
self._trigger_recovery(check.name, check)
def _trigger_recovery(self, component: str, check: HealthCheck):
"""Trigger recovery action for a component."""
if component in self.recovery_callbacks:
try:
logger.info(f"Triggering recovery for {component}: {check.recovery_action}")
success = self.recovery_callbacks[component](component, check)
recovery_action = {
'timestamp': time.time(),
'component': component,
'action': check.recovery_action,
'reason': check.message,
'success': success
}
with self._lock:
self.recovery_actions.append(recovery_action)
if success:
logger.info(f"Recovery successful for {component}")
else:
logger.error(f"Recovery failed for {component}")
except Exception as e:
logger.error(f"Error in recovery callback for {component}: {e}")
def _update_system_metrics(self):
"""Update system-level metrics."""
try:
# Update process metrics for all components
current_time = time.time()
with self._lock:
for component_id, metrics in self.metrics.items():
# Update CPU and memory if available
try:
# This is a simplified approach - in practice you'd want
# per-thread or per-component resource tracking
metrics.cpu_usage_percent = self.process.cpu_percent() / len(self.metrics)
memory_info = self.process.memory_info()
metrics.memory_usage_mb = memory_info.rss / (1024 * 1024) / len(self.metrics)
except Exception:
pass
except Exception as e:
logger.error(f"Error updating system metrics: {e}")
# Global health monitor instance
health_monitor = HealthMonitor()

View file

@ -1,385 +0,0 @@
"""
Recovery manager for automatic handling of health issues.
Provides circuit breaker patterns, automatic restarts, and graceful degradation.
"""
import time
import logging
import threading
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass
from enum import Enum
from collections import defaultdict, deque
from .health import HealthCheck, HealthStatus, health_monitor
logger = logging.getLogger(__name__)
class RecoveryAction(Enum):
"""Types of recovery actions."""
RESTART_STREAM = "restart_stream"
RESTART_THREAD = "restart_thread"
CLEAR_BUFFER = "clear_buffer"
RECONNECT = "reconnect"
THROTTLE = "throttle"
DISABLE = "disable"
@dataclass
class RecoveryAttempt:
"""Record of a recovery attempt."""
timestamp: float
component: str
action: RecoveryAction
reason: str
success: bool
details: Dict[str, Any] = None
@dataclass
class RecoveryState:
"""Recovery state for a component - simplified without circuit breaker."""
failure_count: int = 0
success_count: int = 0
last_failure_time: Optional[float] = None
last_success_time: Optional[float] = None
class RecoveryManager:
"""Manages automatic recovery actions for health issues."""
def __init__(self):
self.recovery_handlers: Dict[str, Callable[[str, HealthCheck], bool]] = {}
self.recovery_states: Dict[str, RecoveryState] = {}
self.recovery_history: deque = deque(maxlen=1000)
self._lock = threading.RLock()
# Configuration - simplified without circuit breaker
self.recovery_cooldown = 30 # 30 seconds between recovery attempts
self.max_attempts_per_hour = 20 # Still limit to prevent spam, but much higher
# Track recovery attempts per component
self.recovery_attempts: Dict[str, deque] = defaultdict(lambda: deque(maxlen=50))
# Register with health monitor
health_monitor.register_recovery_callback("stream", self._handle_stream_recovery)
health_monitor.register_recovery_callback("thread", self._handle_thread_recovery)
health_monitor.register_recovery_callback("buffer", self._handle_buffer_recovery)
def register_recovery_handler(self, action: RecoveryAction, handler: Callable[[str, Dict[str, Any]], bool]):
"""
Register a recovery handler for a specific action.
Args:
action: Type of recovery action
handler: Function that performs the recovery
"""
self.recovery_handlers[action.value] = handler
logger.info(f"Registered recovery handler for {action.value}")
def can_attempt_recovery(self, component: str) -> bool:
"""
Check if recovery can be attempted for a component.
Args:
component: Component identifier
Returns:
True if recovery can be attempted (always allow with minimal throttling)
"""
with self._lock:
current_time = time.time()
# Check recovery attempt rate limiting (much more permissive)
recent_attempts = [
attempt for attempt in self.recovery_attempts[component]
if current_time - attempt <= 3600 # Last hour
]
# Only block if truly excessive attempts
if len(recent_attempts) >= self.max_attempts_per_hour:
logger.warning(f"Recovery rate limit exceeded for {component} "
f"({len(recent_attempts)} attempts in last hour)")
return False
# Check cooldown period (shorter cooldown)
if recent_attempts:
last_attempt = max(recent_attempts)
if current_time - last_attempt < self.recovery_cooldown:
logger.debug(f"Recovery cooldown active for {component} "
f"(last attempt {current_time - last_attempt:.1f}s ago)")
return False
return True
def attempt_recovery(self, component: str, action: RecoveryAction, reason: str,
details: Optional[Dict[str, Any]] = None) -> bool:
"""
Attempt recovery for a component.
Args:
component: Component identifier
action: Recovery action to perform
reason: Reason for recovery
details: Additional details
Returns:
True if recovery was successful
"""
if not self.can_attempt_recovery(component):
return False
current_time = time.time()
logger.info(f"Attempting recovery for {component}: {action.value} ({reason})")
try:
# Record recovery attempt
with self._lock:
self.recovery_attempts[component].append(current_time)
# Perform recovery action
success = self._execute_recovery_action(component, action, details or {})
# Record recovery result
attempt = RecoveryAttempt(
timestamp=current_time,
component=component,
action=action,
reason=reason,
success=success,
details=details
)
with self._lock:
self.recovery_history.append(attempt)
# Update recovery state
self._update_recovery_state(component, success)
if success:
logger.info(f"Recovery successful for {component}: {action.value}")
else:
logger.error(f"Recovery failed for {component}: {action.value}")
return success
except Exception as e:
logger.error(f"Error during recovery for {component}: {e}")
self._update_recovery_state(component, False)
return False
def _execute_recovery_action(self, component: str, action: RecoveryAction,
details: Dict[str, Any]) -> bool:
"""Execute a specific recovery action."""
handler_key = action.value
if handler_key not in self.recovery_handlers:
logger.error(f"No recovery handler registered for action: {handler_key}")
return False
try:
handler = self.recovery_handlers[handler_key]
return handler(component, details)
except Exception as e:
logger.error(f"Error executing recovery action {handler_key} for {component}: {e}")
return False
def _update_recovery_state(self, component: str, success: bool):
"""Update recovery state based on recovery result."""
current_time = time.time()
with self._lock:
if component not in self.recovery_states:
self.recovery_states[component] = RecoveryState()
state = self.recovery_states[component]
if success:
state.success_count += 1
state.last_success_time = current_time
# Reset failure count on success
state.failure_count = max(0, state.failure_count - 1)
logger.debug(f"Recovery success for {component} (total successes: {state.success_count})")
else:
state.failure_count += 1
state.last_failure_time = current_time
logger.debug(f"Recovery failure for {component} (total failures: {state.failure_count})")
def _handle_stream_recovery(self, component: str, health_check: HealthCheck) -> bool:
"""Handle recovery for stream-related issues."""
if "frames" in health_check.name:
# Frame-related issue - restart stream
return self.attempt_recovery(
component,
RecoveryAction.RESTART_STREAM,
health_check.message,
health_check.details
)
elif "connection" in health_check.name:
# Connection issue - reconnect
return self.attempt_recovery(
component,
RecoveryAction.RECONNECT,
health_check.message,
health_check.details
)
elif "errors" in health_check.name:
# High error rate - throttle or restart
return self.attempt_recovery(
component,
RecoveryAction.THROTTLE,
health_check.message,
health_check.details
)
else:
# Generic stream issue - restart
return self.attempt_recovery(
component,
RecoveryAction.RESTART_STREAM,
health_check.message,
health_check.details
)
def _handle_thread_recovery(self, component: str, health_check: HealthCheck) -> bool:
"""Handle recovery for thread-related issues."""
if "deadlock" in health_check.name:
# Deadlock detected - restart thread
return self.attempt_recovery(
component,
RecoveryAction.RESTART_THREAD,
health_check.message,
health_check.details
)
elif "responsive" in health_check.name:
# Thread unresponsive - restart
return self.attempt_recovery(
component,
RecoveryAction.RESTART_THREAD,
health_check.message,
health_check.details
)
else:
# Generic thread issue - restart
return self.attempt_recovery(
component,
RecoveryAction.RESTART_THREAD,
health_check.message,
health_check.details
)
def _handle_buffer_recovery(self, component: str, health_check: HealthCheck) -> bool:
"""Handle recovery for buffer-related issues."""
# Buffer issues - clear buffer
return self.attempt_recovery(
component,
RecoveryAction.CLEAR_BUFFER,
health_check.message,
health_check.details
)
def get_recovery_stats(self) -> Dict[str, Any]:
"""Get recovery statistics."""
current_time = time.time()
with self._lock:
# Calculate stats from history
recent_recoveries = [
attempt for attempt in self.recovery_history
if current_time - attempt.timestamp <= 3600 # Last hour
]
stats_by_component = defaultdict(lambda: {
'attempts': 0,
'successes': 0,
'failures': 0,
'last_attempt': None,
'last_success': None
})
for attempt in recent_recoveries:
stats = stats_by_component[attempt.component]
stats['attempts'] += 1
if attempt.success:
stats['successes'] += 1
if not stats['last_success'] or attempt.timestamp > stats['last_success']:
stats['last_success'] = attempt.timestamp
else:
stats['failures'] += 1
if not stats['last_attempt'] or attempt.timestamp > stats['last_attempt']:
stats['last_attempt'] = attempt.timestamp
return {
'total_recoveries_last_hour': len(recent_recoveries),
'recovery_by_component': dict(stats_by_component),
'recovery_states': {
component: {
'failure_count': state.failure_count,
'success_count': state.success_count,
'last_failure_time': state.last_failure_time,
'last_success_time': state.last_success_time
}
for component, state in self.recovery_states.items()
},
'recent_history': [
{
'timestamp': attempt.timestamp,
'component': attempt.component,
'action': attempt.action.value,
'reason': attempt.reason,
'success': attempt.success
}
for attempt in list(self.recovery_history)[-10:] # Last 10 attempts
]
}
def force_recovery(self, component: str, action: RecoveryAction, reason: str = "manual") -> bool:
"""
Force recovery for a component, bypassing rate limiting.
Args:
component: Component identifier
action: Recovery action to perform
reason: Reason for forced recovery
Returns:
True if recovery was successful
"""
logger.info(f"Forcing recovery for {component}: {action.value} ({reason})")
current_time = time.time()
try:
# Execute recovery action directly
success = self._execute_recovery_action(component, action, {})
# Record forced recovery
attempt = RecoveryAttempt(
timestamp=current_time,
component=component,
action=action,
reason=f"forced: {reason}",
success=success,
details={'forced': True}
)
with self._lock:
self.recovery_history.append(attempt)
self.recovery_attempts[component].append(current_time)
# Update recovery state
self._update_recovery_state(component, success)
return success
except Exception as e:
logger.error(f"Error during forced recovery for {component}: {e}")
return False
# Global recovery manager instance
recovery_manager = RecoveryManager()

View file

@ -1,351 +0,0 @@
"""
Stream-specific health monitoring for video streams.
Tracks frame production, connection health, and stream-specific metrics.
"""
import time
import logging
import threading
import requests
from typing import Dict, Optional, List, Any
from collections import deque
from dataclasses import dataclass
from .health import HealthCheck, HealthStatus, health_monitor
logger = logging.getLogger(__name__)
@dataclass
class StreamMetrics:
"""Metrics for an individual stream."""
camera_id: str
stream_type: str # 'rtsp', 'http_snapshot'
start_time: float
last_frame_time: Optional[float] = None
frame_count: int = 0
error_count: int = 0
reconnect_count: int = 0
bytes_received: int = 0
frames_per_second: float = 0.0
connection_attempts: int = 0
last_connection_test: Optional[float] = None
connection_healthy: bool = True
last_error: Optional[str] = None
last_error_time: Optional[float] = None
class StreamHealthTracker:
"""Tracks health for individual video streams."""
def __init__(self):
self.streams: Dict[str, StreamMetrics] = {}
self._lock = threading.RLock()
# Configuration
self.connection_test_interval = 300 # Test connection every 5 minutes
self.frame_timeout_warning = 120 # Warn if no frames for 2 minutes
self.frame_timeout_critical = 300 # Critical if no frames for 5 minutes
self.error_rate_threshold = 0.1 # 10% error rate threshold
# Register with health monitor
health_monitor.register_health_checker(self._perform_health_checks)
def register_stream(self, camera_id: str, stream_type: str, source_url: Optional[str] = None):
"""Register a new stream for monitoring."""
with self._lock:
if camera_id not in self.streams:
self.streams[camera_id] = StreamMetrics(
camera_id=camera_id,
stream_type=stream_type,
start_time=time.time()
)
logger.info(f"Registered stream for monitoring: {camera_id} ({stream_type})")
# Update health monitor metrics
health_monitor.update_metrics(
camera_id,
thread_alive=True,
connection_healthy=True
)
def unregister_stream(self, camera_id: str):
"""Unregister a stream from monitoring."""
with self._lock:
if camera_id in self.streams:
del self.streams[camera_id]
logger.info(f"Unregistered stream from monitoring: {camera_id}")
def report_frame_received(self, camera_id: str, frame_size_bytes: int = 0):
"""Report that a frame was received."""
current_time = time.time()
with self._lock:
if camera_id not in self.streams:
logger.warning(f"Frame received for unregistered stream: {camera_id}")
return
stream = self.streams[camera_id]
# Update frame metrics
if stream.last_frame_time:
interval = current_time - stream.last_frame_time
# Calculate FPS as moving average
if stream.frames_per_second == 0:
stream.frames_per_second = 1.0 / interval if interval > 0 else 0
else:
new_fps = 1.0 / interval if interval > 0 else 0
stream.frames_per_second = (stream.frames_per_second * 0.9) + (new_fps * 0.1)
stream.last_frame_time = current_time
stream.frame_count += 1
stream.bytes_received += frame_size_bytes
# Report to health monitor
health_monitor.report_frame_received(camera_id)
health_monitor.update_metrics(
camera_id,
frame_count=stream.frame_count,
avg_frame_interval=1.0 / stream.frames_per_second if stream.frames_per_second > 0 else 0,
last_frame_time=current_time
)
def report_error(self, camera_id: str, error_message: str):
"""Report an error for a stream."""
current_time = time.time()
with self._lock:
if camera_id not in self.streams:
logger.warning(f"Error reported for unregistered stream: {camera_id}")
return
stream = self.streams[camera_id]
stream.error_count += 1
stream.last_error = error_message
stream.last_error_time = current_time
# Report to health monitor
health_monitor.report_error(camera_id, "stream_error")
health_monitor.update_metrics(
camera_id,
error_count=stream.error_count
)
logger.debug(f"Error reported for stream {camera_id}: {error_message}")
def report_reconnect(self, camera_id: str, reason: str = "unknown"):
"""Report that a stream reconnected."""
current_time = time.time()
with self._lock:
if camera_id not in self.streams:
logger.warning(f"Reconnect reported for unregistered stream: {camera_id}")
return
stream = self.streams[camera_id]
stream.reconnect_count += 1
# Report to health monitor
health_monitor.report_restart(camera_id)
health_monitor.update_metrics(
camera_id,
restart_count=stream.reconnect_count
)
logger.info(f"Reconnect reported for stream {camera_id}: {reason}")
def report_connection_attempt(self, camera_id: str, success: bool):
"""Report a connection attempt."""
with self._lock:
if camera_id not in self.streams:
return
stream = self.streams[camera_id]
stream.connection_attempts += 1
stream.connection_healthy = success
# Report to health monitor
health_monitor.update_metrics(
camera_id,
connection_healthy=success
)
def test_http_connection(self, camera_id: str, url: str) -> bool:
"""Test HTTP connection health for snapshot streams."""
try:
# Quick HEAD request to test connectivity
response = requests.head(url, timeout=5, verify=False)
success = response.status_code in [200, 404] # 404 might be normal for some cameras
self.report_connection_attempt(camera_id, success)
if success:
logger.debug(f"Connection test passed for {camera_id}")
else:
logger.warning(f"Connection test failed for {camera_id}: HTTP {response.status_code}")
return success
except Exception as e:
logger.warning(f"Connection test failed for {camera_id}: {e}")
self.report_connection_attempt(camera_id, False)
return False
def get_stream_metrics(self, camera_id: str) -> Optional[Dict[str, Any]]:
"""Get metrics for a specific stream."""
with self._lock:
if camera_id not in self.streams:
return None
stream = self.streams[camera_id]
current_time = time.time()
# Calculate derived metrics
uptime = current_time - stream.start_time
frame_age = current_time - stream.last_frame_time if stream.last_frame_time else None
error_rate = stream.error_count / max(1, stream.frame_count)
return {
'camera_id': camera_id,
'stream_type': stream.stream_type,
'uptime_seconds': uptime,
'frame_count': stream.frame_count,
'frames_per_second': stream.frames_per_second,
'bytes_received': stream.bytes_received,
'error_count': stream.error_count,
'error_rate': error_rate,
'reconnect_count': stream.reconnect_count,
'connection_attempts': stream.connection_attempts,
'connection_healthy': stream.connection_healthy,
'last_frame_age_seconds': frame_age,
'last_error': stream.last_error,
'last_error_time': stream.last_error_time
}
def get_all_metrics(self) -> Dict[str, Dict[str, Any]]:
"""Get metrics for all streams."""
with self._lock:
return {
camera_id: self.get_stream_metrics(camera_id)
for camera_id in self.streams.keys()
}
def _perform_health_checks(self) -> List[HealthCheck]:
"""Perform health checks for all streams."""
checks = []
current_time = time.time()
with self._lock:
for camera_id, stream in self.streams.items():
checks.extend(self._check_stream_health(camera_id, stream, current_time))
return checks
def _check_stream_health(self, camera_id: str, stream: StreamMetrics, current_time: float) -> List[HealthCheck]:
"""Perform health checks for a single stream."""
checks = []
# Check frame freshness
if stream.last_frame_time:
frame_age = current_time - stream.last_frame_time
if frame_age > self.frame_timeout_critical:
checks.append(HealthCheck(
name=f"stream_{camera_id}_frames",
status=HealthStatus.CRITICAL,
message=f"No frames for {frame_age:.1f}s (critical threshold: {self.frame_timeout_critical}s)",
details={
'frame_age': frame_age,
'threshold': self.frame_timeout_critical,
'last_frame_time': stream.last_frame_time
},
recovery_action="restart_stream"
))
elif frame_age > self.frame_timeout_warning:
checks.append(HealthCheck(
name=f"stream_{camera_id}_frames",
status=HealthStatus.WARNING,
message=f"Frames aging: {frame_age:.1f}s (warning threshold: {self.frame_timeout_warning}s)",
details={
'frame_age': frame_age,
'threshold': self.frame_timeout_warning,
'last_frame_time': stream.last_frame_time
}
))
else:
# No frames received yet
startup_time = current_time - stream.start_time
if startup_time > 60: # Allow 1 minute for initial connection
checks.append(HealthCheck(
name=f"stream_{camera_id}_startup",
status=HealthStatus.CRITICAL,
message=f"No frames received since startup {startup_time:.1f}s ago",
details={
'startup_time': startup_time,
'start_time': stream.start_time
},
recovery_action="restart_stream"
))
# Check error rate
if stream.frame_count > 10: # Need sufficient samples
error_rate = stream.error_count / stream.frame_count
if error_rate > self.error_rate_threshold:
checks.append(HealthCheck(
name=f"stream_{camera_id}_errors",
status=HealthStatus.WARNING,
message=f"High error rate: {error_rate:.1%} ({stream.error_count}/{stream.frame_count})",
details={
'error_rate': error_rate,
'error_count': stream.error_count,
'frame_count': stream.frame_count,
'last_error': stream.last_error
}
))
# Check connection health
if not stream.connection_healthy:
checks.append(HealthCheck(
name=f"stream_{camera_id}_connection",
status=HealthStatus.WARNING,
message="Connection unhealthy (last test failed)",
details={
'connection_attempts': stream.connection_attempts,
'last_connection_test': stream.last_connection_test
}
))
# Check excessive reconnects
uptime_hours = (current_time - stream.start_time) / 3600
if uptime_hours > 1 and stream.reconnect_count > 5: # More than 5 reconnects per hour
reconnect_rate = stream.reconnect_count / uptime_hours
checks.append(HealthCheck(
name=f"stream_{camera_id}_stability",
status=HealthStatus.WARNING,
message=f"Frequent reconnects: {reconnect_rate:.1f}/hour ({stream.reconnect_count} total)",
details={
'reconnect_rate': reconnect_rate,
'reconnect_count': stream.reconnect_count,
'uptime_hours': uptime_hours
}
))
# Check frame rate health
if stream.last_frame_time and stream.frames_per_second > 0:
expected_fps = 6.0 # Expected FPS for streams
if stream.frames_per_second < expected_fps * 0.5: # Less than 50% of expected
checks.append(HealthCheck(
name=f"stream_{camera_id}_framerate",
status=HealthStatus.WARNING,
message=f"Low frame rate: {stream.frames_per_second:.1f} fps (expected: ~{expected_fps} fps)",
details={
'current_fps': stream.frames_per_second,
'expected_fps': expected_fps
}
))
return checks
# Global stream health tracker instance
stream_health_tracker = StreamHealthTracker()

View file

@ -1,381 +0,0 @@
"""
Thread health monitoring for detecting unresponsive and deadlocked threads.
Provides thread liveness detection and responsiveness testing.
"""
import time
import threading
import logging
import signal
import traceback
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass
from collections import defaultdict
from .health import HealthCheck, HealthStatus, health_monitor
logger = logging.getLogger(__name__)
@dataclass
class ThreadInfo:
"""Information about a monitored thread."""
thread_id: int
thread_name: str
start_time: float
last_heartbeat: float
heartbeat_count: int = 0
is_responsive: bool = True
last_activity: Optional[str] = None
stack_traces: List[str] = None
class ThreadHealthMonitor:
"""Monitors thread health and responsiveness."""
def __init__(self):
self.monitored_threads: Dict[int, ThreadInfo] = {}
self.heartbeat_callbacks: Dict[int, Callable[[], bool]] = {}
self._lock = threading.RLock()
# Configuration
self.heartbeat_timeout = 60.0 # 1 minute without heartbeat = unresponsive
self.responsiveness_test_interval = 30.0 # Test responsiveness every 30 seconds
self.stack_trace_count = 5 # Keep last 5 stack traces for analysis
# Register with health monitor
health_monitor.register_health_checker(self._perform_health_checks)
# Enable periodic responsiveness testing
self.test_thread = threading.Thread(target=self._responsiveness_test_loop, daemon=True)
self.test_thread.start()
def register_thread(self, thread: threading.Thread, heartbeat_callback: Optional[Callable[[], bool]] = None):
"""
Register a thread for monitoring.
Args:
thread: Thread to monitor
heartbeat_callback: Optional callback to test thread responsiveness
"""
with self._lock:
thread_info = ThreadInfo(
thread_id=thread.ident,
thread_name=thread.name,
start_time=time.time(),
last_heartbeat=time.time()
)
self.monitored_threads[thread.ident] = thread_info
if heartbeat_callback:
self.heartbeat_callbacks[thread.ident] = heartbeat_callback
logger.info(f"Registered thread for monitoring: {thread.name} (ID: {thread.ident})")
def unregister_thread(self, thread_id: int):
"""Unregister a thread from monitoring."""
with self._lock:
if thread_id in self.monitored_threads:
thread_name = self.monitored_threads[thread_id].thread_name
del self.monitored_threads[thread_id]
if thread_id in self.heartbeat_callbacks:
del self.heartbeat_callbacks[thread_id]
logger.info(f"Unregistered thread from monitoring: {thread_name} (ID: {thread_id})")
def heartbeat(self, thread_id: Optional[int] = None, activity: Optional[str] = None):
"""
Report thread heartbeat.
Args:
thread_id: Thread ID (uses current thread if None)
activity: Description of current activity
"""
if thread_id is None:
thread_id = threading.current_thread().ident
current_time = time.time()
with self._lock:
if thread_id in self.monitored_threads:
thread_info = self.monitored_threads[thread_id]
thread_info.last_heartbeat = current_time
thread_info.heartbeat_count += 1
thread_info.is_responsive = True
if activity:
thread_info.last_activity = activity
# Report to health monitor
health_monitor.update_metrics(
f"thread_{thread_info.thread_name}",
thread_alive=True,
last_frame_time=current_time
)
def get_thread_info(self, thread_id: int) -> Optional[Dict[str, Any]]:
"""Get information about a monitored thread."""
with self._lock:
if thread_id not in self.monitored_threads:
return None
thread_info = self.monitored_threads[thread_id]
current_time = time.time()
return {
'thread_id': thread_id,
'thread_name': thread_info.thread_name,
'uptime_seconds': current_time - thread_info.start_time,
'last_heartbeat_age': current_time - thread_info.last_heartbeat,
'heartbeat_count': thread_info.heartbeat_count,
'is_responsive': thread_info.is_responsive,
'last_activity': thread_info.last_activity,
'stack_traces': thread_info.stack_traces or []
}
def get_all_thread_info(self) -> Dict[int, Dict[str, Any]]:
"""Get information about all monitored threads."""
with self._lock:
return {
thread_id: self.get_thread_info(thread_id)
for thread_id in self.monitored_threads.keys()
}
def test_thread_responsiveness(self, thread_id: int) -> bool:
"""
Test if a thread is responsive by calling its heartbeat callback.
Args:
thread_id: ID of thread to test
Returns:
True if thread responds within timeout
"""
if thread_id not in self.heartbeat_callbacks:
return True # Can't test if no callback provided
try:
# Call the heartbeat callback with a timeout
callback = self.heartbeat_callbacks[thread_id]
# This is a simple approach - in practice you might want to use
# threading.Timer or asyncio for more sophisticated timeout handling
start_time = time.time()
result = callback()
response_time = time.time() - start_time
with self._lock:
if thread_id in self.monitored_threads:
self.monitored_threads[thread_id].is_responsive = result
if response_time > 5.0: # Slow response
logger.warning(f"Thread {thread_id} slow response: {response_time:.1f}s")
return result
except Exception as e:
logger.error(f"Error testing thread {thread_id} responsiveness: {e}")
with self._lock:
if thread_id in self.monitored_threads:
self.monitored_threads[thread_id].is_responsive = False
return False
def capture_stack_trace(self, thread_id: int) -> Optional[str]:
"""
Capture stack trace for a thread.
Args:
thread_id: ID of thread to capture
Returns:
Stack trace string or None if not available
"""
try:
# Get all frames for all threads
frames = dict(threading._current_frames())
if thread_id not in frames:
return None
# Format stack trace
frame = frames[thread_id]
stack_trace = ''.join(traceback.format_stack(frame))
# Store in thread info
with self._lock:
if thread_id in self.monitored_threads:
thread_info = self.monitored_threads[thread_id]
if thread_info.stack_traces is None:
thread_info.stack_traces = []
thread_info.stack_traces.append(f"{time.time()}: {stack_trace}")
# Keep only last N stack traces
if len(thread_info.stack_traces) > self.stack_trace_count:
thread_info.stack_traces = thread_info.stack_traces[-self.stack_trace_count:]
return stack_trace
except Exception as e:
logger.error(f"Error capturing stack trace for thread {thread_id}: {e}")
return None
def detect_deadlocks(self) -> List[Dict[str, Any]]:
"""
Attempt to detect potential deadlocks by analyzing thread states.
Returns:
List of potential deadlock scenarios
"""
deadlocks = []
current_time = time.time()
with self._lock:
# Look for threads that haven't had heartbeats for a long time
# and are supposedly alive
for thread_id, thread_info in self.monitored_threads.items():
heartbeat_age = current_time - thread_info.last_heartbeat
if heartbeat_age > self.heartbeat_timeout * 2: # Double the timeout
# Check if thread still exists
thread_exists = any(
t.ident == thread_id and t.is_alive()
for t in threading.enumerate()
)
if thread_exists:
# Thread exists but not responding - potential deadlock
stack_trace = self.capture_stack_trace(thread_id)
deadlock_info = {
'thread_id': thread_id,
'thread_name': thread_info.thread_name,
'heartbeat_age': heartbeat_age,
'last_activity': thread_info.last_activity,
'stack_trace': stack_trace,
'detection_time': current_time
}
deadlocks.append(deadlock_info)
logger.warning(f"Potential deadlock detected in thread {thread_info.thread_name}")
return deadlocks
def _responsiveness_test_loop(self):
"""Background loop to test thread responsiveness."""
logger.info("Thread responsiveness testing started")
while True:
try:
time.sleep(self.responsiveness_test_interval)
with self._lock:
thread_ids = list(self.monitored_threads.keys())
for thread_id in thread_ids:
try:
self.test_thread_responsiveness(thread_id)
except Exception as e:
logger.error(f"Error testing thread {thread_id}: {e}")
except Exception as e:
logger.error(f"Error in responsiveness test loop: {e}")
time.sleep(10.0) # Fallback sleep
def _perform_health_checks(self) -> List[HealthCheck]:
"""Perform health checks for all monitored threads."""
checks = []
current_time = time.time()
with self._lock:
for thread_id, thread_info in self.monitored_threads.items():
checks.extend(self._check_thread_health(thread_id, thread_info, current_time))
# Check for deadlocks
deadlocks = self.detect_deadlocks()
for deadlock in deadlocks:
checks.append(HealthCheck(
name=f"deadlock_detection_{deadlock['thread_id']}",
status=HealthStatus.CRITICAL,
message=f"Potential deadlock in thread {deadlock['thread_name']} "
f"(unresponsive for {deadlock['heartbeat_age']:.1f}s)",
details=deadlock,
recovery_action="restart_thread"
))
return checks
def _check_thread_health(self, thread_id: int, thread_info: ThreadInfo, current_time: float) -> List[HealthCheck]:
"""Perform health checks for a single thread."""
checks = []
# Check if thread still exists
thread_exists = any(
t.ident == thread_id and t.is_alive()
for t in threading.enumerate()
)
if not thread_exists:
checks.append(HealthCheck(
name=f"thread_{thread_info.thread_name}_alive",
status=HealthStatus.CRITICAL,
message=f"Thread {thread_info.thread_name} is no longer alive",
details={
'thread_id': thread_id,
'uptime': current_time - thread_info.start_time,
'last_heartbeat': thread_info.last_heartbeat
},
recovery_action="restart_thread"
))
return checks
# Check heartbeat freshness
heartbeat_age = current_time - thread_info.last_heartbeat
if heartbeat_age > self.heartbeat_timeout:
checks.append(HealthCheck(
name=f"thread_{thread_info.thread_name}_responsive",
status=HealthStatus.CRITICAL,
message=f"Thread {thread_info.thread_name} unresponsive for {heartbeat_age:.1f}s",
details={
'thread_id': thread_id,
'heartbeat_age': heartbeat_age,
'heartbeat_count': thread_info.heartbeat_count,
'last_activity': thread_info.last_activity,
'is_responsive': thread_info.is_responsive
},
recovery_action="restart_thread"
))
elif heartbeat_age > self.heartbeat_timeout * 0.5: # Warning at 50% of timeout
checks.append(HealthCheck(
name=f"thread_{thread_info.thread_name}_responsive",
status=HealthStatus.WARNING,
message=f"Thread {thread_info.thread_name} slow heartbeat: {heartbeat_age:.1f}s",
details={
'thread_id': thread_id,
'heartbeat_age': heartbeat_age,
'heartbeat_count': thread_info.heartbeat_count,
'last_activity': thread_info.last_activity,
'is_responsive': thread_info.is_responsive
}
))
# Check responsiveness test results
if not thread_info.is_responsive:
checks.append(HealthCheck(
name=f"thread_{thread_info.thread_name}_callback",
status=HealthStatus.WARNING,
message=f"Thread {thread_info.thread_name} failed responsiveness test",
details={
'thread_id': thread_id,
'last_activity': thread_info.last_activity
}
))
return checks
# Global thread health monitor instance
thread_health_monitor = ThreadHealthMonitor()

View file

@ -1,13 +0,0 @@
"""
Storage module for the Python Detector Worker.
This module provides Redis operations for data persistence
and caching in the detection pipeline.
Note: PostgreSQL operations have been disabled as database functionality
has been moved to microservices architecture.
"""
from .redis import RedisManager
# from .database import DatabaseManager # Disabled - moved to microservices
__all__ = ['RedisManager'] # Removed 'DatabaseManager'

View file

@ -1,369 +0,0 @@
"""
Database Operations Module - DISABLED
NOTE: This module has been disabled as PostgreSQL database operations have been
moved to microservices architecture. All database connections, reads, and writes
are now handled by separate backend services.
The detection worker now communicates results via:
- WebSocket imageDetection messages (primary data flow to backend)
- Redis image storage and pub/sub (temporary storage)
Original functionality: PostgreSQL operations for the detection pipeline.
Status: Commented out - DO NOT ENABLE without updating architecture
"""
# All PostgreSQL functionality below has been commented out
# import psycopg2
# import psycopg2.extras
from typing import Optional, Dict, Any
import logging
# import uuid
logger = logging.getLogger(__name__)
# DatabaseManager class is disabled - all methods commented out
# class DatabaseManager:
# """
# Manages PostgreSQL connections and operations for the detection pipeline.
# Handles database operations and schema management.
# """
#
# def __init__(self, config: Dict[str, Any]):
# """
# Initialize database manager with configuration.
#
# Args:
# config: Database configuration dictionary
# """
# self.config = config
# self.connection: Optional[psycopg2.extensions.connection] = None
#
# def connect(self) -> bool:
# """
# Connect to PostgreSQL database.
#
# Returns:
# True if successful, False otherwise
# """
# try:
# self.connection = psycopg2.connect(
# host=self.config['host'],
# port=self.config['port'],
# database=self.config['database'],
# user=self.config['username'],
# password=self.config['password']
# )
# logger.info("PostgreSQL connection established successfully")
# return True
# except Exception as e:
# logger.error(f"Failed to connect to PostgreSQL: {e}")
# return False
#
# def disconnect(self):
# """Disconnect from PostgreSQL database."""
# if self.connection:
# self.connection.close()
# self.connection = None
# logger.info("PostgreSQL connection closed")
#
# def is_connected(self) -> bool:
# """
# Check if database connection is active.
#
# Returns:
# True if connected, False otherwise
# """
# try:
# if self.connection and not self.connection.closed:
# cur = self.connection.cursor()
# cur.execute("SELECT 1")
# cur.fetchone()
# cur.close()
# return True
# except:
# pass
# return False
#
# def update_car_info(self, session_id: str, brand: str, model: str, body_type: str) -> bool:
# """
# Update car information in the database.
#
# Args:
# session_id: Session identifier
# brand: Car brand
# model: Car model
# body_type: Car body type
#
# Returns:
# True if successful, False otherwise
# """
# if not self.is_connected():
# if not self.connect():
# return False
#
# try:
# cur = self.connection.cursor()
# query = """
# INSERT INTO car_frontal_info (session_id, car_brand, car_model, car_body_type, updated_at)
# VALUES (%s, %s, %s, %s, NOW())
# ON CONFLICT (session_id)
# DO UPDATE SET
# car_brand = EXCLUDED.car_brand,
# car_model = EXCLUDED.car_model,
# car_body_type = EXCLUDED.car_body_type,
# updated_at = NOW()
# """
# cur.execute(query, (session_id, brand, model, body_type))
# self.connection.commit()
# cur.close()
# logger.info(f"Updated car info for session {session_id}: {brand} {model} ({body_type})")
# return True
# except Exception as e:
# logger.error(f"Failed to update car info: {e}")
# if self.connection:
# self.connection.rollback()
# return False
#
# def execute_update(self, table: str, key_field: str, key_value: str, fields: Dict[str, str]) -> bool:
# """
# Execute a dynamic update query on the database.
#
# Args:
# table: Table name
# key_field: Primary key field name
# key_value: Primary key value
# fields: Dictionary of fields to update
#
# Returns:
# True if successful, False otherwise
# """
# if not self.is_connected():
# if not self.connect():
# return False
#
# try:
# cur = self.connection.cursor()
#
# # Build the UPDATE query dynamically
# set_clauses = []
# values = []
#
# for field, value in fields.items():
# if value == "NOW()":
# set_clauses.append(f"{field} = NOW()")
# else:
# set_clauses.append(f"{field} = %s")
# values.append(value)
#
# # Add schema prefix if table doesn't already have it
# full_table_name = table if '.' in table else f"gas_station_1.{table}"
#
# query = f"""
# INSERT INTO {full_table_name} ({key_field}, {', '.join(fields.keys())})
# VALUES (%s, {', '.join(['%s'] * len(fields))})
# ON CONFLICT ({key_field})
# DO UPDATE SET {', '.join(set_clauses)}
# """
#
# # Add key_value to the beginning of values list
# all_values = [key_value] + list(fields.values()) + values
#
# cur.execute(query, all_values)
# self.connection.commit()
# cur.close()
# logger.info(f"Updated {table} for {key_field}={key_value}")
# return True
# except Exception as e:
# logger.error(f"Failed to execute update on {table}: {e}")
# if self.connection:
# self.connection.rollback()
# return False
#
# def create_car_frontal_info_table(self) -> bool:
# """
# Create the car_frontal_info table in gas_station_1 schema if it doesn't exist.
#
# Returns:
# True if successful, False otherwise
# """
# if not self.is_connected():
# if not self.connect():
# return False
#
# try:
# # Since the database already exists, just verify connection
# cur = self.connection.cursor()
#
# # Simple verification that the table exists
# cur.execute("""
# SELECT EXISTS (
# SELECT FROM information_schema.tables
# WHERE table_schema = 'gas_station_1'
# AND table_name = 'car_frontal_info'
# )
# """)
#
# table_exists = cur.fetchone()[0]
# cur.close()
#
# if table_exists:
# logger.info("Verified car_frontal_info table exists")
# return True
# else:
# logger.error("car_frontal_info table does not exist in the database")
# return False
#
# except Exception as e:
# logger.error(f"Failed to create car_frontal_info table: {e}")
# if self.connection:
# self.connection.rollback()
# return False
#
# def insert_initial_detection(self, display_id: str, captured_timestamp: str, session_id: str = None) -> str:
# """
# Insert initial detection record and return the session_id.
#
# Args:
# display_id: Display identifier
# captured_timestamp: Timestamp of the detection
# session_id: Optional session ID, generates one if not provided
#
# Returns:
# Session ID string or None on error
# """
# if not self.is_connected():
# if not self.connect():
# return None
#
# # Generate session_id if not provided
# if not session_id:
# session_id = str(uuid.uuid4())
#
# try:
# # Ensure table exists
# if not self.create_car_frontal_info_table():
# logger.error("Failed to create/verify table before insertion")
# return None
#
# cur = self.connection.cursor()
# insert_query = """
# INSERT INTO gas_station_1.car_frontal_info
# (display_id, captured_timestamp, session_id, license_character, license_type, car_brand, car_model, car_body_type)
# VALUES (%s, %s, %s, NULL, 'No model available', NULL, NULL, NULL)
# ON CONFLICT (session_id) DO NOTHING
# """
#
# cur.execute(insert_query, (display_id, captured_timestamp, session_id))
# self.connection.commit()
# cur.close()
# logger.info(f"Inserted initial detection record with session_id: {session_id}")
# return session_id
#
# except Exception as e:
# logger.error(f"Failed to insert initial detection record: {e}")
# if self.connection:
# self.connection.rollback()
# return None
#
# def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]:
# """
# Get session information from the database.
#
# Args:
# session_id: Session identifier
#
# Returns:
# Dictionary with session data or None if not found
# """
# if not self.is_connected():
# if not self.connect():
# return None
#
# try:
# cur = self.connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
# query = "SELECT * FROM gas_station_1.car_frontal_info WHERE session_id = %s"
# cur.execute(query, (session_id,))
# result = cur.fetchone()
# cur.close()
#
# if result:
# return dict(result)
# else:
# logger.debug(f"No session info found for session_id: {session_id}")
# return None
#
# except Exception as e:
# logger.error(f"Failed to get session info: {e}")
# return None
#
# def delete_session(self, session_id: str) -> bool:
# """
# Delete session record from the database.
#
# Args:
# session_id: Session identifier
#
# Returns:
# True if successful, False otherwise
# """
# if not self.is_connected():
# if not self.connect():
# return False
#
# try:
# cur = self.connection.cursor()
# query = "DELETE FROM gas_station_1.car_frontal_info WHERE session_id = %s"
# cur.execute(query, (session_id,))
# rows_affected = cur.rowcount
# self.connection.commit()
# cur.close()
#
# if rows_affected > 0:
# logger.info(f"Deleted session record: {session_id}")
# return True
# else:
# logger.warning(f"No session record found to delete: {session_id}")
# return False
#
# except Exception as e:
# logger.error(f"Failed to delete session: {e}")
# if self.connection:
# self.connection.rollback()
# return False
#
# def get_statistics(self) -> Dict[str, Any]:
# """
# Get database statistics.
#
# Returns:
# Dictionary with database statistics
# """
# stats = {
# 'connected': self.is_connected(),
# 'host': self.config.get('host', 'unknown'),
# 'port': self.config.get('port', 'unknown'),
# 'database': self.config.get('database', 'unknown')
# }
#
# if self.is_connected():
# try:
# cur = self.connection.cursor()
#
# # Get table record count
# cur.execute("SELECT COUNT(*) FROM gas_station_1.car_frontal_info")
# stats['total_records'] = cur.fetchone()[0]
#
# # Get recent records count (last hour)
# cur.execute("""
# SELECT COUNT(*) FROM gas_station_1.car_frontal_info
# WHERE created_at > NOW() - INTERVAL '1 hour'
# """)
# stats['recent_records'] = cur.fetchone()[0]
#
# cur.close()
# except Exception as e:
# logger.warning(f"Failed to get database statistics: {e}")
# stats['error'] = str(e)
#
# return stats

View file

@ -1,300 +0,0 @@
"""
License Plate Manager Module.
Handles Redis subscription to license plate results from LPR service.
"""
import logging
import json
import asyncio
from typing import Dict, Optional, Any, Callable
import redis.asyncio as redis
logger = logging.getLogger(__name__)
class LicensePlateManager:
"""
Manages license plate result subscription from Redis channel.
Subscribes to 'license_results' channel for license plate data from LPR service.
"""
def __init__(self, redis_config: Dict[str, Any]):
"""
Initialize license plate manager with Redis configuration.
Args:
redis_config: Redis configuration dictionary
"""
self.config = redis_config
self.redis_client: Optional[redis.Redis] = None
self.pubsub = None
self.subscription_task = None
self.callback = None
# Connection parameters
self.host = redis_config.get('host', 'localhost')
self.port = redis_config.get('port', 6379)
self.password = redis_config.get('password')
self.db = redis_config.get('db', 0)
# License plate data cache - store recent results by session_id
self.license_plate_cache: Dict[str, Dict[str, Any]] = {}
self.cache_ttl = 300 # 5 minutes TTL for cached results
logger.info(f"LicensePlateManager initialized for {self.host}:{self.port}")
async def initialize(self, callback: Optional[Callable] = None) -> bool:
"""
Initialize Redis connection and start subscription to license_results channel.
Args:
callback: Optional callback function for processing license plate results
Returns:
True if successful, False otherwise
"""
try:
# Create Redis connection
self.redis_client = redis.Redis(
host=self.host,
port=self.port,
password=self.password,
db=self.db,
decode_responses=True
)
# Test connection
await self.redis_client.ping()
logger.info(f"Connected to Redis for license plate subscription")
# Set callback
self.callback = callback
# Start subscription
await self._start_subscription()
return True
except Exception as e:
logger.error(f"Failed to initialize license plate manager: {e}", exc_info=True)
return False
async def _start_subscription(self):
"""Start Redis subscription to license_results channel."""
try:
if not self.redis_client:
logger.error("Redis client not initialized")
return
# Create pubsub and subscribe
self.pubsub = self.redis_client.pubsub()
await self.pubsub.subscribe('license_results')
logger.info("Subscribed to Redis channel: license_results")
# Start listening task
self.subscription_task = asyncio.create_task(self._listen_for_messages())
except Exception as e:
logger.error(f"Error starting license plate subscription: {e}", exc_info=True)
async def _listen_for_messages(self):
"""Listen for messages on the license_results channel."""
listen_generator = None
try:
if not self.pubsub:
return
listen_generator = self.pubsub.listen()
async for message in listen_generator:
if message['type'] == 'message':
try:
# Log the raw message from Redis channel
logger.info(f"[LICENSE PLATE RAW] Received from 'license_results' channel: {message['data']}")
# Parse the license plate result message
data = json.loads(message['data'])
logger.info(f"[LICENSE PLATE PARSED] Parsed JSON data: {data}")
await self._process_license_plate_result(data)
except json.JSONDecodeError as e:
logger.error(f"[LICENSE PLATE ERROR] Invalid JSON in license plate message: {e}")
logger.error(f"[LICENSE PLATE ERROR] Raw message was: {message['data']}")
except Exception as e:
logger.error(f"Error processing license plate message: {e}", exc_info=True)
except asyncio.CancelledError:
logger.info("License plate subscription task cancelled")
# Don't try to close generator here - let it be handled by the context
# The async generator will be properly closed by the cancellation mechanism
raise # Re-raise to maintain proper cancellation semantics
except Exception as e:
logger.error(f"Error in license plate message listener: {e}", exc_info=True)
# Only attempt cleanup if it's not a cancellation
finally:
# Safe cleanup of async generator
if listen_generator is not None:
try:
# Check if we can safely close without conflicting with ongoing operations
if hasattr(listen_generator, 'aclose') and not asyncio.current_task().cancelled():
await listen_generator.aclose()
except (RuntimeError, AttributeError) as e:
# Generator is already closing or in invalid state - safe to ignore
logger.debug(f"Generator cleanup skipped (safe): {e}")
except Exception as e:
logger.debug(f"Generator cleanup error (non-critical): {e}")
async def _process_license_plate_result(self, data: Dict[str, Any]):
"""
Process incoming license plate result from LPR service.
Expected message format (from actual LPR service):
{
"session_id": "511",
"license_character": "ข3184"
}
or
{
"session_id": "508",
"display_id": "test3",
"license_plate_text": "ABC-123",
"confidence": 0.95,
"timestamp": "2025-09-24T21:10:00Z"
}
Args:
data: License plate result data
"""
try:
session_id = data.get('session_id')
if not session_id:
logger.warning("License plate result missing session_id")
return
# Handle different message formats
# Format 1: {"session_id": "511", "license_character": "ข3184"}
# Format 2: {"session_id": "508", "license_plate_text": "ABC-123", "confidence": 0.95, ...}
license_plate_text = data.get('license_plate_text') or data.get('license_character')
confidence = data.get('confidence', 1.0) # Default confidence for LPR service results
display_id = data.get('display_id', '')
timestamp = data.get('timestamp', '')
logger.info(f"[LICENSE PLATE] Received result for session {session_id}: "
f"text='{license_plate_text}', confidence={confidence:.3f}")
# Store in cache
self.license_plate_cache[session_id] = {
'license_plate_text': license_plate_text,
'confidence': confidence,
'display_id': display_id,
'timestamp': timestamp,
'received_at': asyncio.get_event_loop().time()
}
# Call callback if provided
if self.callback:
await self.callback(session_id, {
'license_plate_text': license_plate_text,
'confidence': confidence,
'display_id': display_id,
'timestamp': timestamp
})
except Exception as e:
logger.error(f"Error processing license plate result: {e}", exc_info=True)
def get_license_plate_result(self, session_id: str) -> Optional[Dict[str, Any]]:
"""
Get cached license plate result for a session.
Args:
session_id: Session identifier
Returns:
License plate result dictionary or None if not found
"""
if session_id not in self.license_plate_cache:
return None
result = self.license_plate_cache[session_id]
# Check TTL
current_time = asyncio.get_event_loop().time()
if current_time - result.get('received_at', 0) > self.cache_ttl:
# Expired, remove from cache
del self.license_plate_cache[session_id]
return None
return {
'license_plate_text': result.get('license_plate_text'),
'confidence': result.get('confidence'),
'display_id': result.get('display_id'),
'timestamp': result.get('timestamp')
}
def cleanup_expired_results(self):
"""Remove expired license plate results from cache."""
try:
current_time = asyncio.get_event_loop().time()
expired_sessions = []
for session_id, result in self.license_plate_cache.items():
if current_time - result.get('received_at', 0) > self.cache_ttl:
expired_sessions.append(session_id)
for session_id in expired_sessions:
del self.license_plate_cache[session_id]
logger.debug(f"Removed expired license plate result for session {session_id}")
except Exception as e:
logger.error(f"Error cleaning up expired license plate results: {e}", exc_info=True)
async def close(self):
"""Close Redis connection and cleanup resources."""
try:
# Cancel subscription task first
if self.subscription_task and not self.subscription_task.done():
self.subscription_task.cancel()
try:
await self.subscription_task
except asyncio.CancelledError:
logger.debug("License plate subscription task cancelled successfully")
except Exception as e:
logger.warning(f"Error waiting for subscription task cancellation: {e}")
# Close pubsub connection properly
if self.pubsub:
try:
# First unsubscribe from channels
await self.pubsub.unsubscribe('license_results')
# Then close the pubsub connection
await self.pubsub.aclose()
except Exception as e:
logger.warning(f"Error closing pubsub connection: {e}")
finally:
self.pubsub = None
# Close Redis connection
if self.redis_client:
try:
await self.redis_client.aclose()
except Exception as e:
logger.warning(f"Error closing Redis connection: {e}")
finally:
self.redis_client = None
# Clear cache
self.license_plate_cache.clear()
logger.info("License plate manager closed successfully")
except Exception as e:
logger.error(f"Error closing license plate manager: {e}", exc_info=True)
def get_statistics(self) -> Dict[str, Any]:
"""Get license plate manager statistics."""
return {
'cached_results': len(self.license_plate_cache),
'connected': self.redis_client is not None,
'subscribed': self.pubsub is not None,
'host': self.host,
'port': self.port
}

View file

@ -1,478 +0,0 @@
"""
Redis Operations Module.
Handles Redis connections, image storage, and pub/sub messaging.
"""
import logging
import json
import time
from typing import Optional, Dict, Any, Union
import asyncio
import cv2
import numpy as np
import redis.asyncio as redis
from redis.exceptions import ConnectionError, TimeoutError
logger = logging.getLogger(__name__)
class RedisManager:
"""
Manages Redis connections and operations for the detection pipeline.
Handles image storage with region cropping and pub/sub messaging.
"""
def __init__(self, redis_config: Dict[str, Any]):
"""
Initialize Redis manager with configuration.
Args:
redis_config: Redis configuration dictionary
"""
self.config = redis_config
self.redis_client: Optional[redis.Redis] = None
# Connection parameters
self.host = redis_config.get('host', 'localhost')
self.port = redis_config.get('port', 6379)
self.password = redis_config.get('password')
self.db = redis_config.get('db', 0)
self.decode_responses = redis_config.get('decode_responses', True)
# Connection pool settings
self.max_connections = redis_config.get('max_connections', 10)
self.socket_timeout = redis_config.get('socket_timeout', 5)
self.socket_connect_timeout = redis_config.get('socket_connect_timeout', 5)
self.health_check_interval = redis_config.get('health_check_interval', 30)
# Statistics
self.stats = {
'images_stored': 0,
'messages_published': 0,
'connection_errors': 0,
'operations_successful': 0,
'operations_failed': 0
}
logger.info(f"RedisManager initialized for {self.host}:{self.port}")
async def initialize(self) -> bool:
"""
Initialize Redis connection and test connectivity.
Returns:
True if successful, False otherwise
"""
try:
# Validate configuration
if not self._validate_config():
return False
# Create Redis connection
self.redis_client = redis.Redis(
host=self.host,
port=self.port,
password=self.password,
db=self.db,
decode_responses=self.decode_responses,
max_connections=self.max_connections,
socket_timeout=self.socket_timeout,
socket_connect_timeout=self.socket_connect_timeout,
health_check_interval=self.health_check_interval
)
# Test connection
await self.redis_client.ping()
logger.info(f"Successfully connected to Redis at {self.host}:{self.port}")
return True
except ConnectionError as e:
logger.error(f"Failed to connect to Redis: {e}")
self.stats['connection_errors'] += 1
return False
except Exception as e:
logger.error(f"Error initializing Redis connection: {e}", exc_info=True)
self.stats['connection_errors'] += 1
return False
def _validate_config(self) -> bool:
"""
Validate Redis configuration parameters.
Returns:
True if valid, False otherwise
"""
required_fields = ['host', 'port']
for field in required_fields:
if field not in self.config:
logger.error(f"Missing required Redis config field: {field}")
return False
if not isinstance(self.port, int) or self.port <= 0:
logger.error(f"Invalid Redis port: {self.port}")
return False
return True
async def is_connected(self) -> bool:
"""
Check if Redis connection is active.
Returns:
True if connected, False otherwise
"""
try:
if self.redis_client:
await self.redis_client.ping()
return True
except Exception:
pass
return False
async def save_image(self,
key: str,
image: np.ndarray,
expire_seconds: Optional[int] = None,
image_format: str = 'jpeg',
quality: int = 90) -> bool:
"""
Save image to Redis with optional expiration.
Args:
key: Redis key for the image
image: Image array to save
expire_seconds: Optional expiration time in seconds
image_format: Image format ('jpeg' or 'png')
quality: JPEG quality (1-100)
Returns:
True if successful, False otherwise
"""
try:
if not self.redis_client:
logger.error("Redis client not initialized")
self.stats['operations_failed'] += 1
return False
# Encode image
encoded_image = self._encode_image(image, image_format, quality)
if encoded_image is None:
logger.error("Failed to encode image")
self.stats['operations_failed'] += 1
return False
# Save to Redis
if expire_seconds:
await self.redis_client.setex(key, expire_seconds, encoded_image)
logger.debug(f"Saved image to Redis with key: {key} (expires in {expire_seconds}s)")
else:
await self.redis_client.set(key, encoded_image)
logger.debug(f"Saved image to Redis with key: {key}")
self.stats['images_stored'] += 1
self.stats['operations_successful'] += 1
return True
except Exception as e:
logger.error(f"Error saving image to Redis: {e}", exc_info=True)
self.stats['operations_failed'] += 1
return False
async def get_image(self, key: str) -> Optional[np.ndarray]:
"""
Retrieve image from Redis.
Args:
key: Redis key for the image
Returns:
Image array or None if not found
"""
try:
if not self.redis_client:
logger.error("Redis client not initialized")
self.stats['operations_failed'] += 1
return None
# Get image data from Redis
image_data = await self.redis_client.get(key)
if image_data is None:
logger.debug(f"Image not found for key: {key}")
return None
# Decode image
image_array = np.frombuffer(image_data, np.uint8)
image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
if image is not None:
logger.debug(f"Retrieved image from Redis with key: {key}")
self.stats['operations_successful'] += 1
return image
else:
logger.error(f"Failed to decode image for key: {key}")
self.stats['operations_failed'] += 1
return None
except Exception as e:
logger.error(f"Error retrieving image from Redis: {e}", exc_info=True)
self.stats['operations_failed'] += 1
return None
async def delete_image(self, key: str) -> bool:
"""
Delete image from Redis.
Args:
key: Redis key for the image
Returns:
True if successful, False otherwise
"""
try:
if not self.redis_client:
logger.error("Redis client not initialized")
self.stats['operations_failed'] += 1
return False
result = await self.redis_client.delete(key)
if result > 0:
logger.debug(f"Deleted image from Redis with key: {key}")
self.stats['operations_successful'] += 1
return True
else:
logger.debug(f"Image not found for deletion: {key}")
return False
except Exception as e:
logger.error(f"Error deleting image from Redis: {e}", exc_info=True)
self.stats['operations_failed'] += 1
return False
async def publish_message(self, channel: str, message: Union[str, Dict]) -> int:
"""
Publish message to Redis channel.
Args:
channel: Redis channel name
message: Message to publish (string or dict)
Returns:
Number of subscribers that received the message, -1 on error
"""
try:
if not self.redis_client:
logger.error("Redis client not initialized")
self.stats['operations_failed'] += 1
return -1
# Convert dict to JSON string if needed
if isinstance(message, dict):
message_str = json.dumps(message)
else:
message_str = str(message)
# Test connection before publishing
await self.redis_client.ping()
# Publish message
result = await self.redis_client.publish(channel, message_str)
logger.info(f"Published message to Redis channel '{channel}': {message_str}")
logger.info(f"Redis publish result (subscribers count): {result}")
if result == 0:
logger.warning(f"No subscribers listening to channel '{channel}'")
else:
logger.info(f"Message delivered to {result} subscriber(s)")
self.stats['messages_published'] += 1
self.stats['operations_successful'] += 1
return result
except Exception as e:
logger.error(f"Error publishing message to Redis: {e}", exc_info=True)
self.stats['operations_failed'] += 1
return -1
async def subscribe_to_channel(self, channel: str, callback=None):
"""
Subscribe to Redis channel (for future use).
Args:
channel: Redis channel name
callback: Optional callback function for messages
"""
try:
if not self.redis_client:
logger.error("Redis client not initialized")
return
pubsub = self.redis_client.pubsub()
await pubsub.subscribe(channel)
logger.info(f"Subscribed to Redis channel: {channel}")
if callback:
async for message in pubsub.listen():
if message['type'] == 'message':
try:
await callback(message['data'])
except Exception as e:
logger.error(f"Error in message callback: {e}")
except Exception as e:
logger.error(f"Error subscribing to Redis channel: {e}", exc_info=True)
async def set_key(self, key: str, value: Union[str, bytes], expire_seconds: Optional[int] = None) -> bool:
"""
Set a key-value pair in Redis.
Args:
key: Redis key
value: Value to store
expire_seconds: Optional expiration time in seconds
Returns:
True if successful, False otherwise
"""
try:
if not self.redis_client:
logger.error("Redis client not initialized")
self.stats['operations_failed'] += 1
return False
if expire_seconds:
await self.redis_client.setex(key, expire_seconds, value)
else:
await self.redis_client.set(key, value)
logger.debug(f"Set Redis key: {key}")
self.stats['operations_successful'] += 1
return True
except Exception as e:
logger.error(f"Error setting Redis key: {e}", exc_info=True)
self.stats['operations_failed'] += 1
return False
async def get_key(self, key: str) -> Optional[Union[str, bytes]]:
"""
Get value for a Redis key.
Args:
key: Redis key
Returns:
Value or None if not found
"""
try:
if not self.redis_client:
logger.error("Redis client not initialized")
self.stats['operations_failed'] += 1
return None
value = await self.redis_client.get(key)
if value is not None:
logger.debug(f"Retrieved Redis key: {key}")
self.stats['operations_successful'] += 1
return value
except Exception as e:
logger.error(f"Error getting Redis key: {e}", exc_info=True)
self.stats['operations_failed'] += 1
return None
async def delete_key(self, key: str) -> bool:
"""
Delete a Redis key.
Args:
key: Redis key
Returns:
True if successful, False otherwise
"""
try:
if not self.redis_client:
logger.error("Redis client not initialized")
self.stats['operations_failed'] += 1
return False
result = await self.redis_client.delete(key)
if result > 0:
logger.debug(f"Deleted Redis key: {key}")
self.stats['operations_successful'] += 1
return True
else:
logger.debug(f"Redis key not found: {key}")
return False
except Exception as e:
logger.error(f"Error deleting Redis key: {e}", exc_info=True)
self.stats['operations_failed'] += 1
return False
def _encode_image(self, image: np.ndarray, image_format: str, quality: int) -> Optional[bytes]:
"""
Encode image to bytes for Redis storage.
Args:
image: Image array
image_format: Image format ('jpeg' or 'png')
quality: JPEG quality (1-100)
Returns:
Encoded image bytes or None on error
"""
try:
format_lower = image_format.lower()
if format_lower == 'jpeg' or format_lower == 'jpg':
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
success, buffer = cv2.imencode('.jpg', image, encode_params)
elif format_lower == 'png':
success, buffer = cv2.imencode('.png', image)
else:
logger.warning(f"Unknown image format '{image_format}', using JPEG")
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
success, buffer = cv2.imencode('.jpg', image, encode_params)
if success:
return buffer.tobytes()
else:
logger.error(f"Failed to encode image as {image_format}")
return None
except Exception as e:
logger.error(f"Error encoding image: {e}", exc_info=True)
return None
def get_statistics(self) -> Dict[str, Any]:
"""
Get Redis manager statistics.
Returns:
Dictionary with statistics
"""
return {
**self.stats,
'connected': self.redis_client is not None,
'host': self.host,
'port': self.port,
'db': self.db
}
def cleanup(self):
"""Cleanup Redis connection."""
if self.redis_client:
# Note: redis.asyncio doesn't have a synchronous close method
# The connection will be closed when the event loop shuts down
self.redis_client = None
logger.info("Redis connection cleaned up")
async def aclose(self):
"""Async cleanup for Redis connection."""
if self.redis_client:
await self.redis_client.aclose()
self.redis_client = None
logger.info("Redis connection closed")

View file

@ -1,26 +0,0 @@
"""
Streaming system for RTSP and HTTP camera feeds.
Provides modular frame readers, buffers, and stream management.
"""
from .readers import HTTPSnapshotReader, FFmpegRTSPReader
from .buffers import FrameBuffer, CacheBuffer, shared_frame_buffer, shared_cache_buffer
from .manager import StreamManager, StreamConfig, SubscriptionInfo, shared_stream_manager, initialize_stream_manager
__all__ = [
# Readers
'HTTPSnapshotReader',
'FFmpegRTSPReader',
# Buffers
'FrameBuffer',
'CacheBuffer',
'shared_frame_buffer',
'shared_cache_buffer',
# Manager
'StreamManager',
'StreamConfig',
'SubscriptionInfo',
'shared_stream_manager',
'initialize_stream_manager'
]

View file

@ -1,295 +0,0 @@
"""
Frame buffering and caching system optimized for different stream formats.
Supports 1280x720 RTSP streams and 2560x1440 HTTP snapshots.
"""
import threading
import time
import cv2
import logging
import numpy as np
from typing import Optional, Dict, Any, Tuple
from collections import defaultdict
logger = logging.getLogger(__name__)
class FrameBuffer:
"""Thread-safe frame buffer for all camera streams."""
def __init__(self, max_age_seconds: int = 5):
self.max_age_seconds = max_age_seconds
self._frames: Dict[str, Dict[str, Any]] = {}
self._lock = threading.RLock()
def put_frame(self, camera_id: str, frame: np.ndarray):
"""Store a frame for the given camera ID."""
with self._lock:
# Validate frame
if not self._validate_frame(frame):
logger.warning(f"Frame validation failed for camera {camera_id}")
return
self._frames[camera_id] = {
'frame': frame.copy(),
'timestamp': time.time(),
'shape': frame.shape,
'dtype': str(frame.dtype),
'size_mb': frame.nbytes / (1024 * 1024)
}
def get_frame(self, camera_id: str) -> Optional[np.ndarray]:
"""Get the latest frame for the given camera ID."""
with self._lock:
if camera_id not in self._frames:
return None
frame_data = self._frames[camera_id]
# Return frame regardless of age - frames persist until replaced
return frame_data['frame'].copy()
def get_frame_info(self, camera_id: str) -> Optional[Dict[str, Any]]:
"""Get frame metadata without copying the frame data."""
with self._lock:
if camera_id not in self._frames:
return None
frame_data = self._frames[camera_id]
age = time.time() - frame_data['timestamp']
# Return frame info regardless of age - frames persist until replaced
return {
'timestamp': frame_data['timestamp'],
'age': age,
'shape': frame_data['shape'],
'dtype': frame_data['dtype'],
'size_mb': frame_data.get('size_mb', 0)
}
def has_frame(self, camera_id: str) -> bool:
"""Check if a valid frame exists for the camera."""
return self.get_frame_info(camera_id) is not None
def clear_camera(self, camera_id: str):
"""Remove all frames for a specific camera."""
with self._lock:
if camera_id in self._frames:
del self._frames[camera_id]
logger.debug(f"Cleared frames for camera {camera_id}")
def clear_all(self):
"""Clear all stored frames."""
with self._lock:
count = len(self._frames)
self._frames.clear()
logger.debug(f"Cleared all frames ({count} cameras)")
def get_camera_list(self) -> list:
"""Get list of cameras with frames - all frames persist until replaced."""
with self._lock:
# Return all cameras that have frames - no age-based filtering
return list(self._frames.keys())
def get_stats(self) -> Dict[str, Any]:
"""Get buffer statistics."""
with self._lock:
current_time = time.time()
stats = {
'total_cameras': len(self._frames),
'recent_cameras': 0,
'stale_cameras': 0,
'total_memory_mb': 0,
'cameras': {}
}
for camera_id, frame_data in self._frames.items():
age = current_time - frame_data['timestamp']
size_mb = frame_data.get('size_mb', 0)
# All frames are valid/available, but categorize by freshness for monitoring
if age <= self.max_age_seconds:
stats['recent_cameras'] += 1
else:
stats['stale_cameras'] += 1
stats['total_memory_mb'] += size_mb
stats['cameras'][camera_id] = {
'age': age,
'recent': age <= self.max_age_seconds, # Recent but all frames available
'shape': frame_data['shape'],
'dtype': frame_data['dtype'],
'size_mb': size_mb
}
return stats
def _validate_frame(self, frame: np.ndarray) -> bool:
"""Validate frame - basic validation for any stream type."""
if frame is None or frame.size == 0:
return False
h, w = frame.shape[:2]
size_mb = frame.nbytes / (1024 * 1024)
# Basic size validation - reject extremely large frames regardless of type
max_size_mb = 50 # Generous limit for any frame type
if size_mb > max_size_mb:
logger.warning(f"Frame too large: {size_mb:.2f}MB (max {max_size_mb}MB) for {w}x{h}")
return False
# Basic dimension validation
if w < 100 or h < 100:
logger.warning(f"Frame too small: {w}x{h}")
return False
return True
class CacheBuffer:
"""Enhanced frame cache with support for cropping."""
def __init__(self, max_age_seconds: int = 10):
self.frame_buffer = FrameBuffer(max_age_seconds)
self._crop_cache: Dict[str, Dict[str, Any]] = {}
self._cache_lock = threading.RLock()
self.jpeg_quality = 95 # High quality for all frames
def put_frame(self, camera_id: str, frame: np.ndarray):
"""Store a frame and clear any associated crop cache."""
self.frame_buffer.put_frame(camera_id, frame)
# Clear crop cache for this camera since we have a new frame
with self._cache_lock:
keys_to_remove = [key for key in self._crop_cache.keys() if key.startswith(f"{camera_id}_")]
for key in keys_to_remove:
del self._crop_cache[key]
def get_frame(self, camera_id: str, crop_coords: Optional[Tuple[int, int, int, int]] = None) -> Optional[np.ndarray]:
"""Get frame with optional cropping."""
if crop_coords is None:
return self.frame_buffer.get_frame(camera_id)
# Check crop cache first
crop_key = f"{camera_id}_{crop_coords}"
with self._cache_lock:
if crop_key in self._crop_cache:
cache_entry = self._crop_cache[crop_key]
age = time.time() - cache_entry['timestamp']
if age <= self.frame_buffer.max_age_seconds:
return cache_entry['cropped_frame'].copy()
else:
del self._crop_cache[crop_key]
# Get original frame and crop it
original_frame = self.frame_buffer.get_frame(camera_id)
if original_frame is None:
return None
try:
x1, y1, x2, y2 = crop_coords
# Ensure coordinates are within frame bounds
h, w = original_frame.shape[:2]
x1 = max(0, min(x1, w))
y1 = max(0, min(y1, h))
x2 = max(x1, min(x2, w))
y2 = max(y1, min(y2, h))
cropped_frame = original_frame[y1:y2, x1:x2]
# Cache the cropped frame
with self._cache_lock:
# Limit cache size to prevent memory issues
if len(self._crop_cache) > 100:
# Remove oldest entries
oldest_keys = sorted(self._crop_cache.keys(),
key=lambda k: self._crop_cache[k]['timestamp'])[:50]
for key in oldest_keys:
del self._crop_cache[key]
self._crop_cache[crop_key] = {
'cropped_frame': cropped_frame.copy(),
'timestamp': time.time(),
'crop_coords': (x1, y1, x2, y2)
}
return cropped_frame
except Exception as e:
logger.error(f"Error cropping frame for camera {camera_id}: {e}")
return original_frame
def get_frame_as_jpeg(self, camera_id: str, crop_coords: Optional[Tuple[int, int, int, int]] = None,
quality: Optional[int] = None) -> Optional[bytes]:
"""Get frame as JPEG bytes."""
frame = self.get_frame(camera_id, crop_coords)
if frame is None:
return None
try:
# Use specified quality or default
if quality is None:
quality = self.jpeg_quality
# Encode as JPEG with specified quality
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
success, encoded_img = cv2.imencode('.jpg', frame, encode_params)
if success:
jpeg_bytes = encoded_img.tobytes()
logger.debug(f"Encoded JPEG for camera {camera_id}: quality={quality}, size={len(jpeg_bytes)} bytes")
return jpeg_bytes
return None
except Exception as e:
logger.error(f"Error encoding frame as JPEG for camera {camera_id}: {e}")
return None
def has_frame(self, camera_id: str) -> bool:
"""Check if a valid frame exists for the camera."""
return self.frame_buffer.has_frame(camera_id)
def clear_camera(self, camera_id: str):
"""Remove all frames and cache for a specific camera."""
self.frame_buffer.clear_camera(camera_id)
with self._cache_lock:
# Clear crop cache entries for this camera
keys_to_remove = [key for key in self._crop_cache.keys() if key.startswith(f"{camera_id}_")]
for key in keys_to_remove:
del self._crop_cache[key]
def clear_all(self):
"""Clear all stored frames and cache."""
self.frame_buffer.clear_all()
with self._cache_lock:
self._crop_cache.clear()
def get_stats(self) -> Dict[str, Any]:
"""Get comprehensive buffer and cache statistics."""
buffer_stats = self.frame_buffer.get_stats()
with self._cache_lock:
cache_stats = {
'crop_cache_entries': len(self._crop_cache),
'crop_cache_cameras': len(set(key.split('_')[0] for key in self._crop_cache.keys() if '_' in key)),
'crop_cache_memory_mb': sum(
entry['cropped_frame'].nbytes / (1024 * 1024)
for entry in self._crop_cache.values()
)
}
return {
'buffer': buffer_stats,
'cache': cache_stats,
'total_memory_mb': buffer_stats.get('total_memory_mb', 0) + cache_stats.get('crop_cache_memory_mb', 0)
}
# Global shared instances for application use
shared_frame_buffer = FrameBuffer(max_age_seconds=5)
shared_cache_buffer = CacheBuffer(max_age_seconds=10)

View file

@ -1,722 +0,0 @@
"""
Stream coordination and lifecycle management.
Optimized for 1280x720@6fps RTSP and 2560x1440 HTTP snapshots.
"""
import logging
import threading
import time
import queue
import asyncio
from typing import Dict, Set, Optional, List, Any
from dataclasses import dataclass
from collections import defaultdict
from .readers import HTTPSnapshotReader, FFmpegRTSPReader
from .buffers import shared_cache_buffer
from ..tracking.integration import TrackingPipelineIntegration
logger = logging.getLogger(__name__)
@dataclass
class StreamConfig:
"""Configuration for a stream."""
camera_id: str
rtsp_url: Optional[str] = None
snapshot_url: Optional[str] = None
snapshot_interval: int = 5000 # milliseconds
max_retries: int = 3
@dataclass
class SubscriptionInfo:
"""Information about a subscription."""
subscription_id: str
camera_id: str
stream_config: StreamConfig
created_at: float
crop_coords: Optional[tuple] = None
model_id: Optional[str] = None
model_url: Optional[str] = None
tracking_integration: Optional[TrackingPipelineIntegration] = None
class StreamManager:
"""Manages multiple camera streams with shared optimization."""
def __init__(self, max_streams: int = 10):
self.max_streams = max_streams
self._streams: Dict[str, Any] = {} # camera_id -> reader instance
self._subscriptions: Dict[str, SubscriptionInfo] = {} # subscription_id -> info
self._camera_subscribers: Dict[str, Set[str]] = defaultdict(set) # camera_id -> set of subscription_ids
self._lock = threading.RLock()
# Fair tracking queue system - per camera queues
self._tracking_queues: Dict[str, queue.Queue] = {} # camera_id -> queue
self._tracking_workers = []
self._stop_workers = threading.Event()
self._dropped_frame_counts: Dict[str, int] = {} # per-camera drop counts
# Round-robin scheduling state
self._camera_list = [] # Ordered list of active cameras
self._camera_round_robin_index = 0
self._round_robin_lock = threading.Lock()
# Start worker threads for tracking processing
num_workers = min(4, max_streams // 2 + 1) # Scale with streams
for i in range(num_workers):
worker = threading.Thread(
target=self._tracking_worker_loop,
name=f"TrackingWorker-{i}",
daemon=True
)
worker.start()
self._tracking_workers.append(worker)
logger.info(f"Started {num_workers} tracking worker threads")
def _ensure_camera_queue(self, camera_id: str):
"""Ensure a tracking queue exists for the camera."""
if camera_id not in self._tracking_queues:
self._tracking_queues[camera_id] = queue.Queue(maxsize=10) # 10 frames per camera
self._dropped_frame_counts[camera_id] = 0
with self._round_robin_lock:
if camera_id not in self._camera_list:
self._camera_list.append(camera_id)
logger.info(f"Created tracking queue for camera {camera_id}")
else:
logger.debug(f"Camera {camera_id} already has tracking queue")
def _remove_camera_queue(self, camera_id: str):
"""Remove tracking queue for a camera that's no longer active."""
if camera_id in self._tracking_queues:
# Clear any remaining items
while not self._tracking_queues[camera_id].empty():
try:
self._tracking_queues[camera_id].get_nowait()
except queue.Empty:
break
del self._tracking_queues[camera_id]
del self._dropped_frame_counts[camera_id]
with self._round_robin_lock:
if camera_id in self._camera_list:
self._camera_list.remove(camera_id)
# Reset index if needed
if self._camera_round_robin_index >= len(self._camera_list):
self._camera_round_robin_index = 0
logger.info(f"Removed tracking queue for camera {camera_id}")
def add_subscription(self, subscription_id: str, stream_config: StreamConfig,
crop_coords: Optional[tuple] = None,
model_id: Optional[str] = None,
model_url: Optional[str] = None,
tracking_integration: Optional[TrackingPipelineIntegration] = None) -> bool:
"""Add a new subscription. Returns True if successful."""
with self._lock:
if subscription_id in self._subscriptions:
logger.warning(f"Subscription {subscription_id} already exists")
return False
camera_id = stream_config.camera_id
# Create subscription info
subscription_info = SubscriptionInfo(
subscription_id=subscription_id,
camera_id=camera_id,
stream_config=stream_config,
created_at=time.time(),
crop_coords=crop_coords,
model_id=model_id,
model_url=model_url,
tracking_integration=tracking_integration
)
# Pass subscription info to tracking integration for snapshot access
if tracking_integration:
tracking_integration.set_subscription_info(subscription_info)
self._subscriptions[subscription_id] = subscription_info
self._camera_subscribers[camera_id].add(subscription_id)
# Start stream if not already running
if camera_id not in self._streams:
if len(self._streams) >= self.max_streams:
logger.error(f"Maximum streams ({self.max_streams}) reached, cannot add {camera_id}")
self._remove_subscription_internal(subscription_id)
return False
success = self._start_stream(camera_id, stream_config)
if not success:
self._remove_subscription_internal(subscription_id)
return False
else:
# Stream already exists, but ensure queue exists too
logger.info(f"Stream already exists for {camera_id}, ensuring queue exists")
self._ensure_camera_queue(camera_id)
logger.info(f"Added subscription {subscription_id} for camera {camera_id} "
f"({len(self._camera_subscribers[camera_id])} total subscribers)")
return True
def remove_subscription(self, subscription_id: str) -> bool:
"""Remove a subscription. Returns True if found and removed."""
with self._lock:
return self._remove_subscription_internal(subscription_id)
def _remove_subscription_internal(self, subscription_id: str) -> bool:
"""Internal method to remove subscription (assumes lock is held)."""
if subscription_id not in self._subscriptions:
logger.warning(f"Subscription {subscription_id} not found")
return False
subscription_info = self._subscriptions[subscription_id]
camera_id = subscription_info.camera_id
# Remove from tracking
del self._subscriptions[subscription_id]
self._camera_subscribers[camera_id].discard(subscription_id)
# Stop stream if no more subscribers
if not self._camera_subscribers[camera_id]:
self._stop_stream(camera_id)
del self._camera_subscribers[camera_id]
logger.info(f"Removed subscription {subscription_id} for camera {camera_id} "
f"({len(self._camera_subscribers[camera_id])} remaining subscribers)")
return True
def _start_stream(self, camera_id: str, stream_config: StreamConfig) -> bool:
"""Start a stream for the given camera."""
try:
if stream_config.rtsp_url:
# RTSP stream using FFmpeg subprocess with CUDA acceleration
logger.info(f"\033[94m[RTSP] Starting {camera_id}\033[0m")
reader = FFmpegRTSPReader(
camera_id=camera_id,
rtsp_url=stream_config.rtsp_url,
max_retries=stream_config.max_retries
)
reader.set_frame_callback(self._frame_callback)
reader.start()
self._streams[camera_id] = reader
self._ensure_camera_queue(camera_id) # Create tracking queue
logger.info(f"\033[92m[RTSP] {camera_id} connected\033[0m")
elif stream_config.snapshot_url:
# HTTP snapshot stream
logger.info(f"\033[95m[HTTP] Starting {camera_id}\033[0m")
reader = HTTPSnapshotReader(
camera_id=camera_id,
snapshot_url=stream_config.snapshot_url,
interval_ms=stream_config.snapshot_interval,
max_retries=stream_config.max_retries
)
reader.set_frame_callback(self._frame_callback)
reader.start()
self._streams[camera_id] = reader
self._ensure_camera_queue(camera_id) # Create tracking queue
logger.info(f"\033[92m[HTTP] {camera_id} connected\033[0m")
else:
logger.error(f"No valid URL provided for camera {camera_id}")
return False
return True
except Exception as e:
logger.error(f"Error starting stream for camera {camera_id}: {e}")
return False
def _stop_stream(self, camera_id: str):
"""Stop a stream for the given camera."""
if camera_id in self._streams:
try:
self._streams[camera_id].stop()
del self._streams[camera_id]
self._remove_camera_queue(camera_id) # Remove tracking queue
# DON'T clear frames - they should persist until replaced
# shared_cache_buffer.clear_camera(camera_id) # REMOVED - frames should persist
logger.info(f"Stopped stream for camera {camera_id} (frames preserved in buffer)")
except Exception as e:
logger.error(f"Error stopping stream for camera {camera_id}: {e}")
def _frame_callback(self, camera_id: str, frame):
"""Callback for when a new frame is available."""
try:
# Store frame in shared buffer
shared_cache_buffer.put_frame(camera_id, frame)
# Quieter frame callback logging - only log occasionally
if hasattr(self, '_frame_log_count'):
self._frame_log_count += 1
else:
self._frame_log_count = 1
# Log every 100 frames to avoid spam
if self._frame_log_count % 100 == 0:
available_cameras = shared_cache_buffer.frame_buffer.get_camera_list()
logger.info(f"\033[96m[BUFFER] {len(available_cameras)} active cameras: {', '.join(available_cameras)}\033[0m")
# Queue for tracking processing (non-blocking) - route to camera-specific queue
if camera_id in self._tracking_queues:
try:
self._tracking_queues[camera_id].put_nowait({
'frame': frame,
'timestamp': time.time()
})
except queue.Full:
# Drop frame if camera queue is full (maintain real-time)
self._dropped_frame_counts[camera_id] += 1
if self._dropped_frame_counts[camera_id] % 50 == 0:
logger.warning(f"Dropped {self._dropped_frame_counts[camera_id]} frames for camera {camera_id} due to full queue")
except Exception as e:
logger.error(f"Error in frame callback for camera {camera_id}: {e}")
def _process_tracking_for_camera(self, camera_id: str, frame):
"""Process tracking for all subscriptions of a camera."""
try:
with self._lock:
for subscription_id in self._camera_subscribers[camera_id]:
subscription_info = self._subscriptions[subscription_id]
# Skip if no tracking integration
if not subscription_info.tracking_integration:
continue
# Extract display_id from subscription_id
display_id = subscription_id.split(';')[0] if ';' in subscription_id else subscription_id
# Process frame through tracking asynchronously
# Note: This is synchronous for now, can be made async in future
try:
# Create a simple asyncio event loop for this frame
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(
subscription_info.tracking_integration.process_frame(
frame, display_id, subscription_id
)
)
# Log tracking results
if result:
tracked_count = len(result.get('tracked_vehicles', []))
validated_vehicle = result.get('validated_vehicle')
pipeline_result = result.get('pipeline_result')
if tracked_count > 0:
logger.info(f"[Tracking] {camera_id}: {tracked_count} vehicles tracked")
if validated_vehicle:
logger.info(f"[Tracking] {camera_id}: Vehicle {validated_vehicle['track_id']} "
f"validated as {validated_vehicle['state']} "
f"(confidence: {validated_vehicle['confidence']:.2f})")
if pipeline_result:
logger.info(f"[Pipeline] {camera_id}: {pipeline_result.get('status', 'unknown')} - "
f"{pipeline_result.get('message', 'no message')}")
finally:
loop.close()
except Exception as track_e:
logger.error(f"Error in tracking for {subscription_id}: {track_e}")
except Exception as e:
logger.error(f"Error processing tracking for camera {camera_id}: {e}")
def _tracking_worker_loop(self):
"""Worker thread loop for round-robin processing of camera queues."""
logger.info(f"Tracking worker {threading.current_thread().name} started")
consecutive_empty = 0
max_consecutive_empty = 10 # Sleep if all cameras empty this many times
while not self._stop_workers.is_set():
try:
# Get next camera in round-robin fashion
camera_id, item = self._get_next_camera_item()
if camera_id is None:
# No cameras have items, sleep briefly
consecutive_empty += 1
if consecutive_empty >= max_consecutive_empty:
time.sleep(0.1) # Sleep 100ms if nothing to process
consecutive_empty = 0
continue
consecutive_empty = 0 # Reset counter when we find work
frame = item['frame']
timestamp = item['timestamp']
# Check if frame is too old (drop if > 1 second old)
age = time.time() - timestamp
if age > 1.0:
logger.debug(f"Dropping old frame for {camera_id} (age: {age:.2f}s)")
continue
# Process tracking for this camera's frame
self._process_tracking_for_camera_sync(camera_id, frame)
except Exception as e:
logger.error(f"Error in tracking worker: {e}", exc_info=True)
logger.info(f"Tracking worker {threading.current_thread().name} stopped")
def _get_next_camera_item(self):
"""Get next item from camera queues using round-robin scheduling."""
with self._round_robin_lock:
# Get current list of cameras from actual tracking queues (central state)
camera_list = list(self._tracking_queues.keys())
if not camera_list:
return None, None
attempts = 0
max_attempts = len(camera_list)
while attempts < max_attempts:
# Get current camera using round-robin index
if self._camera_round_robin_index >= len(camera_list):
self._camera_round_robin_index = 0
camera_id = camera_list[self._camera_round_robin_index]
# Move to next camera for next call
self._camera_round_robin_index = (self._camera_round_robin_index + 1) % len(camera_list)
# Try to get item from this camera's queue
try:
item = self._tracking_queues[camera_id].get_nowait()
return camera_id, item
except queue.Empty:
pass # Try next camera
attempts += 1
return None, None # All cameras empty
def _process_tracking_for_camera_sync(self, camera_id: str, frame):
"""Synchronous version of tracking processing for worker threads."""
try:
with self._lock:
subscription_ids = list(self._camera_subscribers.get(camera_id, []))
for subscription_id in subscription_ids:
subscription_info = self._subscriptions.get(subscription_id)
if not subscription_info:
logger.warning(f"No subscription info found for {subscription_id}")
continue
if not subscription_info.tracking_integration:
logger.debug(f"No tracking integration for {subscription_id} (camera {camera_id}), skipping inference")
continue
display_id = subscription_id.split(';')[0] if ';' in subscription_id else subscription_id
try:
# Run async tracking in thread's event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(
subscription_info.tracking_integration.process_frame(
frame, display_id, subscription_id
)
)
# Log tracking results
if result:
tracked_count = len(result.get('tracked_vehicles', []))
validated_vehicle = result.get('validated_vehicle')
pipeline_result = result.get('pipeline_result')
if tracked_count > 0:
logger.info(f"[Tracking] {camera_id}: {tracked_count} vehicles tracked")
if validated_vehicle:
logger.info(f"[Tracking] {camera_id}: Vehicle {validated_vehicle['track_id']} "
f"validated as {validated_vehicle['state']} "
f"(confidence: {validated_vehicle['confidence']:.2f})")
if pipeline_result:
logger.info(f"[Pipeline] {camera_id}: {pipeline_result.get('status', 'unknown')} - "
f"{pipeline_result.get('message', 'no message')}")
finally:
loop.close()
except Exception as track_e:
logger.error(f"Error in tracking for {subscription_id}: {track_e}")
except Exception as e:
logger.error(f"Error processing tracking for camera {camera_id}: {e}")
def get_frame(self, camera_id: str, crop_coords: Optional[tuple] = None):
"""Get the latest frame for a camera with optional cropping."""
return shared_cache_buffer.get_frame(camera_id, crop_coords)
def get_frame_as_jpeg(self, camera_id: str, crop_coords: Optional[tuple] = None,
quality: int = 100) -> Optional[bytes]:
"""Get frame as JPEG bytes for HTTP responses with highest quality by default."""
return shared_cache_buffer.get_frame_as_jpeg(camera_id, crop_coords, quality)
def has_frame(self, camera_id: str) -> bool:
"""Check if a frame is available for the camera."""
return shared_cache_buffer.has_frame(camera_id)
def get_subscription_info(self, subscription_id: str) -> Optional[SubscriptionInfo]:
"""Get information about a subscription."""
with self._lock:
return self._subscriptions.get(subscription_id)
def get_camera_subscribers(self, camera_id: str) -> Set[str]:
"""Get all subscription IDs for a camera."""
with self._lock:
return self._camera_subscribers[camera_id].copy()
def get_active_cameras(self) -> List[str]:
"""Get list of cameras with active streams."""
with self._lock:
return list(self._streams.keys())
def get_all_subscriptions(self) -> List[SubscriptionInfo]:
"""Get all active subscriptions."""
with self._lock:
return list(self._subscriptions.values())
def reconcile_subscriptions(self, target_subscriptions: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Reconcile current subscriptions with target list.
Returns summary of changes made.
"""
with self._lock:
current_subscription_ids = set(self._subscriptions.keys())
target_subscription_ids = {sub['subscriptionIdentifier'] for sub in target_subscriptions}
# Find subscriptions to remove and add
to_remove = current_subscription_ids - target_subscription_ids
to_add = target_subscription_ids - current_subscription_ids
# Remove old subscriptions
removed_count = 0
for subscription_id in to_remove:
if self._remove_subscription_internal(subscription_id):
removed_count += 1
# Add new subscriptions
added_count = 0
failed_count = 0
for target_sub in target_subscriptions:
subscription_id = target_sub['subscriptionIdentifier']
if subscription_id in to_add:
success = self._add_subscription_from_payload(subscription_id, target_sub)
if success:
added_count += 1
else:
failed_count += 1
result = {
'removed': removed_count,
'added': added_count,
'failed': failed_count,
'total_active': len(self._subscriptions),
'active_streams': len(self._streams)
}
logger.info(f"Subscription reconciliation: {result}")
return result
def _add_subscription_from_payload(self, subscription_id: str, payload: Dict[str, Any]) -> bool:
"""Add subscription from WebSocket payload format."""
try:
# Extract camera ID from subscription identifier
# Format: "display-001;cam-001" -> camera_id = "cam-001"
camera_id = subscription_id.split(';')[-1]
# Extract crop coordinates if present
crop_coords = None
if all(key in payload for key in ['cropX1', 'cropY1', 'cropX2', 'cropY2']):
crop_coords = (
payload['cropX1'],
payload['cropY1'],
payload['cropX2'],
payload['cropY2']
)
# Create stream configuration
stream_config = StreamConfig(
camera_id=camera_id,
rtsp_url=payload.get('rtspUrl'),
snapshot_url=payload.get('snapshotUrl'),
snapshot_interval=payload.get('snapshotInterval', 5000),
max_retries=3,
)
return self.add_subscription(
subscription_id,
stream_config,
crop_coords,
model_id=payload.get('modelId'),
model_url=payload.get('modelUrl')
)
except Exception as e:
logger.error(f"Error adding subscription from payload {subscription_id}: {e}")
return False
def stop_all(self):
"""Stop all streams and clear all subscriptions."""
# Signal workers to stop
self._stop_workers.set()
# Clear all camera queues
for camera_id, camera_queue in list(self._tracking_queues.items()):
while not camera_queue.empty():
try:
camera_queue.get_nowait()
except queue.Empty:
break
# Wait for workers to finish
for worker in self._tracking_workers:
worker.join(timeout=2.0)
# Clear queue management structures
self._tracking_queues.clear()
self._dropped_frame_counts.clear()
with self._round_robin_lock:
self._camera_list.clear()
self._camera_round_robin_index = 0
logger.info("Stopped all tracking worker threads")
with self._lock:
# Stop all streams
for camera_id in list(self._streams.keys()):
self._stop_stream(camera_id)
# Clear all tracking
self._subscriptions.clear()
self._camera_subscribers.clear()
shared_cache_buffer.clear_all()
logger.info("Stopped all streams and cleared all subscriptions")
def set_session_id(self, display_id: str, session_id: str):
"""Set session ID for tracking integration."""
# Ensure session_id is always a string for consistent type handling
session_id = str(session_id) if session_id is not None else None
with self._lock:
for subscription_info in self._subscriptions.values():
# Check if this subscription matches the display_id
subscription_display_id = subscription_info.subscription_id.split(';')[0]
if subscription_display_id == display_id and subscription_info.tracking_integration:
# Pass the full subscription_id (displayId;cameraId) to the tracking integration
subscription_info.tracking_integration.set_session_id(
display_id,
session_id,
subscription_id=subscription_info.subscription_id
)
logger.debug(f"Set session {session_id} for display {display_id} with subscription {subscription_info.subscription_id}")
def clear_session_id(self, session_id: str):
"""Clear session ID from the specific tracking integration handling this session."""
with self._lock:
# Find the subscription that's handling this session
session_subscription = None
for subscription_info in self._subscriptions.values():
if subscription_info.tracking_integration:
# Check if this integration is handling the given session_id
integration = subscription_info.tracking_integration
if session_id in integration.session_vehicles:
session_subscription = subscription_info
break
if session_subscription and session_subscription.tracking_integration:
session_subscription.tracking_integration.clear_session_id(session_id)
logger.debug(f"Cleared session {session_id} from subscription {session_subscription.subscription_id}")
else:
logger.warning(f"No tracking integration found for session {session_id}, broadcasting to all subscriptions")
# Fallback: broadcast to all (original behavior)
for subscription_info in self._subscriptions.values():
if subscription_info.tracking_integration:
subscription_info.tracking_integration.clear_session_id(session_id)
def set_progression_stage(self, session_id: str, stage: str):
"""Set progression stage for the specific tracking integration handling this session."""
with self._lock:
# Find the subscription that's handling this session
session_subscription = None
for subscription_info in self._subscriptions.values():
if subscription_info.tracking_integration:
# Check if this integration is handling the given session_id
# We need to check the integration's active sessions
integration = subscription_info.tracking_integration
if session_id in integration.session_vehicles:
session_subscription = subscription_info
break
if session_subscription and session_subscription.tracking_integration:
session_subscription.tracking_integration.set_progression_stage(session_id, stage)
logger.debug(f"Set progression stage for session {session_id}: {stage} on subscription {session_subscription.subscription_id}")
else:
logger.warning(f"No tracking integration found for session {session_id}, broadcasting to all subscriptions")
# Fallback: broadcast to all (original behavior)
for subscription_info in self._subscriptions.values():
if subscription_info.tracking_integration:
subscription_info.tracking_integration.set_progression_stage(session_id, stage)
def get_tracking_stats(self) -> Dict[str, Any]:
"""Get tracking statistics from all subscriptions."""
stats = {}
with self._lock:
for subscription_id, subscription_info in self._subscriptions.items():
if subscription_info.tracking_integration:
stats[subscription_id] = subscription_info.tracking_integration.get_statistics()
return stats
def get_stats(self) -> Dict[str, Any]:
"""Get comprehensive streaming statistics."""
with self._lock:
buffer_stats = shared_cache_buffer.get_stats()
tracking_stats = self.get_tracking_stats()
return {
'active_subscriptions': len(self._subscriptions),
'active_streams': len(self._streams),
'cameras_with_subscribers': len(self._camera_subscribers),
'max_streams': self.max_streams,
'subscriptions_by_camera': {
camera_id: len(subscribers)
for camera_id, subscribers in self._camera_subscribers.items()
},
'buffer_stats': buffer_stats,
'tracking_stats': tracking_stats,
'memory_usage_mb': buffer_stats.get('total_memory_mb', 0)
}
# Global shared instance for application use
# Default initialization, will be updated with config value in app.py
shared_stream_manager = StreamManager(max_streams=20)
def initialize_stream_manager(max_streams: int = 10):
"""Re-initialize the global stream manager with config value."""
global shared_stream_manager
# Release old manager if exists
if shared_stream_manager:
try:
# Stop all existing streams gracefully
shared_stream_manager.stop_all()
except Exception as e:
logger.warning(f"Error stopping previous stream manager: {e}")
shared_stream_manager = StreamManager(max_streams=max_streams)
return shared_stream_manager

View file

@ -1,18 +0,0 @@
"""
Stream readers for RTSP and HTTP camera feeds.
"""
from .base import VideoReader
from .ffmpeg_rtsp import FFmpegRTSPReader
from .http_snapshot import HTTPSnapshotReader
from .utils import log_success, log_warning, log_error, log_info, Colors
__all__ = [
'VideoReader',
'FFmpegRTSPReader',
'HTTPSnapshotReader',
'log_success',
'log_warning',
'log_error',
'log_info',
'Colors'
]

View file

@ -1,65 +0,0 @@
"""
Abstract base class for video stream readers.
"""
from abc import ABC, abstractmethod
from typing import Optional, Callable
import numpy as np
class VideoReader(ABC):
"""Abstract base class for video stream readers."""
def __init__(self, camera_id: str, source_url: str, max_retries: int = 3):
"""
Initialize the video reader.
Args:
camera_id: Unique identifier for the camera
source_url: URL or path to the video source
max_retries: Maximum number of retry attempts
"""
self.camera_id = camera_id
self.source_url = source_url
self.max_retries = max_retries
self.frame_callback: Optional[Callable[[str, np.ndarray], None]] = None
@abstractmethod
def start(self) -> None:
"""Start the video reader."""
pass
@abstractmethod
def stop(self) -> None:
"""Stop the video reader."""
pass
@abstractmethod
def set_frame_callback(self, callback: Callable[[str, np.ndarray], None]) -> None:
"""
Set callback function to handle captured frames.
Args:
callback: Function that takes (camera_id, frame) as arguments
"""
pass
@property
@abstractmethod
def is_running(self) -> bool:
"""Check if the reader is currently running."""
pass
@property
@abstractmethod
def reader_type(self) -> str:
"""Get the type of reader (e.g., 'rtsp', 'http_snapshot')."""
pass
def __enter__(self):
"""Context manager entry."""
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.stop()

View file

@ -1,436 +0,0 @@
"""
FFmpeg RTSP stream reader using subprocess piping frames directly to buffer.
Enhanced with comprehensive health monitoring and automatic recovery.
"""
import cv2
import time
import threading
import numpy as np
import subprocess
import struct
from typing import Optional, Callable, Dict, Any
from .base import VideoReader
from .utils import log_success, log_warning, log_error, log_info
from ...monitoring.stream_health import stream_health_tracker
from ...monitoring.thread_health import thread_health_monitor
from ...monitoring.recovery import recovery_manager, RecoveryAction
class FFmpegRTSPReader(VideoReader):
"""RTSP stream reader using subprocess FFmpeg piping frames directly to buffer."""
def __init__(self, camera_id: str, rtsp_url: str, max_retries: int = 3):
super().__init__(camera_id, rtsp_url, max_retries)
self.rtsp_url = rtsp_url
self.process = None
self.stop_event = threading.Event()
self.thread = None
self.stderr_thread = None
# Expected stream specs (for reference, actual dimensions read from PPM header)
self.width = 1280
self.height = 720
# Watchdog timers for stream reliability
self.process_start_time = None
self.last_frame_time = None
self.is_restart = False # Track if this is a restart (shorter timeout)
self.first_start_timeout = 30.0 # 30s timeout on first start
self.restart_timeout = 15.0 # 15s timeout after restart
# Health monitoring setup
self.last_heartbeat = time.time()
self.consecutive_errors = 0
self.ffmpeg_restart_count = 0
# Register recovery handlers
recovery_manager.register_recovery_handler(
RecoveryAction.RESTART_STREAM,
self._handle_restart_recovery
)
recovery_manager.register_recovery_handler(
RecoveryAction.RECONNECT,
self._handle_reconnect_recovery
)
@property
def is_running(self) -> bool:
"""Check if the reader is currently running."""
return self.thread is not None and self.thread.is_alive()
@property
def reader_type(self) -> str:
"""Get the type of reader."""
return "rtsp_ffmpeg"
def set_frame_callback(self, callback: Callable[[str, np.ndarray], None]):
"""Set callback function to handle captured frames."""
self.frame_callback = callback
def start(self):
"""Start the FFmpeg subprocess reader."""
if self.thread and self.thread.is_alive():
log_warning(self.camera_id, "FFmpeg reader already running")
return
self.stop_event.clear()
self.thread = threading.Thread(target=self._read_frames, daemon=True)
self.thread.start()
# Register with health monitoring
stream_health_tracker.register_stream(self.camera_id, "rtsp_ffmpeg", self.rtsp_url)
thread_health_monitor.register_thread(self.thread, self._heartbeat_callback)
log_success(self.camera_id, "Stream started with health monitoring")
def stop(self):
"""Stop the FFmpeg subprocess reader."""
self.stop_event.set()
# Unregister from health monitoring
if self.thread:
thread_health_monitor.unregister_thread(self.thread.ident)
if self.process:
self.process.terminate()
try:
self.process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.process.kill()
if self.thread:
self.thread.join(timeout=5.0)
if self.stderr_thread:
self.stderr_thread.join(timeout=2.0)
stream_health_tracker.unregister_stream(self.camera_id)
log_info(self.camera_id, "Stream stopped")
def _start_ffmpeg_process(self):
"""Start FFmpeg subprocess outputting BMP frames to stdout pipe."""
cmd = [
'ffmpeg',
# DO NOT REMOVE
'-hwaccel', 'cuda',
'-hwaccel_device', '0',
# Real-time input flags
'-fflags', 'nobuffer+genpts',
'-flags', 'low_delay',
'-max_delay', '0', # No reordering delay
# RTSP configuration
'-rtsp_transport', 'tcp',
'-i', self.rtsp_url,
# Output configuration (keeping BMP)
'-f', 'image2pipe', # Output images to pipe
'-vcodec', 'bmp', # BMP format with header containing dimensions
'-vsync', 'passthrough', # Pass frames as-is
# Use native stream resolution and framerate
'-an', # No audio
'-' # Output to stdout
]
try:
# Start FFmpeg with stdout pipe to read frames directly
self.process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE, # Capture stdout for frame data
stderr=subprocess.PIPE, # Capture stderr for error logging
bufsize=0 # Unbuffered for real-time processing
)
# Start stderr reading thread
if self.stderr_thread and self.stderr_thread.is_alive():
# Stop previous stderr thread
try:
self.stderr_thread.join(timeout=1.0)
except:
pass
self.stderr_thread = threading.Thread(target=self._read_stderr, daemon=True)
self.stderr_thread.start()
# Set process start time for watchdog
self.process_start_time = time.time()
self.last_frame_time = None # Reset frame time
# After successful restart, next timeout will be back to 30s
if self.is_restart:
log_info(self.camera_id, f"FFmpeg restarted successfully, next timeout: {self.first_start_timeout}s")
self.is_restart = False
return True
except Exception as e:
log_error(self.camera_id, f"FFmpeg startup failed: {e}")
return False
def _read_bmp_frame(self, pipe):
"""Read BMP frame from pipe - BMP header contains dimensions."""
try:
# Read BMP header (14 bytes file header + 40 bytes info header = 54 bytes minimum)
header_data = b''
bytes_to_read = 54
while len(header_data) < bytes_to_read:
chunk = pipe.read(bytes_to_read - len(header_data))
if not chunk:
return None # Silent end of stream
header_data += chunk
# Parse BMP header
if header_data[:2] != b'BM':
return None # Invalid format, skip frame silently
# Extract file size from header (bytes 2-5)
file_size = struct.unpack('<L', header_data[2:6])[0]
# Extract width and height from info header (bytes 18-21 and 22-25)
width = struct.unpack('<L', header_data[18:22])[0]
height = struct.unpack('<L', header_data[22:26])[0]
# Read remaining file data
remaining_size = file_size - 54
remaining_data = b''
while len(remaining_data) < remaining_size:
chunk = pipe.read(remaining_size - len(remaining_data))
if not chunk:
return None # Stream ended silently
remaining_data += chunk
# Complete BMP data
bmp_data = header_data + remaining_data
# Use OpenCV to decode BMP directly from memory
frame_array = np.frombuffer(bmp_data, dtype=np.uint8)
frame = cv2.imdecode(frame_array, cv2.IMREAD_COLOR)
if frame is None:
return None # Decode failed silently
return frame
except Exception:
return None # Error reading frame silently
def _read_stderr(self):
"""Read and log FFmpeg stderr output in background thread."""
if not self.process or not self.process.stderr:
return
try:
while self.process and self.process.poll() is None:
try:
line = self.process.stderr.readline()
if line:
error_msg = line.decode('utf-8', errors='ignore').strip()
if error_msg and not self.stop_event.is_set():
# Filter out common noise but log actual errors
if any(keyword in error_msg.lower() for keyword in ['error', 'failed', 'cannot', 'invalid']):
log_error(self.camera_id, f"FFmpeg: {error_msg}")
elif 'warning' in error_msg.lower():
log_warning(self.camera_id, f"FFmpeg: {error_msg}")
except Exception:
break
except Exception:
pass
def _check_watchdog_timeout(self) -> bool:
"""Check if watchdog timeout has been exceeded."""
if not self.process_start_time:
return False
current_time = time.time()
time_since_start = current_time - self.process_start_time
# Determine timeout based on whether this is a restart
timeout = self.restart_timeout if self.is_restart else self.first_start_timeout
# If no frames received yet, check against process start time
if not self.last_frame_time:
if time_since_start > timeout:
log_warning(self.camera_id, f"Watchdog timeout: No frames for {time_since_start:.1f}s (limit: {timeout}s)")
return True
else:
# Check time since last frame
time_since_frame = current_time - self.last_frame_time
if time_since_frame > timeout:
log_warning(self.camera_id, f"Watchdog timeout: No frames for {time_since_frame:.1f}s (limit: {timeout}s)")
return True
return False
def _restart_ffmpeg_process(self):
"""Restart FFmpeg process due to watchdog timeout."""
log_warning(self.camera_id, "Watchdog triggered FFmpeg restart")
# Terminate current process
if self.process:
try:
self.process.terminate()
self.process.wait(timeout=3)
except subprocess.TimeoutExpired:
self.process.kill()
except Exception:
pass
self.process = None
# Mark as restart for shorter timeout
self.is_restart = True
# Small delay before restart
time.sleep(1.0)
def _read_frames(self):
"""Read frames directly from FFmpeg stdout pipe."""
frame_count = 0
last_log_time = time.time()
while not self.stop_event.is_set():
try:
# Send heartbeat for thread health monitoring
self._send_heartbeat("reading_frames")
# Check watchdog timeout if process is running
if self.process and self.process.poll() is None:
if self._check_watchdog_timeout():
self._restart_ffmpeg_process()
continue
# Start FFmpeg if not running
if not self.process or self.process.poll() is not None:
if self.process and self.process.poll() is not None:
log_warning(self.camera_id, "Stream disconnected, reconnecting...")
stream_health_tracker.report_error(
self.camera_id,
"FFmpeg process disconnected"
)
if not self._start_ffmpeg_process():
self.consecutive_errors += 1
stream_health_tracker.report_error(
self.camera_id,
"Failed to start FFmpeg process"
)
time.sleep(5.0)
continue
# Read frames directly from FFmpeg stdout
try:
if self.process and self.process.stdout:
# Read BMP frame data
frame = self._read_bmp_frame(self.process.stdout)
if frame is None:
continue
# Update watchdog - we got a frame
self.last_frame_time = time.time()
# Reset error counter on successful frame
self.consecutive_errors = 0
# Report successful frame to health monitoring
frame_size = frame.nbytes
stream_health_tracker.report_frame_received(self.camera_id, frame_size)
# Call frame callback
if self.frame_callback:
try:
self.frame_callback(self.camera_id, frame)
except Exception as e:
stream_health_tracker.report_error(
self.camera_id,
f"Frame callback error: {e}"
)
frame_count += 1
# Log progress every 60 seconds (quieter)
current_time = time.time()
if current_time - last_log_time >= 60:
log_success(self.camera_id, f"{frame_count} frames captured ({frame.shape[1]}x{frame.shape[0]})")
last_log_time = current_time
except Exception as e:
# Process might have died, let it restart on next iteration
stream_health_tracker.report_error(
self.camera_id,
f"Frame reading error: {e}"
)
if self.process:
self.process.terminate()
self.process = None
time.sleep(1.0)
except Exception as e:
stream_health_tracker.report_error(
self.camera_id,
f"Main loop error: {e}"
)
time.sleep(1.0)
# Cleanup
if self.process:
self.process.terminate()
# Health monitoring methods
def _send_heartbeat(self, activity: str = "running"):
"""Send heartbeat to thread health monitor."""
self.last_heartbeat = time.time()
thread_health_monitor.heartbeat(activity=activity)
def _heartbeat_callback(self) -> bool:
"""Heartbeat callback for thread responsiveness testing."""
try:
# Check if thread is responsive by checking recent heartbeat
current_time = time.time()
age = current_time - self.last_heartbeat
# Thread is responsive if heartbeat is recent
return age < 30.0 # 30 second responsiveness threshold
except Exception:
return False
def _handle_restart_recovery(self, component: str, details: Dict[str, Any]) -> bool:
"""Handle restart recovery action."""
try:
log_info(self.camera_id, "Restarting FFmpeg RTSP reader for health recovery")
# Stop current instance
self.stop()
# Small delay
time.sleep(2.0)
# Restart
self.start()
# Report successful restart
stream_health_tracker.report_reconnect(self.camera_id, "health_recovery_restart")
self.ffmpeg_restart_count += 1
return True
except Exception as e:
log_error(self.camera_id, f"Failed to restart FFmpeg RTSP reader: {e}")
return False
def _handle_reconnect_recovery(self, component: str, details: Dict[str, Any]) -> bool:
"""Handle reconnect recovery action."""
try:
log_info(self.camera_id, "Reconnecting FFmpeg RTSP reader for health recovery")
# Force restart FFmpeg process
self._restart_ffmpeg_process()
# Reset error counters
self.consecutive_errors = 0
stream_health_tracker.report_reconnect(self.camera_id, "health_recovery_reconnect")
return True
except Exception as e:
log_error(self.camera_id, f"Failed to reconnect FFmpeg RTSP reader: {e}")
return False

View file

@ -1,378 +0,0 @@
"""
HTTP snapshot reader optimized for 2560x1440 (2K) high quality images.
Enhanced with comprehensive health monitoring and automatic recovery.
"""
import cv2
import logging
import time
import threading
import requests
import numpy as np
from typing import Optional, Callable, Dict, Any
from .base import VideoReader
from .utils import log_success, log_warning, log_error, log_info
from ...monitoring.stream_health import stream_health_tracker
from ...monitoring.thread_health import thread_health_monitor
from ...monitoring.recovery import recovery_manager, RecoveryAction
logger = logging.getLogger(__name__)
class HTTPSnapshotReader(VideoReader):
"""HTTP snapshot reader optimized for 2560x1440 (2K) high quality images."""
def __init__(self, camera_id: str, snapshot_url: str, interval_ms: int = 5000, max_retries: int = 3):
super().__init__(camera_id, snapshot_url, max_retries)
self.snapshot_url = snapshot_url
self.interval_ms = interval_ms
self.stop_event = threading.Event()
self.thread = None
# Expected snapshot specifications
self.expected_width = 2560
self.expected_height = 1440
self.max_file_size = 10 * 1024 * 1024 # 10MB max for 2K image
# Health monitoring setup
self.last_heartbeat = time.time()
self.consecutive_errors = 0
self.connection_test_interval = 300 # Test connection every 5 minutes
self.last_connection_test = None
# Register recovery handlers
recovery_manager.register_recovery_handler(
RecoveryAction.RESTART_STREAM,
self._handle_restart_recovery
)
recovery_manager.register_recovery_handler(
RecoveryAction.RECONNECT,
self._handle_reconnect_recovery
)
@property
def is_running(self) -> bool:
"""Check if the reader is currently running."""
return self.thread is not None and self.thread.is_alive()
@property
def reader_type(self) -> str:
"""Get the type of reader."""
return "http_snapshot"
def set_frame_callback(self, callback: Callable[[str, np.ndarray], None]):
"""Set callback function to handle captured frames."""
self.frame_callback = callback
def start(self):
"""Start the snapshot reader thread."""
if self.thread and self.thread.is_alive():
logger.warning(f"Snapshot reader for {self.camera_id} already running")
return
self.stop_event.clear()
self.thread = threading.Thread(target=self._read_snapshots, daemon=True)
self.thread.start()
# Register with health monitoring
stream_health_tracker.register_stream(self.camera_id, "http_snapshot", self.snapshot_url)
thread_health_monitor.register_thread(self.thread, self._heartbeat_callback)
logger.info(f"Started snapshot reader for camera {self.camera_id} with health monitoring")
def stop(self):
"""Stop the snapshot reader thread."""
self.stop_event.set()
# Unregister from health monitoring
if self.thread:
thread_health_monitor.unregister_thread(self.thread.ident)
self.thread.join(timeout=5.0)
stream_health_tracker.unregister_stream(self.camera_id)
logger.info(f"Stopped snapshot reader for camera {self.camera_id}")
def _read_snapshots(self):
"""Main snapshot reading loop for high quality 2K images."""
retries = 0
frame_count = 0
last_log_time = time.time()
last_connection_test = time.time()
interval_seconds = self.interval_ms / 1000.0
logger.info(f"Snapshot interval for camera {self.camera_id}: {interval_seconds}s")
while not self.stop_event.is_set():
try:
# Send heartbeat for thread health monitoring
self._send_heartbeat("fetching_snapshot")
start_time = time.time()
frame = self._fetch_snapshot()
if frame is None:
retries += 1
self.consecutive_errors += 1
# Report error to health monitoring
stream_health_tracker.report_error(
self.camera_id,
f"Failed to fetch snapshot (retry {retries}/{self.max_retries})"
)
logger.warning(f"Failed to fetch snapshot for camera {self.camera_id}, retry {retries}/{self.max_retries}")
if self.max_retries != -1 and retries > self.max_retries:
logger.error(f"Max retries reached for snapshot camera {self.camera_id}")
break
time.sleep(min(2.0, interval_seconds))
continue
# Accept any valid image dimensions - don't force specific resolution
if frame.shape[1] <= 0 or frame.shape[0] <= 0:
logger.warning(f"Camera {self.camera_id}: Invalid frame dimensions {frame.shape[1]}x{frame.shape[0]}")
stream_health_tracker.report_error(
self.camera_id,
f"Invalid frame dimensions: {frame.shape[1]}x{frame.shape[0]}"
)
continue
# Reset retry counter on successful fetch
retries = 0
self.consecutive_errors = 0
frame_count += 1
# Report successful frame to health monitoring
frame_size = frame.nbytes
stream_health_tracker.report_frame_received(self.camera_id, frame_size)
# Call frame callback
if self.frame_callback:
try:
self.frame_callback(self.camera_id, frame)
except Exception as e:
logger.error(f"Camera {self.camera_id}: Frame callback error: {e}")
stream_health_tracker.report_error(self.camera_id, f"Frame callback error: {e}")
# Periodic connection health test
current_time = time.time()
if current_time - last_connection_test >= self.connection_test_interval:
self._test_connection_health()
last_connection_test = current_time
# Log progress every 30 seconds
if current_time - last_log_time >= 30:
logger.info(f"Camera {self.camera_id}: {frame_count} snapshots processed")
last_log_time = current_time
# Wait for next interval
elapsed = time.time() - start_time
sleep_time = max(0, interval_seconds - elapsed)
if sleep_time > 0:
self.stop_event.wait(sleep_time)
except Exception as e:
logger.error(f"Error in snapshot loop for camera {self.camera_id}: {e}")
stream_health_tracker.report_error(self.camera_id, f"Snapshot loop error: {e}")
retries += 1
if self.max_retries != -1 and retries > self.max_retries:
break
time.sleep(min(2.0, interval_seconds))
logger.info(f"Snapshot reader thread ended for camera {self.camera_id}")
def _fetch_snapshot(self) -> Optional[np.ndarray]:
"""Fetch a single high quality snapshot from HTTP URL."""
try:
# Parse URL for authentication
from urllib.parse import urlparse
parsed_url = urlparse(self.snapshot_url)
headers = {
'User-Agent': 'Python-Detector-Worker/1.0',
'Accept': 'image/jpeg, image/png, image/*'
}
auth = None
if parsed_url.username and parsed_url.password:
from requests.auth import HTTPBasicAuth, HTTPDigestAuth
auth = HTTPBasicAuth(parsed_url.username, parsed_url.password)
# Reconstruct URL without credentials
clean_url = f"{parsed_url.scheme}://{parsed_url.hostname}"
if parsed_url.port:
clean_url += f":{parsed_url.port}"
clean_url += parsed_url.path
if parsed_url.query:
clean_url += f"?{parsed_url.query}"
# Try Basic Auth first
response = requests.get(clean_url, auth=auth, timeout=15, headers=headers,
stream=True, verify=False)
# If Basic Auth fails, try Digest Auth
if response.status_code == 401:
auth = HTTPDigestAuth(parsed_url.username, parsed_url.password)
response = requests.get(clean_url, auth=auth, timeout=15, headers=headers,
stream=True, verify=False)
else:
response = requests.get(self.snapshot_url, timeout=15, headers=headers,
stream=True, verify=False)
if response.status_code == 200:
# Check content size
content_length = int(response.headers.get('content-length', 0))
if content_length > self.max_file_size:
logger.warning(f"Snapshot too large for camera {self.camera_id}: {content_length} bytes")
return None
# Read content
content = response.content
# Convert to numpy array
image_array = np.frombuffer(content, np.uint8)
# Decode as high quality image
frame = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
if frame is None:
logger.error(f"Failed to decode snapshot for camera {self.camera_id}")
return None
logger.debug(f"Fetched snapshot for camera {self.camera_id}: {frame.shape[1]}x{frame.shape[0]}")
return frame
else:
logger.warning(f"HTTP {response.status_code} from {self.camera_id}")
return None
except requests.RequestException as e:
logger.error(f"Request error fetching snapshot for {self.camera_id}: {e}")
return None
except Exception as e:
logger.error(f"Error decoding snapshot for {self.camera_id}: {e}")
return None
def fetch_single_snapshot(self) -> Optional[np.ndarray]:
"""
Fetch a single high-quality snapshot on demand for pipeline processing.
This method is for one-time fetch from HTTP URL, not continuous streaming.
Returns:
High quality 2K snapshot frame or None if failed
"""
logger.info(f"[SNAPSHOT] Fetching snapshot for {self.camera_id} from {self.snapshot_url}")
# Try to fetch snapshot with retries
for attempt in range(self.max_retries):
frame = self._fetch_snapshot()
if frame is not None:
logger.info(f"[SNAPSHOT] Successfully fetched {frame.shape[1]}x{frame.shape[0]} snapshot for {self.camera_id}")
return frame
if attempt < self.max_retries - 1:
logger.warning(f"[SNAPSHOT] Attempt {attempt + 1}/{self.max_retries} failed for {self.camera_id}, retrying...")
time.sleep(0.5)
logger.error(f"[SNAPSHOT] Failed to fetch snapshot for {self.camera_id} after {self.max_retries} attempts")
return None
def _resize_maintain_aspect(self, frame: np.ndarray, target_width: int, target_height: int) -> np.ndarray:
"""Resize image while maintaining aspect ratio for high quality."""
h, w = frame.shape[:2]
aspect = w / h
target_aspect = target_width / target_height
if aspect > target_aspect:
# Image is wider
new_width = target_width
new_height = int(target_width / aspect)
else:
# Image is taller
new_height = target_height
new_width = int(target_height * aspect)
# Use INTER_LANCZOS4 for high quality downsampling
resized = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4)
# Pad to target size if needed
if new_width < target_width or new_height < target_height:
top = (target_height - new_height) // 2
bottom = target_height - new_height - top
left = (target_width - new_width) // 2
right = target_width - new_width - left
resized = cv2.copyMakeBorder(resized, top, bottom, left, right, cv2.BORDER_CONSTANT, value=[0, 0, 0])
return resized
# Health monitoring methods
def _send_heartbeat(self, activity: str = "running"):
"""Send heartbeat to thread health monitor."""
self.last_heartbeat = time.time()
thread_health_monitor.heartbeat(activity=activity)
def _heartbeat_callback(self) -> bool:
"""Heartbeat callback for thread responsiveness testing."""
try:
# Check if thread is responsive by checking recent heartbeat
current_time = time.time()
age = current_time - self.last_heartbeat
# Thread is responsive if heartbeat is recent
return age < 30.0 # 30 second responsiveness threshold
except Exception:
return False
def _test_connection_health(self):
"""Test HTTP connection health."""
try:
stream_health_tracker.test_http_connection(self.camera_id, self.snapshot_url)
except Exception as e:
logger.error(f"Error testing connection health for {self.camera_id}: {e}")
def _handle_restart_recovery(self, component: str, details: Dict[str, Any]) -> bool:
"""Handle restart recovery action."""
try:
logger.info(f"Restarting HTTP snapshot reader for {self.camera_id}")
# Stop current instance
self.stop()
# Small delay
time.sleep(2.0)
# Restart
self.start()
# Report successful restart
stream_health_tracker.report_reconnect(self.camera_id, "health_recovery_restart")
return True
except Exception as e:
logger.error(f"Failed to restart HTTP snapshot reader for {self.camera_id}: {e}")
return False
def _handle_reconnect_recovery(self, component: str, details: Dict[str, Any]) -> bool:
"""Handle reconnect recovery action."""
try:
logger.info(f"Reconnecting HTTP snapshot reader for {self.camera_id}")
# Test connection first
success = stream_health_tracker.test_http_connection(self.camera_id, self.snapshot_url)
if success:
# Reset error counters
self.consecutive_errors = 0
stream_health_tracker.report_reconnect(self.camera_id, "health_recovery_reconnect")
return True
else:
logger.warning(f"Connection test failed during recovery for {self.camera_id}")
return False
except Exception as e:
logger.error(f"Failed to reconnect HTTP snapshot reader for {self.camera_id}: {e}")
return False

View file

@ -1,38 +0,0 @@
"""
Utility functions for stream readers.
"""
import logging
import os
# Keep OpenCV errors visible but allow FFmpeg stderr logging
os.environ["OPENCV_LOG_LEVEL"] = "ERROR"
logger = logging.getLogger(__name__)
# Color codes for pretty logging
class Colors:
GREEN = '\033[92m'
YELLOW = '\033[93m'
RED = '\033[91m'
BLUE = '\033[94m'
PURPLE = '\033[95m'
CYAN = '\033[96m'
WHITE = '\033[97m'
BOLD = '\033[1m'
END = '\033[0m'
def log_success(camera_id: str, message: str):
"""Log success messages in green"""
logger.info(f"{Colors.GREEN}[{camera_id}] {message}{Colors.END}")
def log_warning(camera_id: str, message: str):
"""Log warnings in yellow"""
logger.warning(f"{Colors.YELLOW}[{camera_id}] {message}{Colors.END}")
def log_error(camera_id: str, message: str):
"""Log errors in red"""
logger.error(f"{Colors.RED}[{camera_id}] {message}{Colors.END}")
def log_info(camera_id: str, message: str):
"""Log info in cyan"""
logger.info(f"{Colors.CYAN}[{camera_id}] {message}{Colors.END}")

View file

@ -1,14 +0,0 @@
# Tracking module for vehicle tracking and validation
from .tracker import VehicleTracker, TrackedVehicle
from .validator import StableCarValidator, ValidationResult, VehicleState
from .integration import TrackingPipelineIntegration
__all__ = [
'VehicleTracker',
'TrackedVehicle',
'StableCarValidator',
'ValidationResult',
'VehicleState',
'TrackingPipelineIntegration'
]

View file

@ -1,408 +0,0 @@
"""
BoT-SORT Multi-Object Tracker with Camera Isolation
Based on BoT-SORT: Robust Associations Multi-Pedestrian Tracking
"""
import logging
import time
import numpy as np
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from scipy.optimize import linear_sum_assignment
from filterpy.kalman import KalmanFilter
import cv2
logger = logging.getLogger(__name__)
@dataclass
class TrackState:
"""Track state enumeration"""
TENTATIVE = "tentative" # New track, not confirmed yet
CONFIRMED = "confirmed" # Confirmed track
DELETED = "deleted" # Track to be deleted
class Track:
"""
Individual track representation with Kalman filter for motion prediction
"""
def __init__(self, detection, track_id: int, camera_id: str):
"""
Initialize a new track
Args:
detection: Initial detection (bbox, confidence, class)
track_id: Unique track identifier within camera
camera_id: Camera identifier
"""
self.track_id = track_id
self.camera_id = camera_id
self.state = TrackState.TENTATIVE
# Time tracking
self.start_time = time.time()
self.last_update_time = time.time()
# Appearance and motion
self.bbox = detection.bbox # [x1, y1, x2, y2]
self.confidence = detection.confidence
self.class_name = detection.class_name
# Track management
self.hit_streak = 1
self.time_since_update = 0
self.age = 1
# Kalman filter for motion prediction
self.kf = self._create_kalman_filter()
self._update_kalman_filter(detection.bbox)
# Track history
self.history = [detection.bbox]
self.max_history = 10
def _create_kalman_filter(self) -> KalmanFilter:
"""Create Kalman filter for bbox tracking (x, y, w, h, vx, vy, vw, vh)"""
kf = KalmanFilter(dim_x=8, dim_z=4)
# State transition matrix (constant velocity model)
kf.F = np.array([
[1, 0, 0, 0, 1, 0, 0, 0],
[0, 1, 0, 0, 0, 1, 0, 0],
[0, 0, 1, 0, 0, 0, 1, 0],
[0, 0, 0, 1, 0, 0, 0, 1],
[0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1]
])
# Measurement matrix (observe x, y, w, h)
kf.H = np.array([
[1, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0]
])
# Process noise
kf.Q *= 0.01
# Measurement noise
kf.R *= 10
# Initial covariance
kf.P *= 100
return kf
def _update_kalman_filter(self, bbox: List[float]):
"""Update Kalman filter with new bbox"""
# Convert [x1, y1, x2, y2] to [cx, cy, w, h]
x1, y1, x2, y2 = bbox
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
# Properly assign to column vector
self.kf.x[:4, 0] = [cx, cy, w, h]
def predict(self) -> np.ndarray:
"""Predict next position using Kalman filter"""
self.kf.predict()
# Convert back to [x1, y1, x2, y2] format
cx, cy, w, h = self.kf.x[:4, 0] # Extract from column vector
x1 = cx - w/2
y1 = cy - h/2
x2 = cx + w/2
y2 = cy + h/2
return np.array([x1, y1, x2, y2])
def update(self, detection):
"""Update track with new detection"""
self.last_update_time = time.time()
self.time_since_update = 0
self.hit_streak += 1
self.age += 1
# Update track properties
self.bbox = detection.bbox
self.confidence = detection.confidence
# Update Kalman filter
x1, y1, x2, y2 = detection.bbox
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
self.kf.update([cx, cy, w, h])
# Update history
self.history.append(detection.bbox)
if len(self.history) > self.max_history:
self.history.pop(0)
# Update state
if self.state == TrackState.TENTATIVE and self.hit_streak >= 3:
self.state = TrackState.CONFIRMED
def mark_missed(self):
"""Mark track as missed in this frame"""
self.time_since_update += 1
self.age += 1
if self.time_since_update > 5: # Delete after 5 missed frames
self.state = TrackState.DELETED
def is_confirmed(self) -> bool:
"""Check if track is confirmed"""
return self.state == TrackState.CONFIRMED
def is_deleted(self) -> bool:
"""Check if track should be deleted"""
return self.state == TrackState.DELETED
class CameraTracker:
"""
BoT-SORT tracker for a single camera
"""
def __init__(self, camera_id: str, max_disappeared: int = 10):
"""
Initialize camera tracker
Args:
camera_id: Unique camera identifier
max_disappeared: Maximum frames a track can be missed before deletion
"""
self.camera_id = camera_id
self.max_disappeared = max_disappeared
# Track management
self.tracks: Dict[int, Track] = {}
self.next_id = 1
self.frame_count = 0
logger.info(f"Initialized BoT-SORT tracker for camera {camera_id}")
def update(self, detections: List) -> List[Track]:
"""
Update tracker with new detections
Args:
detections: List of Detection objects
Returns:
List of active confirmed tracks
"""
self.frame_count += 1
# Predict all existing tracks
for track in self.tracks.values():
track.predict()
# Associate detections to tracks
matched_tracks, unmatched_detections, unmatched_tracks = self._associate(detections)
# Update matched tracks
for track_id, detection in matched_tracks:
self.tracks[track_id].update(detection)
# Mark unmatched tracks as missed
for track_id in unmatched_tracks:
self.tracks[track_id].mark_missed()
# Create new tracks for unmatched detections
for detection in unmatched_detections:
track = Track(detection, self.next_id, self.camera_id)
self.tracks[self.next_id] = track
self.next_id += 1
# Remove deleted tracks
tracks_to_remove = [tid for tid, track in self.tracks.items() if track.is_deleted()]
for tid in tracks_to_remove:
del self.tracks[tid]
# Return confirmed tracks
confirmed_tracks = [track for track in self.tracks.values() if track.is_confirmed()]
return confirmed_tracks
def _associate(self, detections: List) -> Tuple[List[Tuple[int, Any]], List[Any], List[int]]:
"""
Associate detections to existing tracks using IoU distance
Returns:
(matched_tracks, unmatched_detections, unmatched_tracks)
"""
if not detections or not self.tracks:
return [], detections, list(self.tracks.keys())
# Calculate IoU distance matrix
track_ids = list(self.tracks.keys())
cost_matrix = np.zeros((len(track_ids), len(detections)))
for i, track_id in enumerate(track_ids):
track = self.tracks[track_id]
predicted_bbox = track.predict()
for j, detection in enumerate(detections):
iou = self._calculate_iou(predicted_bbox, detection.bbox)
cost_matrix[i, j] = 1 - iou # Convert IoU to distance
# Solve assignment problem
row_indices, col_indices = linear_sum_assignment(cost_matrix)
# Filter matches by IoU threshold
iou_threshold = 0.3
matched_tracks = []
matched_detection_indices = set()
matched_track_indices = set()
for row, col in zip(row_indices, col_indices):
if cost_matrix[row, col] <= (1 - iou_threshold):
track_id = track_ids[row]
detection = detections[col]
matched_tracks.append((track_id, detection))
matched_detection_indices.add(col)
matched_track_indices.add(row)
# Find unmatched detections and tracks
unmatched_detections = [detections[i] for i in range(len(detections))
if i not in matched_detection_indices]
unmatched_tracks = [track_ids[i] for i in range(len(track_ids))
if i not in matched_track_indices]
return matched_tracks, unmatched_detections, unmatched_tracks
def _calculate_iou(self, bbox1: np.ndarray, bbox2: List[float]) -> float:
"""Calculate IoU between two bounding boxes"""
x1_1, y1_1, x2_1, y2_1 = bbox1
x1_2, y1_2, x2_2, y2_2 = bbox2
# Calculate intersection area
x1_i = max(x1_1, x1_2)
y1_i = max(y1_1, y1_2)
x2_i = min(x2_1, x2_2)
y2_i = min(y2_1, y2_2)
if x2_i <= x1_i or y2_i <= y1_i:
return 0.0
intersection = (x2_i - x1_i) * (y2_i - y1_i)
# Calculate union area
area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0.0
class MultiCameraBoTSORT:
"""
Multi-camera BoT-SORT tracker with complete camera isolation
"""
def __init__(self, trigger_classes: List[str], min_confidence: float = 0.6):
"""
Initialize multi-camera tracker
Args:
trigger_classes: List of class names to track
min_confidence: Minimum detection confidence threshold
"""
self.trigger_classes = trigger_classes
self.min_confidence = min_confidence
# Camera-specific trackers
self.camera_trackers: Dict[str, CameraTracker] = {}
logger.info(f"Initialized MultiCameraBoTSORT with classes={trigger_classes}, "
f"min_confidence={min_confidence}")
def get_or_create_tracker(self, camera_id: str) -> CameraTracker:
"""Get or create tracker for specific camera"""
if camera_id not in self.camera_trackers:
self.camera_trackers[camera_id] = CameraTracker(camera_id)
logger.info(f"Created new tracker for camera {camera_id}")
return self.camera_trackers[camera_id]
def update(self, camera_id: str, inference_result) -> List[Dict]:
"""
Update tracker for specific camera with detections
Args:
camera_id: Camera identifier
inference_result: InferenceResult with detections
Returns:
List of track information dictionaries
"""
# Filter detections by confidence and trigger classes
filtered_detections = []
if hasattr(inference_result, 'detections') and inference_result.detections:
for detection in inference_result.detections:
if (detection.confidence >= self.min_confidence and
detection.class_name in self.trigger_classes):
filtered_detections.append(detection)
# Get camera tracker and update
tracker = self.get_or_create_tracker(camera_id)
confirmed_tracks = tracker.update(filtered_detections)
# Convert tracks to output format
track_results = []
for track in confirmed_tracks:
track_results.append({
'track_id': track.track_id,
'camera_id': track.camera_id,
'bbox': track.bbox,
'confidence': track.confidence,
'class_name': track.class_name,
'hit_streak': track.hit_streak,
'age': track.age
})
return track_results
def get_statistics(self) -> Dict[str, Any]:
"""Get tracking statistics across all cameras"""
stats = {}
total_tracks = 0
for camera_id, tracker in self.camera_trackers.items():
camera_stats = {
'active_tracks': len([t for t in tracker.tracks.values() if t.is_confirmed()]),
'total_tracks': len(tracker.tracks),
'frame_count': tracker.frame_count
}
stats[camera_id] = camera_stats
total_tracks += camera_stats['active_tracks']
stats['summary'] = {
'total_cameras': len(self.camera_trackers),
'total_active_tracks': total_tracks
}
return stats
def reset_camera(self, camera_id: str):
"""Reset tracking for specific camera"""
if camera_id in self.camera_trackers:
del self.camera_trackers[camera_id]
logger.info(f"Reset tracking for camera {camera_id}")
def reset_all(self):
"""Reset all camera trackers"""
self.camera_trackers.clear()
logger.info("Reset all camera trackers")

View file

@ -1,883 +0,0 @@
"""
Tracking-Pipeline Integration Module.
Connects the tracking system with the main detection pipeline and manages the flow.
"""
import logging
import time
import uuid
from typing import Dict, Optional, Any, List, Tuple
from concurrent.futures import ThreadPoolExecutor
import asyncio
import numpy as np
from .tracker import VehicleTracker, TrackedVehicle
from .validator import StableCarValidator
from ..models.inference import YOLOWrapper
from ..models.pipeline import PipelineParser
from ..detection.pipeline import DetectionPipeline
logger = logging.getLogger(__name__)
class TrackingPipelineIntegration:
"""
Integrates vehicle tracking with the detection pipeline.
Manages tracking state transitions and pipeline execution triggers.
"""
def __init__(self, pipeline_parser: PipelineParser, model_manager: Any, model_id: int, message_sender=None):
"""
Initialize tracking-pipeline integration.
Args:
pipeline_parser: Pipeline parser with loaded configuration
model_manager: Model manager for loading models
model_id: The model ID to use for loading models
message_sender: Optional callback function for sending WebSocket messages
"""
self.pipeline_parser = pipeline_parser
self.model_manager = model_manager
self.model_id = model_id
self.message_sender = message_sender
# Store subscription info for snapshot access
self.subscription_info = None
# Initialize tracking components
tracking_config = pipeline_parser.tracking_config.__dict__ if pipeline_parser.tracking_config else {}
self.tracker = VehicleTracker(tracking_config)
self.validator = StableCarValidator()
# Tracking model
self.tracking_model: Optional[YOLOWrapper] = None
self.tracking_model_id = None
# Detection pipeline (Phase 5)
self.detection_pipeline: Optional[DetectionPipeline] = None
# Session management
self.active_sessions: Dict[str, str] = {} # display_id -> session_id
self.session_vehicles: Dict[str, int] = {} # session_id -> track_id
self.cleared_sessions: Dict[str, float] = {} # session_id -> clear_time
self.pending_vehicles: Dict[str, int] = {} # display_id -> track_id (waiting for session ID)
self.pending_processing_data: Dict[str, Dict] = {} # display_id -> processing data (waiting for session ID)
self.display_to_subscription: Dict[str, str] = {} # display_id -> subscription_id (for fallback)
# Additional validators for enhanced flow control
self.permanently_processed: Dict[str, float] = {} # "camera_id:track_id" -> process_time (never process again)
self.progression_stages: Dict[str, str] = {} # session_id -> current_stage
self.last_detection_time: Dict[str, float] = {} # display_id -> last_detection_timestamp
self.abandonment_timeout = 3.0 # seconds to wait before declaring car abandoned
# Thread pool for pipeline execution
# Increased to 8 workers to handle 8 concurrent cameras without queuing
self.executor = ThreadPoolExecutor(max_workers=8)
# Min bbox filtering configuration
# TODO: Make this configurable via pipeline.json in the future
self.min_bbox_area_percentage = 3.5 # 3.5% of frame area minimum
# Statistics
self.stats = {
'frames_processed': 0,
'vehicles_detected': 0,
'vehicles_validated': 0,
'pipelines_executed': 0,
'frontals_filtered_small': 0 # Track filtered detections
}
logger.info("TrackingPipelineIntegration initialized")
async def initialize_tracking_model(self) -> bool:
"""
Load and initialize the tracking model.
Returns:
True if successful, False otherwise
"""
try:
if not self.pipeline_parser.tracking_config:
logger.warning("No tracking configuration found in pipeline")
return False
model_file = self.pipeline_parser.tracking_config.model_file
model_id = self.pipeline_parser.tracking_config.model_id
if not model_file:
logger.warning("No tracking model file specified")
return False
# Load tracking model
logger.info(f"Loading tracking model: {model_id} ({model_file})")
self.tracking_model = self.model_manager.get_yolo_model(self.model_id, model_file)
if not self.tracking_model:
logger.error(f"Failed to load tracking model {model_file} from model {self.model_id}")
return False
self.tracking_model_id = model_id
if self.tracking_model:
logger.info(f"Tracking model {model_id} loaded successfully")
# Initialize detection pipeline (Phase 5)
await self._initialize_detection_pipeline()
return True
else:
logger.error(f"Failed to load tracking model {model_id}")
return False
except Exception as e:
logger.error(f"Error initializing tracking model: {e}", exc_info=True)
return False
async def _initialize_detection_pipeline(self) -> bool:
"""
Initialize the detection pipeline for main detection processing.
Returns:
True if successful, False otherwise
"""
try:
if not self.pipeline_parser:
logger.warning("No pipeline parser available for detection pipeline")
return False
# Create detection pipeline with message sender capability
self.detection_pipeline = DetectionPipeline(self.pipeline_parser, self.model_manager, self.model_id, self.message_sender)
# Initialize detection pipeline
if await self.detection_pipeline.initialize():
logger.info("Detection pipeline initialized successfully")
return True
else:
logger.error("Failed to initialize detection pipeline")
return False
except Exception as e:
logger.error(f"Error initializing detection pipeline: {e}", exc_info=True)
return False
async def process_frame(self,
frame: np.ndarray,
display_id: str,
subscription_id: str,
session_id: Optional[str] = None) -> Dict[str, Any]:
"""
Process a frame through tracking and potentially the detection pipeline.
Args:
frame: Input frame to process
display_id: Display identifier
subscription_id: Full subscription identifier
session_id: Optional session ID from backend
Returns:
Dictionary with processing results
"""
start_time = time.time()
result = {
'tracked_vehicles': [],
'validated_vehicle': None,
'pipeline_result': None,
'session_id': session_id,
'processing_time': 0.0
}
try:
# Update stats
self.stats['frames_processed'] += 1
# Run tracking model
if self.tracking_model:
# Run detection-only (tracking handled by our own tracker)
tracking_results = self.tracking_model.track(
frame,
confidence_threshold=self.tracker.min_confidence,
trigger_classes=self.tracker.trigger_classes,
persist=True
)
# Debug: Log raw detection results
if tracking_results and hasattr(tracking_results, 'detections'):
raw_detections = len(tracking_results.detections)
if raw_detections > 0:
class_names = [detection.class_name for detection in tracking_results.detections]
logger.debug(f"Raw detections: {raw_detections}, classes: {class_names}")
else:
logger.debug(f"No raw detections found")
else:
logger.debug(f"No tracking results or detections attribute")
# Filter out small frontal detections (neighboring pumps/distant cars)
if tracking_results and hasattr(tracking_results, 'detections'):
tracking_results = self._filter_small_frontals(tracking_results, frame)
# Process tracking results
tracked_vehicles = self.tracker.process_detections(
tracking_results,
display_id,
frame
)
# Update last detection time for abandonment detection
# Update when vehicles ARE detected, so when they leave, timestamp ages
if tracked_vehicles:
self.last_detection_time[display_id] = time.time()
logger.debug(f"Updated last_detection_time for {display_id}: {len(tracked_vehicles)} vehicles")
# Check for car abandonment (vehicle left after getting car_wait_staff stage)
await self._check_car_abandonment(display_id, subscription_id)
result['tracked_vehicles'] = [
{
'track_id': v.track_id,
'bbox': v.bbox,
'confidence': v.confidence,
'is_stable': v.is_stable,
'session_id': v.session_id
}
for v in tracked_vehicles
]
# Log tracking info periodically
if self.stats['frames_processed'] % 30 == 0: # Every 30 frames
logger.debug(f"Tracking: {len(tracked_vehicles)} vehicles, "
f"display={display_id}")
# Get stable vehicles for validation
stable_vehicles = self.tracker.get_stable_vehicles(display_id)
# Validate and potentially process stable vehicles
for vehicle in stable_vehicles:
# Check if vehicle is already processed or has session
if vehicle.processed_pipeline:
continue
# Check for session cleared (post-fueling)
if session_id and vehicle.session_id == session_id:
# Same vehicle with same session, skip
continue
# Check if this was a recently cleared session
session_cleared = False
if vehicle.session_id in self.cleared_sessions:
clear_time = self.cleared_sessions[vehicle.session_id]
if (time.time() - clear_time) < 30: # 30 second cooldown
session_cleared = True
# Skip same car after session clear or if permanently processed
if self.validator.should_skip_same_car(vehicle, session_cleared, self.permanently_processed):
continue
# Validate vehicle
validation_result = self.validator.validate_vehicle(vehicle, frame.shape)
if validation_result.is_valid and validation_result.should_process:
logger.info(f"Vehicle {vehicle.track_id} validated for processing: "
f"{validation_result.reason}")
result['validated_vehicle'] = {
'track_id': vehicle.track_id,
'state': validation_result.state.value,
'confidence': validation_result.confidence
}
# Execute detection pipeline - this will send real imageDetection when detection is found
# Mark vehicle as pending session ID assignment
self.pending_vehicles[display_id] = vehicle.track_id
logger.info(f"Vehicle {vehicle.track_id} waiting for session ID from backend")
# Execute detection pipeline (placeholder for Phase 5)
pipeline_result = await self._execute_pipeline(
frame,
vehicle,
display_id,
None, # No session ID yet
subscription_id
)
result['pipeline_result'] = pipeline_result
# No session_id in result yet - backend will provide it
self.stats['pipelines_executed'] += 1
# Only process one vehicle per frame
break
self.stats['vehicles_detected'] = len(tracked_vehicles)
self.stats['vehicles_validated'] = len(stable_vehicles)
else:
logger.warning("No tracking model available")
except Exception as e:
logger.error(f"Error in tracking pipeline: {e}", exc_info=True)
result['processing_time'] = time.time() - start_time
return result
async def _execute_pipeline(self,
frame: np.ndarray,
vehicle: TrackedVehicle,
display_id: str,
session_id: str,
subscription_id: str) -> Dict[str, Any]:
"""
Execute the main detection pipeline for a validated vehicle.
Args:
frame: Input frame
vehicle: Validated tracked vehicle
display_id: Display identifier
session_id: Session identifier
subscription_id: Full subscription identifier
Returns:
Pipeline execution results
"""
logger.info(f"Executing detection pipeline for vehicle {vehicle.track_id}, "
f"session={session_id}, display={display_id}")
try:
# Check if detection pipeline is available
if not self.detection_pipeline:
logger.warning("Detection pipeline not initialized, using fallback")
return {
'status': 'error',
'message': 'Detection pipeline not available',
'vehicle_id': vehicle.track_id,
'session_id': session_id
}
# Fetch high-quality 2K snapshot for detection phase (not RTSP frame)
# This ensures bbox coordinates match the frame used in processing phase
logger.info(f"[DETECTION PHASE] Fetching 2K snapshot for vehicle {vehicle.track_id}")
snapshot_frame = self._fetch_snapshot()
if snapshot_frame is None:
logger.warning(f"[DETECTION PHASE] Failed to fetch snapshot, falling back to RTSP frame")
snapshot_frame = frame # Fallback to RTSP if snapshot fails
else:
logger.info(f"[DETECTION PHASE] Using {snapshot_frame.shape[1]}x{snapshot_frame.shape[0]} snapshot for detection")
# Execute only the detection phase (first phase)
# This will run detection and send imageDetection message to backend
detection_result = await self.detection_pipeline.execute_detection_phase(
frame=snapshot_frame, # Use 2K snapshot instead of RTSP frame
display_id=display_id,
subscription_id=subscription_id
)
# Add vehicle information to result
detection_result['vehicle_id'] = vehicle.track_id
detection_result['vehicle_bbox'] = vehicle.bbox
detection_result['vehicle_confidence'] = vehicle.confidence
detection_result['phase'] = 'detection'
logger.info(f"Detection phase executed for vehicle {vehicle.track_id}: "
f"status={detection_result.get('status', 'unknown')}, "
f"message_sent={detection_result.get('message_sent', False)}, "
f"processing_time={detection_result.get('processing_time', 0):.3f}s")
# Store frame and detection results for processing phase
if detection_result['message_sent']:
# Store for later processing when sessionId is received
self.pending_processing_data[display_id] = {
'frame': snapshot_frame.copy(), # Store copy of 2K snapshot (not RTSP frame!)
'vehicle': vehicle,
'subscription_id': subscription_id,
'detection_result': detection_result,
'timestamp': time.time()
}
logger.info(f"Stored processing data ({snapshot_frame.shape[1]}x{snapshot_frame.shape[0]} frame) for {display_id}, waiting for sessionId from backend")
return detection_result
except Exception as e:
logger.error(f"Error executing detection pipeline: {e}", exc_info=True)
return {
'status': 'error',
'message': str(e),
'vehicle_id': vehicle.track_id,
'session_id': session_id,
'processing_time': 0.0
}
async def _execute_processing_phase(self,
processing_data: Dict[str, Any],
session_id: str,
display_id: str) -> None:
"""
Execute the processing phase after receiving sessionId from backend.
This includes branch processing and database operations.
Args:
processing_data: Stored processing data from detection phase
session_id: Session ID from backend
display_id: Display identifier
"""
try:
vehicle = processing_data['vehicle']
subscription_id = processing_data['subscription_id']
detection_result = processing_data['detection_result']
logger.info(f"Executing processing phase for session {session_id}, vehicle {vehicle.track_id}")
# Reuse the snapshot from detection phase OR fetch fresh one if detection used RTSP fallback
detection_frame = processing_data['frame']
frame_height = detection_frame.shape[0]
# Check if detection phase used 2K snapshot (height > 1000) or RTSP fallback (height = 720)
if frame_height >= 1000:
# Detection used 2K snapshot - reuse it for consistent coordinates
logger.info(f"[PROCESSING PHASE] Reusing 2K snapshot from detection phase ({detection_frame.shape[1]}x{detection_frame.shape[0]})")
frame = detection_frame
else:
# Detection used RTSP fallback - need to fetch fresh 2K snapshot
logger.warning(f"[PROCESSING PHASE] Detection used RTSP fallback ({detection_frame.shape[1]}x{detection_frame.shape[0]}), fetching fresh 2K snapshot")
frame = self._fetch_snapshot()
if frame is None:
logger.error(f"[PROCESSING PHASE] Failed to fetch snapshot and detection used RTSP - coordinate mismatch will occur!")
logger.error(f"[PROCESSING PHASE] Cannot proceed with mismatched coordinates. Aborting processing phase.")
return # Cannot process safely - bbox coordinates won't match frame resolution
else:
logger.warning(f"[PROCESSING PHASE] Fetched fresh 2K snapshot ({frame.shape[1]}x{frame.shape[0]}), but coordinates may not match exactly")
logger.warning(f"[PROCESSING PHASE] Re-running detection on fresh snapshot is recommended but not implemented yet")
# Extract detected regions from detection phase result if available
detected_regions = detection_result.get('detected_regions', {})
logger.info(f"[INTEGRATION] Passing detected_regions to processing phase: {list(detected_regions.keys())}")
# Execute processing phase with detection pipeline
if self.detection_pipeline:
processing_result = await self.detection_pipeline.execute_processing_phase(
frame=frame,
display_id=display_id,
session_id=session_id,
subscription_id=subscription_id,
detected_regions=detected_regions
)
logger.info(f"Processing phase completed for session {session_id}: "
f"status={processing_result.get('status', 'unknown')}, "
f"branches={len(processing_result.get('branch_results', {}))}, "
f"actions={len(processing_result.get('actions_executed', []))}, "
f"processing_time={processing_result.get('processing_time', 0):.3f}s")
# Update stats
self.stats['pipelines_executed'] += 1
else:
logger.error("Detection pipeline not available for processing phase")
except Exception as e:
logger.error(f"Error in processing phase for session {session_id}: {e}", exc_info=True)
def set_subscription_info(self, subscription_info):
"""
Set subscription info to access snapshot URL and other stream details.
Args:
subscription_info: SubscriptionInfo object containing stream config
"""
self.subscription_info = subscription_info
logger.debug(f"Set subscription info with snapshot_url: {subscription_info.stream_config.snapshot_url if subscription_info else None}")
def set_session_id(self, display_id: str, session_id: str, subscription_id: str = None):
"""
Set session ID for a display (from backend).
This is called when backend sends setSessionId after receiving imageDetection.
Args:
display_id: Display identifier
session_id: Session identifier
subscription_id: Subscription identifier (displayId;cameraId) - needed for fallback
"""
# Ensure session_id is always a string for consistent type handling
session_id = str(session_id) if session_id is not None else None
self.active_sessions[display_id] = session_id
# Store subscription_id for fallback usage
if subscription_id:
self.display_to_subscription[display_id] = subscription_id
logger.info(f"Set session {session_id} for display {display_id} with subscription {subscription_id}")
else:
logger.info(f"Set session {session_id} for display {display_id}")
# Check if we have a pending vehicle for this display
if display_id in self.pending_vehicles:
track_id = self.pending_vehicles[display_id]
# Mark vehicle as processed with the session ID
self.tracker.mark_processed(track_id, session_id)
self.session_vehicles[session_id] = track_id
# Mark vehicle as permanently processed (won't process again even after session clear)
# Use composite key to distinguish same track IDs across different cameras
camera_id = display_id # Using display_id as camera_id for isolation
permanent_key = f"{camera_id}:{track_id}"
self.permanently_processed[permanent_key] = time.time()
# Remove from pending
del self.pending_vehicles[display_id]
logger.info(f"Assigned session {session_id} to vehicle {track_id}, marked as permanently processed")
else:
logger.warning(f"No pending vehicle found for display {display_id} when setting session {session_id}")
# Check if we have pending processing data for this display
if display_id in self.pending_processing_data:
processing_data = self.pending_processing_data[display_id]
# Trigger the processing phase asynchronously
asyncio.create_task(self._execute_processing_phase(
processing_data=processing_data,
session_id=session_id,
display_id=display_id
))
# Remove from pending processing
del self.pending_processing_data[display_id]
logger.info(f"Triggered processing phase for session {session_id} on display {display_id}")
else:
logger.warning(f"No pending processing data found for display {display_id} when setting session {session_id}")
# FALLBACK: Execute pipeline for POS-initiated sessions
# Skip if session_id is None (no car present or car has left)
if session_id is not None:
# Use stored subscription_id instead of creating fake one
stored_subscription_id = self.display_to_subscription.get(display_id)
if stored_subscription_id:
logger.info(f"[FALLBACK] Triggering fallback pipeline for session {session_id} on display {display_id} with subscription {stored_subscription_id}")
# Trigger the fallback pipeline asynchronously with real subscription_id
asyncio.create_task(self._execute_fallback_pipeline(
display_id=display_id,
session_id=session_id,
subscription_id=stored_subscription_id
))
else:
logger.error(f"[FALLBACK] No subscription_id stored for display {display_id}, cannot execute fallback pipeline")
else:
logger.debug(f"[FALLBACK] Skipping pipeline execution for session_id=None on display {display_id}")
def clear_session_id(self, session_id: str):
"""
Clear session ID (post-fueling).
Args:
session_id: Session identifier to clear
"""
# Mark session as cleared
self.cleared_sessions[session_id] = time.time()
# Clear from tracker
self.tracker.clear_session(session_id)
# Remove from active sessions
display_to_remove = None
for display_id, sess_id in self.active_sessions.items():
if sess_id == session_id:
display_to_remove = display_id
break
if display_to_remove:
del self.active_sessions[display_to_remove]
if session_id in self.session_vehicles:
del self.session_vehicles[session_id]
logger.info(f"Cleared session {session_id}")
# Clean old cleared sessions (older than 5 minutes)
current_time = time.time()
old_sessions = [
sid for sid, clear_time in self.cleared_sessions.items()
if (current_time - clear_time) > 300
]
for sid in old_sessions:
del self.cleared_sessions[sid]
def get_session_for_display(self, display_id: str) -> Optional[str]:
"""Get active session for a display."""
return self.active_sessions.get(display_id)
def reset_tracking(self):
"""Reset all tracking state."""
self.tracker.reset_tracking()
self.active_sessions.clear()
self.session_vehicles.clear()
self.cleared_sessions.clear()
self.pending_vehicles.clear()
self.pending_processing_data.clear()
self.display_to_subscription.clear()
self.permanently_processed.clear()
self.progression_stages.clear()
self.last_detection_time.clear()
logger.info("Tracking pipeline integration reset")
def get_statistics(self) -> Dict[str, Any]:
"""Get comprehensive statistics."""
tracker_stats = self.tracker.get_statistics()
validator_stats = self.validator.get_statistics()
return {
'integration': self.stats,
'tracker': tracker_stats,
'validator': validator_stats,
'active_sessions': len(self.active_sessions),
'cleared_sessions': len(self.cleared_sessions)
}
async def _check_car_abandonment(self, display_id: str, subscription_id: str):
"""
Check if a car has abandoned the fueling process (left after getting car_wait_staff stage).
Args:
display_id: Display identifier
subscription_id: Subscription identifier
"""
current_time = time.time()
# Check all sessions in car_wait_staff stage
abandoned_sessions = []
for session_id, stage in self.progression_stages.items():
if stage == "car_wait_staff":
# Check if we have recent detections for this session's display
session_display = None
for disp_id, sess_id in self.active_sessions.items():
if sess_id == session_id:
session_display = disp_id
break
if session_display:
last_detection = self.last_detection_time.get(session_display, 0)
time_since_detection = current_time - last_detection
logger.info(f"[ABANDON CHECK] Session {session_id} (display: {session_display}): "
f"time_since_detection={time_since_detection:.1f}s, "
f"timeout={self.abandonment_timeout}s")
if time_since_detection > self.abandonment_timeout:
logger.warning(f"🚨 Car abandonment detected: session {session_id}, "
f"no detection for {time_since_detection:.1f}s")
abandoned_sessions.append(session_id)
else:
logger.debug(f"[ABANDON CHECK] Session {session_id} has no associated display")
# Send abandonment detection for each abandoned session
for session_id in abandoned_sessions:
await self._send_abandonment_detection(subscription_id, session_id)
# Remove from progression stages to avoid repeated detection
if session_id in self.progression_stages:
del self.progression_stages[session_id]
logger.info(f"[ABANDON] Removed session {session_id} from progression_stages after notification")
async def _send_abandonment_detection(self, subscription_id: str, session_id: str):
"""
Send imageDetection with null detection to indicate car abandonment.
Args:
subscription_id: Subscription identifier
session_id: Session ID of the abandoned car
"""
try:
# Import here to avoid circular imports
from ..communication.messages import create_image_detection
# Create abandonment detection message with null detection
detection_message = create_image_detection(
subscription_identifier=subscription_id,
detection_data=None, # Null detection indicates abandonment
model_id=self.model_id,
model_name=self.pipeline_parser.tracking_config.model_id if self.pipeline_parser.tracking_config else "tracking_model"
)
# Send to backend via WebSocket if sender is available
if self.message_sender:
await self.message_sender(detection_message)
logger.info(f"[CAR ABANDONMENT] Sent null detection for session {session_id}")
else:
logger.info(f"[CAR ABANDONMENT] No message sender available, would send: {detection_message}")
except Exception as e:
logger.error(f"Error sending abandonment detection: {e}", exc_info=True)
def set_progression_stage(self, session_id: str, stage: str):
"""
Set progression stage for a session (from backend setProgessionStage message).
Args:
session_id: Session identifier
stage: Progression stage (e.g., "car_wait_staff")
"""
self.progression_stages[session_id] = stage
logger.info(f"Set progression stage for session {session_id}: {stage}")
# If car reaches car_wait_staff, start monitoring for abandonment
if stage == "car_wait_staff":
logger.info(f"Started monitoring session {session_id} for car abandonment")
def _fetch_snapshot(self) -> Optional[np.ndarray]:
"""
Fetch high-quality snapshot from camera's snapshot URL.
Reusable method for both processing phase and fallback pipeline.
Returns:
Snapshot frame or None if unavailable
"""
if not (self.subscription_info and self.subscription_info.stream_config.snapshot_url):
logger.warning("[SNAPSHOT] No subscription info or snapshot URL available")
return None
try:
from ..streaming.readers import HTTPSnapshotReader
logger.info(f"[SNAPSHOT] Fetching snapshot for {self.subscription_info.camera_id}")
snapshot_reader = HTTPSnapshotReader(
camera_id=self.subscription_info.camera_id,
snapshot_url=self.subscription_info.stream_config.snapshot_url,
max_retries=3
)
frame = snapshot_reader.fetch_single_snapshot()
if frame is not None:
logger.info(f"[SNAPSHOT] Successfully fetched {frame.shape[1]}x{frame.shape[0]} snapshot")
return frame
else:
logger.warning("[SNAPSHOT] Failed to fetch snapshot")
return None
except Exception as e:
logger.error(f"[SNAPSHOT] Error fetching snapshot: {e}", exc_info=True)
return None
async def _execute_fallback_pipeline(self, display_id: str, session_id: str, subscription_id: str):
"""
Execute fallback pipeline when sessionId is received without prior detection.
This handles POS-initiated sessions where backend starts transaction before car detection.
Args:
display_id: Display identifier
session_id: Session ID from backend
subscription_id: Subscription identifier for pipeline execution
"""
try:
logger.info(f"[FALLBACK PIPELINE] Executing for session {session_id}, display {display_id}")
# Fetch fresh snapshot from camera
frame = self._fetch_snapshot()
if frame is None:
logger.error(f"[FALLBACK] Failed to fetch snapshot for session {session_id}, cannot execute pipeline")
return
logger.info(f"[FALLBACK] Using snapshot frame {frame.shape[1]}x{frame.shape[0]} for session {session_id}")
# Check if detection pipeline is available
if not self.detection_pipeline:
logger.error(f"[FALLBACK] Detection pipeline not available for session {session_id}")
return
# Execute detection phase to get detected regions
detection_result = await self.detection_pipeline.execute_detection_phase(
frame=frame,
display_id=display_id,
subscription_id=subscription_id
)
logger.info(f"[FALLBACK] Detection phase completed for session {session_id}: "
f"status={detection_result.get('status', 'unknown')}, "
f"regions={list(detection_result.get('detected_regions', {}).keys())}")
# If detection found regions, execute processing phase
detected_regions = detection_result.get('detected_regions', {})
if detected_regions:
processing_result = await self.detection_pipeline.execute_processing_phase(
frame=frame,
display_id=display_id,
session_id=session_id,
subscription_id=subscription_id,
detected_regions=detected_regions
)
logger.info(f"[FALLBACK] Processing phase completed for session {session_id}: "
f"status={processing_result.get('status', 'unknown')}, "
f"branches={len(processing_result.get('branch_results', {}))}, "
f"actions={len(processing_result.get('actions_executed', []))}")
# Update statistics
self.stats['pipelines_executed'] += 1
else:
logger.warning(f"[FALLBACK] No detections found in snapshot for session {session_id}")
except Exception as e:
logger.error(f"[FALLBACK] Error executing fallback pipeline for session {session_id}: {e}", exc_info=True)
def _filter_small_frontals(self, tracking_results, frame):
"""
Filter out frontal detections that are smaller than minimum bbox area percentage.
This prevents processing of cars from neighboring pumps that appear in camera view.
Args:
tracking_results: YOLO tracking results with detections
frame: Input frame for calculating frame area
Returns:
Modified tracking_results with small frontals removed
"""
if not hasattr(tracking_results, 'detections') or not tracking_results.detections:
return tracking_results
# Calculate frame area and minimum bbox area threshold
frame_area = frame.shape[0] * frame.shape[1] # height * width
min_bbox_area = frame_area * (self.min_bbox_area_percentage / 100.0)
# Filter detections
filtered_detections = []
filtered_count = 0
for detection in tracking_results.detections:
# Calculate detection bbox area
bbox = detection.bbox # Assuming bbox is [x1, y1, x2, y2]
bbox_area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
if bbox_area >= min_bbox_area:
# Keep detection - bbox is large enough
filtered_detections.append(detection)
else:
# Filter out small detection
filtered_count += 1
area_percentage = (bbox_area / frame_area) * 100
logger.debug(f"Filtered small frontal: area={bbox_area:.0f}px² ({area_percentage:.1f}% of frame, "
f"min required: {self.min_bbox_area_percentage}%)")
# Update tracking results with filtered detections
tracking_results.detections = filtered_detections
# Update statistics
if filtered_count > 0:
self.stats['frontals_filtered_small'] += filtered_count
logger.info(f"Filtered {filtered_count} small frontal detections, "
f"{len(filtered_detections)} remaining (total filtered: {self.stats['frontals_filtered_small']})")
return tracking_results
def cleanup(self):
"""Cleanup resources."""
self.executor.shutdown(wait=False)
self.reset_tracking()
# Cleanup detection pipeline
if self.detection_pipeline:
self.detection_pipeline.cleanup()
logger.info("Tracking pipeline integration cleaned up")

View file

@ -1,293 +0,0 @@
"""
Vehicle Tracking Module - BoT-SORT based tracking with camera isolation
Implements vehicle identification, persistence, and motion analysis using external tracker.
"""
import logging
import time
import uuid
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, field
import numpy as np
from threading import Lock
from .bot_sort_tracker import MultiCameraBoTSORT
logger = logging.getLogger(__name__)
@dataclass
class TrackedVehicle:
"""Represents a tracked vehicle with all its state information."""
track_id: int
camera_id: str
first_seen: float
last_seen: float
session_id: Optional[str] = None
display_id: Optional[str] = None
confidence: float = 0.0
bbox: Tuple[int, int, int, int] = (0, 0, 0, 0) # x1, y1, x2, y2
center: Tuple[float, float] = (0.0, 0.0)
stable_frames: int = 0
total_frames: int = 0
is_stable: bool = False
processed_pipeline: bool = False
last_position_history: List[Tuple[float, float]] = field(default_factory=list)
avg_confidence: float = 0.0
hit_streak: int = 0
age: int = 0
def update_position(self, bbox: Tuple[int, int, int, int], confidence: float):
"""Update vehicle position and confidence."""
self.bbox = bbox
self.center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
self.last_seen = time.time()
self.confidence = confidence
self.total_frames += 1
# Update confidence average
self.avg_confidence = ((self.avg_confidence * (self.total_frames - 1)) + confidence) / self.total_frames
# Maintain position history (last 10 positions)
self.last_position_history.append(self.center)
if len(self.last_position_history) > 10:
self.last_position_history.pop(0)
def calculate_stability(self) -> float:
"""Calculate stability score based on position history."""
if len(self.last_position_history) < 2:
return 0.0
# Calculate movement variance
positions = np.array(self.last_position_history)
if len(positions) < 2:
return 0.0
# Calculate standard deviation of positions
std_x = np.std(positions[:, 0])
std_y = np.std(positions[:, 1])
# Lower variance means more stable (inverse relationship)
# Normalize to 0-1 range (assuming max reasonable std is 50 pixels)
stability = max(0, 1 - (std_x + std_y) / 100)
return stability
def is_expired(self, timeout_seconds: float = 2.0) -> bool:
"""Check if vehicle tracking has expired."""
return (time.time() - self.last_seen) > timeout_seconds
class VehicleTracker:
"""
Main vehicle tracking implementation using BoT-SORT with camera isolation.
Manages continuous tracking, vehicle identification, and state persistence.
"""
def __init__(self, tracking_config: Optional[Dict] = None):
"""
Initialize the vehicle tracker.
Args:
tracking_config: Configuration from pipeline.json tracking section
"""
self.config = tracking_config or {}
self.trigger_classes = self.config.get('trigger_classes', self.config.get('triggerClasses', ['frontal']))
self.min_confidence = self.config.get('minConfidence', 0.6)
# BoT-SORT multi-camera tracker
self.bot_sort = MultiCameraBoTSORT(self.trigger_classes, self.min_confidence)
# Tracking state - maintain compatibility with existing code
self.tracked_vehicles: Dict[str, Dict[int, TrackedVehicle]] = {} # camera_id -> {track_id: vehicle}
self.lock = Lock()
# Tracking parameters
self.stability_threshold = 0.7
self.min_stable_frames = 5
self.timeout_seconds = 2.0
logger.info(f"VehicleTracker initialized with BoT-SORT: trigger_classes={self.trigger_classes}, "
f"min_confidence={self.min_confidence}")
def process_detections(self,
results: Any,
display_id: str,
frame: np.ndarray) -> List[TrackedVehicle]:
"""
Process detection results using BoT-SORT tracking.
Args:
results: Detection results (InferenceResult)
display_id: Display identifier for this stream
frame: Current frame being processed
Returns:
List of currently tracked vehicles
"""
current_time = time.time()
# Extract camera_id from display_id for tracking isolation
camera_id = display_id # Using display_id as camera_id for isolation
with self.lock:
# Update BoT-SORT tracker
track_results = self.bot_sort.update(camera_id, results)
# Ensure camera tracking dict exists
if camera_id not in self.tracked_vehicles:
self.tracked_vehicles[camera_id] = {}
# Update tracked vehicles based on BoT-SORT results
current_tracks = {}
active_tracks = []
for track_result in track_results:
track_id = track_result['track_id']
# Create or update TrackedVehicle
if track_id in self.tracked_vehicles[camera_id]:
# Update existing vehicle
vehicle = self.tracked_vehicles[camera_id][track_id]
vehicle.update_position(track_result['bbox'], track_result['confidence'])
vehicle.hit_streak = track_result['hit_streak']
vehicle.age = track_result['age']
# Update stability based on hit_streak
if vehicle.hit_streak >= self.min_stable_frames:
vehicle.is_stable = True
vehicle.stable_frames = vehicle.hit_streak
logger.debug(f"Updated track {track_id}: conf={vehicle.confidence:.2f}, "
f"stable={vehicle.is_stable}, hit_streak={vehicle.hit_streak}")
else:
# Create new vehicle
x1, y1, x2, y2 = track_result['bbox']
vehicle = TrackedVehicle(
track_id=track_id,
camera_id=camera_id,
first_seen=current_time,
last_seen=current_time,
display_id=display_id,
confidence=track_result['confidence'],
bbox=tuple(track_result['bbox']),
center=((x1 + x2) / 2, (y1 + y2) / 2),
total_frames=1,
hit_streak=track_result['hit_streak'],
age=track_result['age']
)
vehicle.last_position_history.append(vehicle.center)
logger.info(f"New vehicle tracked: ID={track_id}, camera={camera_id}, display={display_id}")
current_tracks[track_id] = vehicle
active_tracks.append(vehicle)
# Update the camera's tracked vehicles
self.tracked_vehicles[camera_id] = current_tracks
return active_tracks
def get_stable_vehicles(self, display_id: Optional[str] = None) -> List[TrackedVehicle]:
"""
Get all stable vehicles, optionally filtered by display.
Args:
display_id: Optional display ID to filter by
Returns:
List of stable tracked vehicles
"""
with self.lock:
stable = []
camera_id = display_id # Using display_id as camera_id
if camera_id in self.tracked_vehicles:
for vehicle in self.tracked_vehicles[camera_id].values():
if (vehicle.is_stable and not vehicle.is_expired(self.timeout_seconds) and
(display_id is None or vehicle.display_id == display_id)):
stable.append(vehicle)
return stable
def get_vehicle_by_session(self, session_id: str) -> Optional[TrackedVehicle]:
"""
Get a tracked vehicle by its session ID.
Args:
session_id: Session ID to look up
Returns:
Tracked vehicle if found, None otherwise
"""
with self.lock:
# Search across all cameras
for camera_vehicles in self.tracked_vehicles.values():
for vehicle in camera_vehicles.values():
if vehicle.session_id == session_id:
return vehicle
return None
def mark_processed(self, track_id: int, session_id: str):
"""
Mark a vehicle as processed through the pipeline.
Args:
track_id: Track ID of the vehicle
session_id: Session ID assigned to this vehicle
"""
with self.lock:
# Search across all cameras for the track_id
for camera_vehicles in self.tracked_vehicles.values():
if track_id in camera_vehicles:
vehicle = camera_vehicles[track_id]
vehicle.processed_pipeline = True
vehicle.session_id = session_id
logger.info(f"Marked vehicle {track_id} as processed with session {session_id}")
return
def clear_session(self, session_id: str):
"""
Clear session ID from a tracked vehicle (post-fueling).
Args:
session_id: Session ID to clear
"""
with self.lock:
# Search across all cameras
for camera_vehicles in self.tracked_vehicles.values():
for vehicle in camera_vehicles.values():
if vehicle.session_id == session_id:
logger.info(f"Clearing session {session_id} from vehicle {vehicle.track_id}")
vehicle.session_id = None
# Keep processed_pipeline=True to prevent re-processing
def reset_tracking(self):
"""Reset all tracking state."""
with self.lock:
self.tracked_vehicles.clear()
self.bot_sort.reset_all()
logger.info("Vehicle tracking state reset")
def get_statistics(self) -> Dict:
"""Get tracking statistics."""
with self.lock:
total = 0
stable = 0
processed = 0
all_confidences = []
# Aggregate stats across all cameras
for camera_vehicles in self.tracked_vehicles.values():
total += len(camera_vehicles)
for vehicle in camera_vehicles.values():
if vehicle.is_stable:
stable += 1
if vehicle.processed_pipeline:
processed += 1
all_confidences.append(vehicle.avg_confidence)
return {
'total_tracked': total,
'stable_vehicles': stable,
'processed_vehicles': processed,
'avg_confidence': np.mean(all_confidences) if all_confidences else 0.0,
'bot_sort_stats': self.bot_sort.get_statistics()
}

View file

@ -1,407 +0,0 @@
"""
Vehicle Validation Module - Stable car detection and validation logic.
Differentiates between stable (fueling) cars and passing-by vehicles.
"""
import logging
import time
import numpy as np
from typing import List, Optional, Tuple, Dict, Any
from dataclasses import dataclass
from enum import Enum
from .tracker import TrackedVehicle
logger = logging.getLogger(__name__)
class VehicleState(Enum):
"""Vehicle state classification."""
UNKNOWN = "unknown"
ENTERING = "entering"
STABLE = "stable"
LEAVING = "leaving"
PASSING_BY = "passing_by"
@dataclass
class ValidationResult:
"""Result of vehicle validation."""
is_valid: bool
state: VehicleState
confidence: float
reason: str
should_process: bool = False
track_id: Optional[int] = None
class StableCarValidator:
"""
Validates whether a tracked vehicle should be processed through the pipeline.
Updated for BoT-SORT integration: Trusts the sophisticated BoT-SORT tracking algorithm
for stability determination and focuses on business logic validation:
- Duration requirements for processing
- Confidence thresholds
- Session management and cooldowns
- Camera isolation with composite keys
"""
def __init__(self, config: Optional[Dict] = None):
"""
Initialize the validator with configuration.
Args:
config: Optional configuration dictionary
"""
self.config = config or {}
# Validation thresholds
# Optimized for 6 FPS RTSP source with 8 concurrent cameras on GPU
# GPU contention reduces effective FPS to ~3-5 per camera
# Reduced from 3.0s to 1.5s to achieve ~2.75s total validation time (was ~4.25s)
self.min_stable_duration = self.config.get('min_stable_duration', 1.5) # seconds
# Reduced from 10 to 5 to align with tracker requirement and reduce validation time
self.min_stable_frames = self.config.get('min_stable_frames', 5)
self.position_variance_threshold = self.config.get('position_variance_threshold', 25.0) # pixels
# Reduced from 0.7 to 0.45 to be more permissive under GPU load
self.min_confidence = self.config.get('min_confidence', 0.45)
self.velocity_threshold = self.config.get('velocity_threshold', 5.0) # pixels/frame
self.entering_zone_ratio = self.config.get('entering_zone_ratio', 0.3) # 30% of frame
self.leaving_zone_ratio = self.config.get('leaving_zone_ratio', 0.3)
# Frame dimensions (will be updated on first frame)
self.frame_width = 1920
self.frame_height = 1080
# History for validation
self.validation_history: Dict[int, List[VehicleState]] = {}
self.last_processed_vehicles: Dict[int, float] = {} # track_id -> last_process_time
logger.info(f"StableCarValidator initialized with min_duration={self.min_stable_duration}s, "
f"min_frames={self.min_stable_frames}, position_variance={self.position_variance_threshold}")
def update_frame_dimensions(self, width: int, height: int):
"""Update frame dimensions for zone calculations."""
self.frame_width = width
self.frame_height = height
# Commented out verbose frame dimension logging
# logger.debug(f"Updated frame dimensions: {width}x{height}")
def validate_vehicle(self, vehicle: TrackedVehicle, frame_shape: Optional[Tuple] = None) -> ValidationResult:
"""
Validate whether a tracked vehicle is stable and should be processed.
Args:
vehicle: The tracked vehicle to validate
frame_shape: Optional frame shape (height, width, channels)
Returns:
ValidationResult with validation status and reasoning
"""
# Update frame dimensions if provided
if frame_shape:
self.update_frame_dimensions(frame_shape[1], frame_shape[0])
# Initialize validation history for new vehicles
if vehicle.track_id not in self.validation_history:
self.validation_history[vehicle.track_id] = []
# Check if already processed
if vehicle.processed_pipeline:
return ValidationResult(
is_valid=False,
state=VehicleState.STABLE,
confidence=1.0,
reason="Already processed through pipeline",
should_process=False,
track_id=vehicle.track_id
)
# Check if recently processed (cooldown period)
if vehicle.track_id in self.last_processed_vehicles:
time_since_process = time.time() - self.last_processed_vehicles[vehicle.track_id]
if time_since_process < 10.0: # 10 second cooldown
return ValidationResult(
is_valid=False,
state=VehicleState.STABLE,
confidence=1.0,
reason=f"Recently processed ({time_since_process:.1f}s ago)",
should_process=False,
track_id=vehicle.track_id
)
# Determine vehicle state
state = self._determine_vehicle_state(vehicle)
# Update history
self.validation_history[vehicle.track_id].append(state)
if len(self.validation_history[vehicle.track_id]) > 20:
self.validation_history[vehicle.track_id].pop(0)
# Validate based on state
if state == VehicleState.STABLE:
return self._validate_stable_vehicle(vehicle)
elif state == VehicleState.PASSING_BY:
return ValidationResult(
is_valid=False,
state=state,
confidence=0.8,
reason="Vehicle is passing by",
should_process=False,
track_id=vehicle.track_id
)
elif state == VehicleState.ENTERING:
return ValidationResult(
is_valid=False,
state=state,
confidence=0.5,
reason="Vehicle is entering, waiting for stability",
should_process=False,
track_id=vehicle.track_id
)
elif state == VehicleState.LEAVING:
return ValidationResult(
is_valid=False,
state=state,
confidence=0.5,
reason="Vehicle is leaving",
should_process=False,
track_id=vehicle.track_id
)
else:
return ValidationResult(
is_valid=False,
state=state,
confidence=0.0,
reason="Unknown vehicle state",
should_process=False,
track_id=vehicle.track_id
)
def _determine_vehicle_state(self, vehicle: TrackedVehicle) -> VehicleState:
"""
Determine the current state of the vehicle based on BoT-SORT tracking results.
BoT-SORT provides sophisticated tracking, so we trust its stability determination
and focus on business logic validation.
Args:
vehicle: The tracked vehicle
Returns:
Current vehicle state
"""
# Trust BoT-SORT's stability determination
if vehicle.is_stable:
# Check if it's been stable long enough for processing
duration = time.time() - vehicle.first_seen
if duration >= self.min_stable_duration:
return VehicleState.STABLE
else:
return VehicleState.ENTERING
# For non-stable vehicles, use simplified state determination
if len(vehicle.last_position_history) < 2:
return VehicleState.UNKNOWN
# Calculate velocity for movement classification
velocity = self._calculate_velocity(vehicle)
# Basic movement classification
if velocity > self.velocity_threshold:
# Vehicle is moving - classify as passing by or entering/leaving
x_position = vehicle.center[0] / self.frame_width
# Simple heuristic: vehicles near edges are entering/leaving, center vehicles are passing
if x_position < 0.2 or x_position > 0.8:
return VehicleState.ENTERING
else:
return VehicleState.PASSING_BY
# Low velocity but not marked stable by tracker - likely entering
return VehicleState.ENTERING
def _validate_stable_vehicle(self, vehicle: TrackedVehicle) -> ValidationResult:
"""
Perform business logic validation of a stable vehicle.
Since BoT-SORT already determined the vehicle is stable, we focus on:
- Duration requirements for processing
- Confidence thresholds
- Business logic constraints
Args:
vehicle: The stable vehicle to validate
Returns:
Detailed validation result
"""
# Check duration (business requirement)
duration = time.time() - vehicle.first_seen
if duration < self.min_stable_duration:
return ValidationResult(
is_valid=False,
state=VehicleState.STABLE,
confidence=0.6,
reason=f"Not stable long enough ({duration:.1f}s < {self.min_stable_duration}s)",
should_process=False,
track_id=vehicle.track_id
)
# Check confidence (business requirement)
if vehicle.avg_confidence < self.min_confidence:
return ValidationResult(
is_valid=False,
state=VehicleState.STABLE,
confidence=vehicle.avg_confidence,
reason=f"Confidence too low ({vehicle.avg_confidence:.2f} < {self.min_confidence})",
should_process=False,
track_id=vehicle.track_id
)
# Trust BoT-SORT's stability determination - skip position variance check
# BoT-SORT's sophisticated tracking already ensures consistent positioning
# Simplified state history check - just ensure recent stability
if vehicle.track_id in self.validation_history:
history = self.validation_history[vehicle.track_id][-3:] # Last 3 states
stable_count = sum(1 for s in history if s == VehicleState.STABLE)
if len(history) >= 2 and stable_count == 0: # Only fail if clear instability
return ValidationResult(
is_valid=False,
state=VehicleState.STABLE,
confidence=0.7,
reason="Recent state history shows instability",
should_process=False,
track_id=vehicle.track_id
)
# All checks passed - vehicle is valid for processing
self.last_processed_vehicles[vehicle.track_id] = time.time()
return ValidationResult(
is_valid=True,
state=VehicleState.STABLE,
confidence=vehicle.avg_confidence,
reason="Vehicle is stable and ready for processing (BoT-SORT validated)",
should_process=True,
track_id=vehicle.track_id
)
def _calculate_velocity(self, vehicle: TrackedVehicle) -> float:
"""
Calculate the velocity of the vehicle based on position history.
Args:
vehicle: The tracked vehicle
Returns:
Velocity in pixels per frame
"""
if len(vehicle.last_position_history) < 2:
return 0.0
positions = np.array(vehicle.last_position_history)
if len(positions) < 2:
return 0.0
# Calculate velocity over last 3 frames
recent_positions = positions[-min(3, len(positions)):]
velocities = []
for i in range(1, len(recent_positions)):
dx = recent_positions[i][0] - recent_positions[i-1][0]
dy = recent_positions[i][1] - recent_positions[i-1][1]
velocity = np.sqrt(dx**2 + dy**2)
velocities.append(velocity)
return np.mean(velocities) if velocities else 0.0
def _calculate_position_variance(self, vehicle: TrackedVehicle) -> float:
"""
Calculate the position variance of the vehicle.
Args:
vehicle: The tracked vehicle
Returns:
Position variance in pixels
"""
if len(vehicle.last_position_history) < 2:
return 0.0
positions = np.array(vehicle.last_position_history)
variance_x = np.var(positions[:, 0])
variance_y = np.var(positions[:, 1])
return np.sqrt(variance_x + variance_y)
def should_skip_same_car(self,
vehicle: TrackedVehicle,
session_cleared: bool = False,
permanently_processed: Dict[str, float] = None) -> bool:
"""
Determine if we should skip processing for the same car after session clear.
Args:
vehicle: The tracked vehicle
session_cleared: Whether the session was recently cleared
permanently_processed: Dict of permanently processed vehicles (camera_id:track_id -> time)
Returns:
True if we should skip this vehicle
"""
# Check if this vehicle was permanently processed (never process again)
if permanently_processed:
# Create composite key using camera_id and track_id
permanent_key = f"{vehicle.camera_id}:{vehicle.track_id}"
if permanent_key in permanently_processed:
process_time = permanently_processed[permanent_key]
time_since = time.time() - process_time
logger.debug(f"Skipping permanently processed vehicle {vehicle.track_id} on camera {vehicle.camera_id} "
f"(processed {time_since:.1f}s ago)")
return True
# If vehicle has a session_id but it was cleared, skip for a period
if vehicle.session_id is None and vehicle.processed_pipeline and session_cleared:
# Check if enough time has passed since processing
if vehicle.track_id in self.last_processed_vehicles:
time_since = time.time() - self.last_processed_vehicles[vehicle.track_id]
if time_since < 30.0: # 30 second cooldown after session clear
logger.debug(f"Skipping same car {vehicle.track_id} after session clear "
f"({time_since:.1f}s since processing)")
return True
return False
def reset_vehicle(self, track_id: int):
"""
Reset validation state for a specific vehicle.
Args:
track_id: Track ID of the vehicle to reset
"""
if track_id in self.validation_history:
del self.validation_history[track_id]
if track_id in self.last_processed_vehicles:
del self.last_processed_vehicles[track_id]
logger.debug(f"Reset validation state for vehicle {track_id}")
def get_statistics(self) -> Dict:
"""Get validation statistics."""
return {
'vehicles_in_history': len(self.validation_history),
'recently_processed': len(self.last_processed_vehicles),
'state_distribution': self._get_state_distribution()
}
def _get_state_distribution(self) -> Dict[str, int]:
"""Get distribution of current vehicle states."""
distribution = {state.value: 0 for state in VehicleState}
for history in self.validation_history.values():
if history:
current_state = history[-1]
distribution[current_state.value] += 1
return distribution

View file

@ -1,214 +0,0 @@
"""
FFmpeg hardware acceleration detection and configuration
"""
import subprocess
import logging
import re
from typing import Dict, List, Optional
logger = logging.getLogger("detector_worker")
class FFmpegCapabilities:
"""Detect and configure FFmpeg hardware acceleration capabilities."""
def __init__(self):
"""Initialize FFmpeg capabilities detector."""
self.hwaccels = []
self.codecs = {}
self.nvidia_support = False
self.vaapi_support = False
self.qsv_support = False
self._detect_capabilities()
def _detect_capabilities(self):
"""Detect available hardware acceleration methods."""
try:
# Get hardware accelerators
result = subprocess.run(
['ffmpeg', '-hide_banner', '-hwaccels'],
capture_output=True, text=True, timeout=10
)
if result.returncode == 0:
self.hwaccels = [line.strip() for line in result.stdout.strip().split('\n')[1:] if line.strip()]
logger.info(f"Available FFmpeg hardware accelerators: {', '.join(self.hwaccels)}")
# Check for NVIDIA support
self.nvidia_support = any(hw in self.hwaccels for hw in ['cuda', 'cuvid', 'nvdec'])
self.vaapi_support = 'vaapi' in self.hwaccels
self.qsv_support = 'qsv' in self.hwaccels
# Get decoder information
self._detect_decoders()
# Log capabilities
if self.nvidia_support:
logger.info("NVIDIA hardware acceleration available (CUDA/CUVID/NVDEC)")
logger.info(f"Detected hardware codecs: {self.codecs}")
if self.vaapi_support:
logger.info("VAAPI hardware acceleration available")
if self.qsv_support:
logger.info("Intel QuickSync hardware acceleration available")
except Exception as e:
logger.warning(f"Failed to detect FFmpeg capabilities: {e}")
def _detect_decoders(self):
"""Detect available hardware decoders."""
try:
result = subprocess.run(
['ffmpeg', '-hide_banner', '-decoders'],
capture_output=True, text=True, timeout=10
)
if result.returncode == 0:
# Parse decoder output to find hardware decoders
for line in result.stdout.split('\n'):
if 'cuvid' in line or 'nvdec' in line:
match = re.search(r'(\w+)\s+.*?(\w+(?:_cuvid|_nvdec))', line)
if match:
codec_type, decoder = match.groups()
if 'h264' in decoder:
self.codecs['h264_hw'] = decoder
elif 'hevc' in decoder or 'h265' in decoder:
self.codecs['h265_hw'] = decoder
elif 'vaapi' in line:
match = re.search(r'(\w+)\s+.*?(\w+_vaapi)', line)
if match:
codec_type, decoder = match.groups()
if 'h264' in decoder:
self.codecs['h264_vaapi'] = decoder
except Exception as e:
logger.debug(f"Failed to detect decoders: {e}")
def get_optimal_capture_options(self, codec: str = 'h264') -> Dict[str, str]:
"""
Get optimal FFmpeg capture options for the given codec.
Args:
codec: Video codec (h264, h265, etc.)
Returns:
Dictionary of FFmpeg options
"""
options = {
'rtsp_transport': 'tcp',
'buffer_size': '1024k',
'max_delay': '500000', # 500ms
'fflags': '+genpts',
'flags': '+low_delay',
'probesize': '32',
'analyzeduration': '0'
}
# Add hardware acceleration if available
if self.nvidia_support:
# Force enable CUDA hardware acceleration for H.264 if CUDA is available
if codec == 'h264':
options.update({
'hwaccel': 'cuda',
'hwaccel_device': '0'
})
logger.info("Using NVIDIA NVDEC hardware acceleration for H.264")
elif codec == 'h265':
options.update({
'hwaccel': 'cuda',
'hwaccel_device': '0',
'video_codec': 'hevc_cuvid',
'hwaccel_output_format': 'cuda'
})
logger.info("Using NVIDIA CUVID hardware acceleration for H.265")
elif self.vaapi_support:
if codec == 'h264':
options.update({
'hwaccel': 'vaapi',
'hwaccel_device': '/dev/dri/renderD128',
'video_codec': 'h264_vaapi'
})
logger.debug("Using VAAPI hardware acceleration")
return options
def format_opencv_options(self, options: Dict[str, str]) -> str:
"""
Format options for OpenCV FFmpeg backend.
Args:
options: Dictionary of FFmpeg options
Returns:
Formatted options string for OpenCV
"""
return '|'.join(f"{key};{value}" for key, value in options.items())
def get_hardware_encoder_options(self, codec: str = 'h264', quality: str = 'fast') -> Dict[str, str]:
"""
Get optimal hardware encoding options.
Args:
codec: Video codec for encoding
quality: Quality preset (fast, medium, slow)
Returns:
Dictionary of encoding options
"""
options = {}
if self.nvidia_support:
if codec == 'h264':
options.update({
'video_codec': 'h264_nvenc',
'preset': quality,
'tune': 'zerolatency',
'gpu': '0',
'rc': 'cbr_hq',
'surfaces': '64'
})
elif codec == 'h265':
options.update({
'video_codec': 'hevc_nvenc',
'preset': quality,
'tune': 'zerolatency',
'gpu': '0'
})
elif self.vaapi_support:
if codec == 'h264':
options.update({
'video_codec': 'h264_vaapi',
'vaapi_device': '/dev/dri/renderD128'
})
return options
# Global instance
_ffmpeg_caps = None
def get_ffmpeg_capabilities() -> FFmpegCapabilities:
"""Get or create the global FFmpeg capabilities instance."""
global _ffmpeg_caps
if _ffmpeg_caps is None:
_ffmpeg_caps = FFmpegCapabilities()
return _ffmpeg_caps
def get_optimal_rtsp_options(rtsp_url: str) -> str:
"""
Get optimal OpenCV FFmpeg options for RTSP streaming.
Args:
rtsp_url: RTSP stream URL
Returns:
Formatted options string for cv2.VideoCapture
"""
caps = get_ffmpeg_capabilities()
# Detect codec from URL or assume H.264
codec = 'h265' if any(x in rtsp_url.lower() for x in ['h265', 'hevc']) else 'h264'
options = caps.get_optimal_capture_options(codec)
return caps.format_opencv_options(options)

View file

@ -1,173 +0,0 @@
"""
Hardware-accelerated image encoding using NVIDIA NVENC or Intel QuickSync
"""
import cv2
import numpy as np
import logging
from typing import Optional, Tuple
import os
logger = logging.getLogger("detector_worker")
class HardwareEncoder:
"""Hardware-accelerated JPEG encoder using GPU."""
def __init__(self):
"""Initialize hardware encoder."""
self.nvenc_available = False
self.vaapi_available = False
self.turbojpeg_available = False
# Check for TurboJPEG (fastest CPU-based option)
try:
from turbojpeg import TurboJPEG
self.turbojpeg = TurboJPEG()
self.turbojpeg_available = True
logger.info("TurboJPEG accelerated encoding available")
except ImportError:
logger.debug("TurboJPEG not available")
# Check for NVIDIA NVENC support
try:
# Test if we can create an NVENC encoder
test_frame = np.zeros((720, 1280, 3), dtype=np.uint8)
fourcc = cv2.VideoWriter_fourcc(*'H264')
test_writer = cv2.VideoWriter(
"test.mp4",
fourcc,
30,
(1280, 720),
[cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_ANY]
)
if test_writer.isOpened():
self.nvenc_available = True
logger.info("NVENC hardware encoding available")
test_writer.release()
if os.path.exists("test.mp4"):
os.remove("test.mp4")
except Exception as e:
logger.debug(f"NVENC not available: {e}")
def encode_jpeg(self, frame: np.ndarray, quality: int = 85) -> Optional[bytes]:
"""
Encode frame to JPEG using the fastest available method.
Args:
frame: BGR image frame
quality: JPEG quality (1-100)
Returns:
Encoded JPEG bytes or None on failure
"""
try:
# Method 1: TurboJPEG (3-5x faster than cv2.imencode)
if self.turbojpeg_available:
# Convert BGR to RGB for TurboJPEG
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
encoded = self.turbojpeg.encode(rgb_frame, quality=quality)
return encoded
# Method 2: Hardware-accelerated encoding via GStreamer (if available)
if self.nvenc_available:
return self._encode_with_nvenc(frame, quality)
# Fallback: Standard OpenCV encoding
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
success, encoded = cv2.imencode('.jpg', frame, encode_params)
if success:
return encoded.tobytes()
return None
except Exception as e:
logger.error(f"Failed to encode frame: {e}")
return None
def _encode_with_nvenc(self, frame: np.ndarray, quality: int) -> Optional[bytes]:
"""
Encode using NVIDIA NVENC hardware encoder.
This is complex to implement directly, so we'll use a GStreamer pipeline
if available.
"""
try:
# Create a GStreamer pipeline for hardware encoding
height, width = frame.shape[:2]
gst_pipeline = (
f"appsrc ! "
f"video/x-raw,format=BGR,width={width},height={height},framerate=30/1 ! "
f"videoconvert ! "
f"nvvideoconvert ! " # GPU color conversion
f"nvjpegenc quality={quality} ! " # Hardware JPEG encoder
f"appsink"
)
# This would require GStreamer Python bindings
# For now, fall back to TurboJPEG or standard encoding
logger.debug("NVENC JPEG encoding not fully implemented, using fallback")
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
success, encoded = cv2.imencode('.jpg', frame, encode_params)
if success:
return encoded.tobytes()
return None
except Exception as e:
logger.error(f"NVENC encoding failed: {e}")
return None
def encode_batch(self, frames: list, quality: int = 85) -> list:
"""
Batch encode multiple frames for better GPU utilization.
Args:
frames: List of BGR frames
quality: JPEG quality
Returns:
List of encoded JPEG bytes
"""
encoded_frames = []
if self.turbojpeg_available:
# TurboJPEG can handle batch encoding efficiently
for frame in frames:
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
encoded = self.turbojpeg.encode(rgb_frame, quality=quality)
encoded_frames.append(encoded)
else:
# Fallback to sequential encoding
for frame in frames:
encoded = self.encode_jpeg(frame, quality)
encoded_frames.append(encoded)
return encoded_frames
# Global encoder instance
_hardware_encoder = None
def get_hardware_encoder() -> HardwareEncoder:
"""Get or create the global hardware encoder instance."""
global _hardware_encoder
if _hardware_encoder is None:
_hardware_encoder = HardwareEncoder()
return _hardware_encoder
def encode_frame_hardware(frame: np.ndarray, quality: int = 85) -> Optional[bytes]:
"""
Convenience function to encode a frame using hardware acceleration.
Args:
frame: BGR image frame
quality: JPEG quality (1-100)
Returns:
Encoded JPEG bytes or None on failure
"""
encoder = get_hardware_encoder()
return encoder.encode_jpeg(frame, quality)

4
debug/cuda.py Normal file
View file

@ -0,0 +1,4 @@
import torch
print(torch.cuda.is_available()) # True if CUDA is available
print(torch.cuda.get_device_name(0)) # GPU name
print(torch.version.cuda) # CUDA version PyTorch was compiled with

1
feeder/note.txt Normal file
View file

@ -0,0 +1 @@
python simple_track.py --source video/sample.mp4 --show-vid --save-vid --enable-json-log

View file

21
feeder/sender/base.py Normal file
View file

@ -0,0 +1,21 @@
import numpy as np
import json
class NumpyArrayEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return super(NumpyArrayEncoder, self).default(obj)
class BasSender:
def __init__(self) -> None:
pass
def send(self, messages):
raise NotImplementedError()

View file

@ -0,0 +1,13 @@
from .base import BasSender
from loguru import logger
import json
from .base import NumpyArrayEncoder
class JsonLogger(BasSender):
def __init__(self, log_filename:str = "tracking.log") -> None:
super().__init__()
self.logger = logger
self.logger.add(log_filename, format="{message}", level="INFO")
def send(self, messages):
self.logger.info(json.dumps(messages, cls=NumpyArrayEncoder))

14
feeder/sender/szmq.py Normal file
View file

@ -0,0 +1,14 @@
from .base import BasSender, NumpyArrayEncoder
import zmq
import json
class ZmqLogger(BasSender):
def __init__(self, ip_addr:str = "localhost", port:int = 5555) -> None:
super().__init__()
self.context = zmq.Context()
self.producer = self.context.socket(zmq.PUB)
self.producer.connect(f"tcp://{ip_addr}:{port}")
def send(self, messages):
self.producer.send_string(json.dumps(messages, cls = NumpyArrayEncoder))

245
feeder/simple_track.py Normal file
View file

@ -0,0 +1,245 @@
import argparse
import cv2
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
import sys
import numpy as np
from pathlib import Path
import torch
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]
WEIGHTS = ROOT / 'weights'
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
if str(ROOT / 'trackers' / 'strongsort') not in sys.path:
sys.path.append(str(ROOT / 'trackers' / 'strongsort'))
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
from ultralytics.yolo.data.utils import VID_FORMATS
from ultralytics.yolo.utils import LOGGER, colorstr
from ultralytics.yolo.utils.checks import check_file, check_imgsz
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.torch_utils import select_device
from ultralytics.yolo.utils.ops import Profile, non_max_suppression, scale_boxes
from ultralytics.yolo.utils.plotting import Annotator, colors
from trackers.multi_tracker_zoo import create_tracker
from sender.jsonlogger import JsonLogger
from sender.szmq import ZmqLogger
@torch.no_grad()
def run(
source='0',
yolo_weights=WEIGHTS / 'yolov8n.pt',
reid_weights=WEIGHTS / 'osnet_x0_25_msmt17.pt',
imgsz=(640, 640),
conf_thres=0.7,
iou_thres=0.45,
max_det=1000,
device='',
show_vid=True,
save_vid=True,
project=ROOT / 'runs' / 'track',
name='exp',
exist_ok=False,
line_thickness=2,
hide_labels=False,
hide_conf=False,
half=False,
vid_stride=1,
enable_json_log=False,
enable_zmq=False,
zmq_ip='localhost',
zmq_port=5555,
):
source = str(source)
is_file = Path(source).suffix[1:] in (VID_FORMATS)
if is_file:
source = check_file(source)
device = select_device(device)
model = AutoBackend(yolo_weights, device=device, dnn=False, fp16=half)
stride, names, pt = model.stride, model.names, model.pt
imgsz = check_imgsz(imgsz, stride=stride)
dataset = LoadImages(
source,
imgsz=imgsz,
stride=stride,
auto=pt,
transforms=getattr(model.model, 'transforms', None),
vid_stride=vid_stride
)
bs = len(dataset)
tracking_config = ROOT / 'trackers' / 'strongsort' / 'configs' / 'strongsort.yaml'
tracker = create_tracker('strongsort', tracking_config, reid_weights, device, half)
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok)
(save_dir / 'tracks').mkdir(parents=True, exist_ok=True)
# Initialize loggers
json_logger = JsonLogger(f"{source}-strongsort.log") if enable_json_log else None
zmq_logger = ZmqLogger(zmq_ip, zmq_port) if enable_zmq else None
vid_path, vid_writer = [None] * bs, [None] * bs
dt = (Profile(), Profile(), Profile())
for frame_idx, (path, im, im0s, vid_cap, s) in enumerate(dataset):
with dt[0]:
im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float()
im /= 255.0
if len(im.shape) == 3:
im = im[None]
with dt[1]:
pred = model(im, augment=False, visualize=False)
with dt[2]:
pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=max_det)
for i, det in enumerate(pred):
seen = 0
p, im0, _ = path, im0s.copy(), dataset.count
p = Path(p)
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
if len(det):
# Filter detections for 'car' class only (class 2 in COCO dataset)
car_mask = det[:, 5] == 2 # car class index is 2
det = det[car_mask]
if len(det):
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
for *xyxy, conf, cls in reversed(det):
c = int(cls)
id = f'{c}'
label = None if hide_labels else (f'{id} {names[c]}' if hide_conf else f'{id} {names[c]} {conf:.2f}')
annotator.box_label(xyxy, label, color=colors(c, True))
t_outputs = tracker.update(det.cpu(), im0)
if len(t_outputs) > 0:
for j, (output) in enumerate(t_outputs):
bbox = output[0:4]
id = output[4]
cls = output[5]
conf = output[6]
# Log tracking data
if json_logger or zmq_logger:
track_data = {
'bbox': bbox.tolist() if hasattr(bbox, 'tolist') else list(bbox),
'id': int(id),
'cls': int(cls),
'conf': float(conf),
'frame_idx': frame_idx,
'source': source,
'class_name': names[int(cls)]
}
if json_logger:
json_logger.send(track_data)
if zmq_logger:
zmq_logger.send(track_data)
if save_vid or show_vid:
c = int(cls)
id = int(id)
label = f'{id} {names[c]}' if not hide_labels else f'{id}'
if not hide_conf:
label += f' {conf:.2f}'
annotator.box_label(bbox, label, color=colors(c, True))
im0 = annotator.result()
if show_vid:
cv2.imshow(str(p), im0)
if cv2.waitKey(1) == ord('q'):
break
if save_vid:
if vid_path[i] != str(save_dir / p.name):
vid_path[i] = str(save_dir / p.name)
if isinstance(vid_writer[i], cv2.VideoWriter):
vid_writer[i].release()
if vid_cap:
fps = vid_cap.get(cv2.CAP_PROP_FPS)
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
else:
fps, w, h = 30, im0.shape[1], im0.shape[0]
vid_writer[i] = cv2.VideoWriter(vid_path[i], cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
vid_writer[i].write(im0)
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
for i, vid_writer_obj in enumerate(vid_writer):
if isinstance(vid_writer_obj, cv2.VideoWriter):
vid_writer_obj.release()
cv2.destroyAllWindows()
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
def xyxy2xywh(x):
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
y[:, 2] = x[:, 2] - x[:, 0] # width
y[:, 3] = x[:, 3] - x[:, 1] # height
return y
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--source', type=str, default='0', help='file/dir/URL/glob, 0 for webcam')
parser.add_argument('--yolo-weights', nargs='+', type=str, default=WEIGHTS / 'yolov8n.pt', help='model path')
parser.add_argument('--reid-weights', type=str, default=WEIGHTS / 'osnet_x0_25_msmt17.pt')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
parser.add_argument('--conf-thres', type=float, default=0.7, help='confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--show-vid', action='store_true', help='display results')
parser.add_argument('--save-vid', action='store_true', help='save video tracking results')
parser.add_argument('--project', default=ROOT / 'runs' / 'track', help='save results to project/name')
parser.add_argument('--name', default='exp', help='save results to project/name')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
parser.add_argument('--line-thickness', default=2, type=int, help='bounding box thickness (pixels)')
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
parser.add_argument('--enable-json-log', action='store_true', help='enable JSON file logging')
parser.add_argument('--enable-zmq', action='store_true', help='enable ZMQ messaging')
parser.add_argument('--zmq-ip', type=str, default='localhost', help='ZMQ server IP')
parser.add_argument('--zmq-port', type=int, default=5555, help='ZMQ server port')
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1
return opt
def main(opt):
run(**vars(opt))
if __name__ == "__main__":
opt = parse_opt()
main(opt)

View file

View file

@ -0,0 +1,60 @@
import numpy as np
from collections import OrderedDict
class TrackState(object):
New = 0
Tracked = 1
Lost = 2
LongLost = 3
Removed = 4
class BaseTrack(object):
_count = 0
track_id = 0
is_activated = False
state = TrackState.New
history = OrderedDict()
features = []
curr_feature = None
score = 0
start_frame = 0
frame_id = 0
time_since_update = 0
# multi-camera
location = (np.inf, np.inf)
@property
def end_frame(self):
return self.frame_id
@staticmethod
def next_id():
BaseTrack._count += 1
return BaseTrack._count
def activate(self, *args):
raise NotImplementedError
def predict(self):
raise NotImplementedError
def update(self, *args, **kwargs):
raise NotImplementedError
def mark_lost(self):
self.state = TrackState.Lost
def mark_long_lost(self):
self.state = TrackState.LongLost
def mark_removed(self):
self.state = TrackState.Removed
@staticmethod
def clear_count():
BaseTrack._count = 0

View file

@ -0,0 +1,534 @@
import cv2
import matplotlib.pyplot as plt
import numpy as np
from collections import deque
from trackers.botsort import matching
from trackers.botsort.gmc import GMC
from trackers.botsort.basetrack import BaseTrack, TrackState
from trackers.botsort.kalman_filter import KalmanFilter
# from fast_reid.fast_reid_interfece import FastReIDInterface
from reid_multibackend import ReIDDetectMultiBackend
from ultralytics.yolo.utils.ops import xyxy2xywh, xywh2xyxy
class STrack(BaseTrack):
shared_kalman = KalmanFilter()
def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
# wait activate
self._tlwh = np.asarray(tlwh, dtype=np.float32)
self.kalman_filter = None
self.mean, self.covariance = None, None
self.is_activated = False
self.cls = -1
self.cls_hist = [] # (cls id, freq)
self.update_cls(cls, score)
self.score = score
self.tracklet_len = 0
self.smooth_feat = None
self.curr_feat = None
if feat is not None:
self.update_features(feat)
self.features = deque([], maxlen=feat_history)
self.alpha = 0.9
def update_features(self, feat):
feat /= np.linalg.norm(feat)
self.curr_feat = feat
if self.smooth_feat is None:
self.smooth_feat = feat
else:
self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat
self.features.append(feat)
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
def update_cls(self, cls, score):
if len(self.cls_hist) > 0:
max_freq = 0
found = False
for c in self.cls_hist:
if cls == c[0]:
c[1] += score
found = True
if c[1] > max_freq:
max_freq = c[1]
self.cls = c[0]
if not found:
self.cls_hist.append([cls, score])
self.cls = cls
else:
self.cls_hist.append([cls, score])
self.cls = cls
def predict(self):
mean_state = self.mean.copy()
if self.state != TrackState.Tracked:
mean_state[6] = 0
mean_state[7] = 0
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
@staticmethod
def multi_predict(stracks):
if len(stracks) > 0:
multi_mean = np.asarray([st.mean.copy() for st in stracks])
multi_covariance = np.asarray([st.covariance for st in stracks])
for i, st in enumerate(stracks):
if st.state != TrackState.Tracked:
multi_mean[i][6] = 0
multi_mean[i][7] = 0
multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
stracks[i].mean = mean
stracks[i].covariance = cov
@staticmethod
def multi_gmc(stracks, H=np.eye(2, 3)):
if len(stracks) > 0:
multi_mean = np.asarray([st.mean.copy() for st in stracks])
multi_covariance = np.asarray([st.covariance for st in stracks])
R = H[:2, :2]
R8x8 = np.kron(np.eye(4, dtype=float), R)
t = H[:2, 2]
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
mean = R8x8.dot(mean)
mean[:2] += t
cov = R8x8.dot(cov).dot(R8x8.transpose())
stracks[i].mean = mean
stracks[i].covariance = cov
def activate(self, kalman_filter, frame_id):
"""Start a new tracklet"""
self.kalman_filter = kalman_filter
self.track_id = self.next_id()
self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xywh(self._tlwh))
self.tracklet_len = 0
self.state = TrackState.Tracked
if frame_id == 1:
self.is_activated = True
self.frame_id = frame_id
self.start_frame = frame_id
def re_activate(self, new_track, frame_id, new_id=False):
self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, self.tlwh_to_xywh(new_track.tlwh))
if new_track.curr_feat is not None:
self.update_features(new_track.curr_feat)
self.tracklet_len = 0
self.state = TrackState.Tracked
self.is_activated = True
self.frame_id = frame_id
if new_id:
self.track_id = self.next_id()
self.score = new_track.score
self.update_cls(new_track.cls, new_track.score)
def update(self, new_track, frame_id):
"""
Update a matched track
:type new_track: STrack
:type frame_id: int
:type update_feature: bool
:return:
"""
self.frame_id = frame_id
self.tracklet_len += 1
new_tlwh = new_track.tlwh
self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, self.tlwh_to_xywh(new_tlwh))
if new_track.curr_feat is not None:
self.update_features(new_track.curr_feat)
self.state = TrackState.Tracked
self.is_activated = True
self.score = new_track.score
self.update_cls(new_track.cls, new_track.score)
@property
def tlwh(self):
"""Get current position in bounding box format `(top left x, top left y,
width, height)`.
"""
if self.mean is None:
return self._tlwh.copy()
ret = self.mean[:4].copy()
ret[:2] -= ret[2:] / 2
return ret
@property
def tlbr(self):
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`.
"""
ret = self.tlwh.copy()
ret[2:] += ret[:2]
return ret
@property
def xywh(self):
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`.
"""
ret = self.tlwh.copy()
ret[:2] += ret[2:] / 2.0
return ret
@staticmethod
def tlwh_to_xyah(tlwh):
"""Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`.
"""
ret = np.asarray(tlwh).copy()
ret[:2] += ret[2:] / 2
ret[2] /= ret[3]
return ret
@staticmethod
def tlwh_to_xywh(tlwh):
"""Convert bounding box to format `(center x, center y, width,
height)`.
"""
ret = np.asarray(tlwh).copy()
ret[:2] += ret[2:] / 2
return ret
def to_xywh(self):
return self.tlwh_to_xywh(self.tlwh)
@staticmethod
def tlbr_to_tlwh(tlbr):
ret = np.asarray(tlbr).copy()
ret[2:] -= ret[:2]
return ret
@staticmethod
def tlwh_to_tlbr(tlwh):
ret = np.asarray(tlwh).copy()
ret[2:] += ret[:2]
return ret
def __repr__(self):
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
class BoTSORT(object):
def __init__(self,
model_weights,
device,
fp16,
track_high_thresh:float = 0.45,
new_track_thresh:float = 0.6,
track_buffer:int = 30,
match_thresh:float = 0.8,
proximity_thresh:float = 0.5,
appearance_thresh:float = 0.25,
cmc_method:str = 'sparseOptFlow',
frame_rate=30,
lambda_=0.985
):
self.tracked_stracks = [] # type: list[STrack]
self.lost_stracks = [] # type: list[STrack]
self.removed_stracks = [] # type: list[STrack]
BaseTrack.clear_count()
self.frame_id = 0
self.lambda_ = lambda_
self.track_high_thresh = track_high_thresh
self.new_track_thresh = new_track_thresh
self.buffer_size = int(frame_rate / 30.0 * track_buffer)
self.max_time_lost = self.buffer_size
self.kalman_filter = KalmanFilter()
# ReID module
self.proximity_thresh = proximity_thresh
self.appearance_thresh = appearance_thresh
self.match_thresh = match_thresh
self.model = ReIDDetectMultiBackend(weights=model_weights, device=device, fp16=fp16)
self.gmc = GMC(method=cmc_method, verbose=[None,False])
def update(self, output_results, img):
self.frame_id += 1
activated_starcks = []
refind_stracks = []
lost_stracks = []
removed_stracks = []
xyxys = output_results[:, 0:4]
xywh = xyxy2xywh(xyxys.numpy())
confs = output_results[:, 4]
clss = output_results[:, 5]
classes = clss.numpy()
xyxys = xyxys.numpy()
confs = confs.numpy()
remain_inds = confs > self.track_high_thresh
inds_low = confs > 0.1
inds_high = confs < self.track_high_thresh
inds_second = np.logical_and(inds_low, inds_high)
dets_second = xywh[inds_second]
dets = xywh[remain_inds]
scores_keep = confs[remain_inds]
scores_second = confs[inds_second]
classes_keep = classes[remain_inds]
clss_second = classes[inds_second]
self.height, self.width = img.shape[:2]
'''Extract embeddings '''
features_keep = self._get_features(dets, img)
if len(dets) > 0:
'''Detections'''
detections = [STrack(xyxy, s, c, f.cpu().numpy()) for
(xyxy, s, c, f) in zip(dets, scores_keep, classes_keep, features_keep)]
else:
detections = []
''' Add newly detected tracklets to tracked_stracks'''
unconfirmed = []
tracked_stracks = [] # type: list[STrack]
for track in self.tracked_stracks:
if not track.is_activated:
unconfirmed.append(track)
else:
tracked_stracks.append(track)
''' Step 2: First association, with high score detection boxes'''
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
# Predict the current location with KF
STrack.multi_predict(strack_pool)
# Fix camera motion
warp = self.gmc.apply(img, dets)
STrack.multi_gmc(strack_pool, warp)
STrack.multi_gmc(unconfirmed, warp)
# Associate with high score detection boxes
raw_emb_dists = matching.embedding_distance(strack_pool, detections)
dists = matching.fuse_motion(self.kalman_filter, raw_emb_dists, strack_pool, detections, only_position=False, lambda_=self.lambda_)
# ious_dists = matching.iou_distance(strack_pool, detections)
# ious_dists_mask = (ious_dists > self.proximity_thresh)
# ious_dists = matching.fuse_score(ious_dists, detections)
# emb_dists = matching.embedding_distance(strack_pool, detections) / 2.0
# raw_emb_dists = emb_dists.copy()
# emb_dists[emb_dists > self.appearance_thresh] = 1.0
# emb_dists[ious_dists_mask] = 1.0
# dists = np.minimum(ious_dists, emb_dists)
# Popular ReID method (JDE / FairMOT)
# raw_emb_dists = matching.embedding_distance(strack_pool, detections)
# dists = matching.fuse_motion(self.kalman_filter, raw_emb_dists, strack_pool, detections)
# emb_dists = dists
# IoU making ReID
# dists = matching.embedding_distance(strack_pool, detections)
# dists[ious_dists_mask] = 1.0
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.match_thresh)
for itracked, idet in matches:
track = strack_pool[itracked]
det = detections[idet]
if track.state == TrackState.Tracked:
track.update(detections[idet], self.frame_id)
activated_starcks.append(track)
else:
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
''' Step 3: Second association, with low score detection boxes'''
# if len(scores):
# inds_high = scores < self.track_high_thresh
# inds_low = scores > self.track_low_thresh
# inds_second = np.logical_and(inds_low, inds_high)
# dets_second = bboxes[inds_second]
# scores_second = scores[inds_second]
# classes_second = classes[inds_second]
# else:
# dets_second = []
# scores_second = []
# classes_second = []
# association the untrack to the low score detections
if len(dets_second) > 0:
'''Detections'''
detections_second = [STrack(STrack.tlbr_to_tlwh(tlbr), s, c) for
(tlbr, s, c) in zip(dets_second, scores_second, clss_second)]
else:
detections_second = []
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
dists = matching.iou_distance(r_tracked_stracks, detections_second)
matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)
for itracked, idet in matches:
track = r_tracked_stracks[itracked]
det = detections_second[idet]
if track.state == TrackState.Tracked:
track.update(det, self.frame_id)
activated_starcks.append(track)
else:
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
for it in u_track:
track = r_tracked_stracks[it]
if not track.state == TrackState.Lost:
track.mark_lost()
lost_stracks.append(track)
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
detections = [detections[i] for i in u_detection]
ious_dists = matching.iou_distance(unconfirmed, detections)
ious_dists_mask = (ious_dists > self.proximity_thresh)
ious_dists = matching.fuse_score(ious_dists, detections)
emb_dists = matching.embedding_distance(unconfirmed, detections) / 2.0
raw_emb_dists = emb_dists.copy()
emb_dists[emb_dists > self.appearance_thresh] = 1.0
emb_dists[ious_dists_mask] = 1.0
dists = np.minimum(ious_dists, emb_dists)
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
for itracked, idet in matches:
unconfirmed[itracked].update(detections[idet], self.frame_id)
activated_starcks.append(unconfirmed[itracked])
for it in u_unconfirmed:
track = unconfirmed[it]
track.mark_removed()
removed_stracks.append(track)
""" Step 4: Init new stracks"""
for inew in u_detection:
track = detections[inew]
if track.score < self.new_track_thresh:
continue
track.activate(self.kalman_filter, self.frame_id)
activated_starcks.append(track)
""" Step 5: Update state"""
for track in self.lost_stracks:
if self.frame_id - track.end_frame > self.max_time_lost:
track.mark_removed()
removed_stracks.append(track)
""" Merge """
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
self.lost_stracks.extend(lost_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
self.removed_stracks.extend(removed_stracks)
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
# output_stracks = [track for track in self.tracked_stracks if track.is_activated]
output_stracks = [track for track in self.tracked_stracks if track.is_activated]
outputs = []
for t in output_stracks:
output= []
tlwh = t.tlwh
tid = t.track_id
tlwh = np.expand_dims(tlwh, axis=0)
xyxy = xywh2xyxy(tlwh)
xyxy = np.squeeze(xyxy, axis=0)
output.extend(xyxy)
output.append(tid)
output.append(t.cls)
output.append(t.score)
outputs.append(output)
return outputs
def _xywh_to_xyxy(self, bbox_xywh):
x, y, w, h = bbox_xywh
x1 = max(int(x - w / 2), 0)
x2 = min(int(x + w / 2), self.width - 1)
y1 = max(int(y - h / 2), 0)
y2 = min(int(y + h / 2), self.height - 1)
return x1, y1, x2, y2
def _get_features(self, bbox_xywh, ori_img):
im_crops = []
for box in bbox_xywh:
x1, y1, x2, y2 = self._xywh_to_xyxy(box)
im = ori_img[y1:y2, x1:x2]
im_crops.append(im)
if im_crops:
features = self.model(im_crops)
else:
features = np.array([])
return features
def joint_stracks(tlista, tlistb):
exists = {}
res = []
for t in tlista:
exists[t.track_id] = 1
res.append(t)
for t in tlistb:
tid = t.track_id
if not exists.get(tid, 0):
exists[tid] = 1
res.append(t)
return res
def sub_stracks(tlista, tlistb):
stracks = {}
for t in tlista:
stracks[t.track_id] = t
for t in tlistb:
tid = t.track_id
if stracks.get(tid, 0):
del stracks[tid]
return list(stracks.values())
def remove_duplicate_stracks(stracksa, stracksb):
pdist = matching.iou_distance(stracksa, stracksb)
pairs = np.where(pdist < 0.15)
dupa, dupb = list(), list()
for p, q in zip(*pairs):
timep = stracksa[p].frame_id - stracksa[p].start_frame
timeq = stracksb[q].frame_id - stracksb[q].start_frame
if timep > timeq:
dupb.append(q)
else:
dupa.append(p)
resa = [t for i, t in enumerate(stracksa) if not i in dupa]
resb = [t for i, t in enumerate(stracksb) if not i in dupb]
return resa, resb

View file

@ -0,0 +1,13 @@
# Trial number: 232
# HOTA, MOTA, IDF1: [45.31]
botsort:
appearance_thresh: 0.4818211117541298
cmc_method: sparseOptFlow
conf_thres: 0.3501265956918775
frame_rate: 30
lambda_: 0.9896143462366406
match_thresh: 0.22734550911325851
new_track_thresh: 0.21144301345190655
proximity_thresh: 0.5945380911899254
track_buffer: 60
track_high_thresh: 0.33824964456239337

View file

@ -0,0 +1,316 @@
import cv2
import matplotlib.pyplot as plt
import numpy as np
import copy
import time
class GMC:
def __init__(self, method='sparseOptFlow', downscale=2, verbose=None):
super(GMC, self).__init__()
self.method = method
self.downscale = max(1, int(downscale))
if self.method == 'orb':
self.detector = cv2.FastFeatureDetector_create(20)
self.extractor = cv2.ORB_create()
self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
elif self.method == 'sift':
self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
self.matcher = cv2.BFMatcher(cv2.NORM_L2)
elif self.method == 'ecc':
number_of_iterations = 5000
termination_eps = 1e-6
self.warp_mode = cv2.MOTION_EUCLIDEAN
self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
elif self.method == 'sparseOptFlow':
self.feature_params = dict(maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3,
useHarrisDetector=False, k=0.04)
# self.gmc_file = open('GMC_results.txt', 'w')
elif self.method == 'file' or self.method == 'files':
seqName = verbose[0]
ablation = verbose[1]
if ablation:
filePath = r'tracker/GMC_files/MOT17_ablation'
else:
filePath = r'tracker/GMC_files/MOTChallenge'
if '-FRCNN' in seqName:
seqName = seqName[:-6]
elif '-DPM' in seqName:
seqName = seqName[:-4]
elif '-SDP' in seqName:
seqName = seqName[:-4]
self.gmcFile = open(filePath + "/GMC-" + seqName + ".txt", 'r')
if self.gmcFile is None:
raise ValueError("Error: Unable to open GMC file in directory:" + filePath)
elif self.method == 'none' or self.method == 'None':
self.method = 'none'
else:
raise ValueError("Error: Unknown CMC method:" + method)
self.prevFrame = None
self.prevKeyPoints = None
self.prevDescriptors = None
self.initializedFirstFrame = False
def apply(self, raw_frame, detections=None):
if self.method == 'orb' or self.method == 'sift':
return self.applyFeaures(raw_frame, detections)
elif self.method == 'ecc':
return self.applyEcc(raw_frame, detections)
elif self.method == 'sparseOptFlow':
return self.applySparseOptFlow(raw_frame, detections)
elif self.method == 'file':
return self.applyFile(raw_frame, detections)
elif self.method == 'none':
return np.eye(2, 3)
else:
return np.eye(2, 3)
def applyEcc(self, raw_frame, detections=None):
# Initialize
height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
H = np.eye(2, 3, dtype=np.float32)
# Downscale image (TODO: consider using pyramids)
if self.downscale > 1.0:
frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
width = width // self.downscale
height = height // self.downscale
# Handle first frame
if not self.initializedFirstFrame:
# Initialize data
self.prevFrame = frame.copy()
# Initialization done
self.initializedFirstFrame = True
return H
# Run the ECC algorithm. The results are stored in warp_matrix.
# (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
try:
(cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
except:
print('Warning: find transform failed. Set warp as identity')
return H
def applyFeaures(self, raw_frame, detections=None):
# Initialize
height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
H = np.eye(2, 3)
# Downscale image (TODO: consider using pyramids)
if self.downscale > 1.0:
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
width = width // self.downscale
height = height // self.downscale
# find the keypoints
mask = np.zeros_like(frame)
# mask[int(0.05 * height): int(0.95 * height), int(0.05 * width): int(0.95 * width)] = 255
mask[int(0.02 * height): int(0.98 * height), int(0.02 * width): int(0.98 * width)] = 255
if detections is not None:
for det in detections:
tlbr = (det[:4] / self.downscale).astype(np.int_)
mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0
keypoints = self.detector.detect(frame, mask)
# compute the descriptors
keypoints, descriptors = self.extractor.compute(frame, keypoints)
# Handle first frame
if not self.initializedFirstFrame:
# Initialize data
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
self.prevDescriptors = copy.copy(descriptors)
# Initialization done
self.initializedFirstFrame = True
return H
# Match descriptors.
knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
# Filtered matches based on smallest spatial distance
matches = []
spatialDistances = []
maxSpatialDistance = 0.25 * np.array([width, height])
# Handle empty matches case
if len(knnMatches) == 0:
# Store to next iteration
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
self.prevDescriptors = copy.copy(descriptors)
return H
for m, n in knnMatches:
if m.distance < 0.9 * n.distance:
prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
currKeyPointLocation = keypoints[m.trainIdx].pt
spatialDistance = (prevKeyPointLocation[0] - currKeyPointLocation[0],
prevKeyPointLocation[1] - currKeyPointLocation[1])
if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \
(np.abs(spatialDistance[1]) < maxSpatialDistance[1]):
spatialDistances.append(spatialDistance)
matches.append(m)
meanSpatialDistances = np.mean(spatialDistances, 0)
stdSpatialDistances = np.std(spatialDistances, 0)
inliesrs = (spatialDistances - meanSpatialDistances) < 2.5 * stdSpatialDistances
goodMatches = []
prevPoints = []
currPoints = []
for i in range(len(matches)):
if inliesrs[i, 0] and inliesrs[i, 1]:
goodMatches.append(matches[i])
prevPoints.append(self.prevKeyPoints[matches[i].queryIdx].pt)
currPoints.append(keypoints[matches[i].trainIdx].pt)
prevPoints = np.array(prevPoints)
currPoints = np.array(currPoints)
# Draw the keypoint matches on the output image
if 0:
matches_img = np.hstack((self.prevFrame, frame))
matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR)
W = np.size(self.prevFrame, 1)
for m in goodMatches:
prev_pt = np.array(self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_)
curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_)
curr_pt[0] += W
color = np.random.randint(0, 255, (3,))
color = (int(color[0]), int(color[1]), int(color[2]))
matches_img = cv2.line(matches_img, prev_pt, curr_pt, tuple(color), 1, cv2.LINE_AA)
matches_img = cv2.circle(matches_img, prev_pt, 2, tuple(color), -1)
matches_img = cv2.circle(matches_img, curr_pt, 2, tuple(color), -1)
plt.figure()
plt.imshow(matches_img)
plt.show()
# Find rigid matrix
if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)):
H, inliesrs = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
# Handle downscale
if self.downscale > 1.0:
H[0, 2] *= self.downscale
H[1, 2] *= self.downscale
else:
print('Warning: not enough matching points')
# Store to next iteration
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
self.prevDescriptors = copy.copy(descriptors)
return H
def applySparseOptFlow(self, raw_frame, detections=None):
t0 = time.time()
# Initialize
height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
H = np.eye(2, 3)
# Downscale image
if self.downscale > 1.0:
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
# find the keypoints
keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params)
# Handle first frame
if not self.initializedFirstFrame:
# Initialize data
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
# Initialization done
self.initializedFirstFrame = True
return H
# find correspondences
matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
# leave good correspondences only
prevPoints = []
currPoints = []
for i in range(len(status)):
if status[i]:
prevPoints.append(self.prevKeyPoints[i])
currPoints.append(matchedKeypoints[i])
prevPoints = np.array(prevPoints)
currPoints = np.array(currPoints)
# Find rigid matrix
if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)):
H, inliesrs = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
# Handle downscale
if self.downscale > 1.0:
H[0, 2] *= self.downscale
H[1, 2] *= self.downscale
else:
print('Warning: not enough matching points')
# Store to next iteration
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
t1 = time.time()
# gmc_line = str(1000 * (t1 - t0)) + "\t" + str(H[0, 0]) + "\t" + str(H[0, 1]) + "\t" + str(
# H[0, 2]) + "\t" + str(H[1, 0]) + "\t" + str(H[1, 1]) + "\t" + str(H[1, 2]) + "\n"
# self.gmc_file.write(gmc_line)
return H
def applyFile(self, raw_frame, detections=None):
line = self.gmcFile.readline()
tokens = line.split("\t")
H = np.eye(2, 3, dtype=np.float_)
H[0, 0] = float(tokens[1])
H[0, 1] = float(tokens[2])
H[0, 2] = float(tokens[3])
H[1, 0] = float(tokens[4])
H[1, 1] = float(tokens[5])
H[1, 2] = float(tokens[6])
return H

View file

@ -0,0 +1,269 @@
# vim: expandtab:ts=4:sw=4
import numpy as np
import scipy.linalg
"""
Table for the 0.95 quantile of the chi-square distribution with N degrees of
freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
function and used as Mahalanobis gating threshold.
"""
chi2inv95 = {
1: 3.8415,
2: 5.9915,
3: 7.8147,
4: 9.4877,
5: 11.070,
6: 12.592,
7: 14.067,
8: 15.507,
9: 16.919}
class KalmanFilter(object):
"""
A simple Kalman filter for tracking bounding boxes in image space.
The 8-dimensional state space
x, y, w, h, vx, vy, vw, vh
contains the bounding box center position (x, y), width w, height h,
and their respective velocities.
Object motion follows a constant velocity model. The bounding box location
(x, y, w, h) is taken as direct observation of the state space (linear
observation model).
"""
def __init__(self):
ndim, dt = 4, 1.
# Create Kalman filter model matrices.
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
for i in range(ndim):
self._motion_mat[i, ndim + i] = dt
self._update_mat = np.eye(ndim, 2 * ndim)
# Motion and observation uncertainty are chosen relative to the current
# state estimate. These weights control the amount of uncertainty in
# the model. This is a bit hacky.
self._std_weight_position = 1. / 20
self._std_weight_velocity = 1. / 160
def initiate(self, measurement):
"""Create track from unassociated measurement.
Parameters
----------
measurement : ndarray
Bounding box coordinates (x, y, w, h) with center position (x, y),
width w, and height h.
Returns
-------
(ndarray, ndarray)
Returns the mean vector (8 dimensional) and covariance matrix (8x8
dimensional) of the new track. Unobserved velocities are initialized
to 0 mean.
"""
mean_pos = measurement
mean_vel = np.zeros_like(mean_pos)
mean = np.r_[mean_pos, mean_vel]
std = [
2 * self._std_weight_position * measurement[2],
2 * self._std_weight_position * measurement[3],
2 * self._std_weight_position * measurement[2],
2 * self._std_weight_position * measurement[3],
10 * self._std_weight_velocity * measurement[2],
10 * self._std_weight_velocity * measurement[3],
10 * self._std_weight_velocity * measurement[2],
10 * self._std_weight_velocity * measurement[3]]
covariance = np.diag(np.square(std))
return mean, covariance
def predict(self, mean, covariance):
"""Run Kalman filter prediction step.
Parameters
----------
mean : ndarray
The 8 dimensional mean vector of the object state at the previous
time step.
covariance : ndarray
The 8x8 dimensional covariance matrix of the object state at the
previous time step.
Returns
-------
(ndarray, ndarray)
Returns the mean vector and covariance matrix of the predicted
state. Unobserved velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[2],
self._std_weight_position * mean[3],
self._std_weight_position * mean[2],
self._std_weight_position * mean[3]]
std_vel = [
self._std_weight_velocity * mean[2],
self._std_weight_velocity * mean[3],
self._std_weight_velocity * mean[2],
self._std_weight_velocity * mean[3]]
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
mean = np.dot(mean, self._motion_mat.T)
covariance = np.linalg.multi_dot((
self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
return mean, covariance
def project(self, mean, covariance):
"""Project state distribution to measurement space.
Parameters
----------
mean : ndarray
The state's mean vector (8 dimensional array).
covariance : ndarray
The state's covariance matrix (8x8 dimensional).
Returns
-------
(ndarray, ndarray)
Returns the projected mean and covariance matrix of the given state
estimate.
"""
std = [
self._std_weight_position * mean[2],
self._std_weight_position * mean[3],
self._std_weight_position * mean[2],
self._std_weight_position * mean[3]]
innovation_cov = np.diag(np.square(std))
mean = np.dot(self._update_mat, mean)
covariance = np.linalg.multi_dot((
self._update_mat, covariance, self._update_mat.T))
return mean, covariance + innovation_cov
def multi_predict(self, mean, covariance):
"""Run Kalman filter prediction step (Vectorized version).
Parameters
----------
mean : ndarray
The Nx8 dimensional mean matrix of the object states at the previous
time step.
covariance : ndarray
The Nx8x8 dimensional covariance matrics of the object states at the
previous time step.
Returns
-------
(ndarray, ndarray)
Returns the mean vector and covariance matrix of the predicted
state. Unobserved velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[:, 2],
self._std_weight_position * mean[:, 3],
self._std_weight_position * mean[:, 2],
self._std_weight_position * mean[:, 3]]
std_vel = [
self._std_weight_velocity * mean[:, 2],
self._std_weight_velocity * mean[:, 3],
self._std_weight_velocity * mean[:, 2],
self._std_weight_velocity * mean[:, 3]]
sqr = np.square(np.r_[std_pos, std_vel]).T
motion_cov = []
for i in range(len(mean)):
motion_cov.append(np.diag(sqr[i]))
motion_cov = np.asarray(motion_cov)
mean = np.dot(mean, self._motion_mat.T)
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
covariance = np.dot(left, self._motion_mat.T) + motion_cov
return mean, covariance
def update(self, mean, covariance, measurement):
"""Run Kalman filter correction step.
Parameters
----------
mean : ndarray
The predicted state's mean vector (8 dimensional).
covariance : ndarray
The state's covariance matrix (8x8 dimensional).
measurement : ndarray
The 4 dimensional measurement vector (x, y, w, h), where (x, y)
is the center position, w the width, and h the height of the
bounding box.
Returns
-------
(ndarray, ndarray)
Returns the measurement-corrected state distribution.
"""
projected_mean, projected_cov = self.project(mean, covariance)
chol_factor, lower = scipy.linalg.cho_factor(
projected_cov, lower=True, check_finite=False)
kalman_gain = scipy.linalg.cho_solve(
(chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
check_finite=False).T
innovation = measurement - projected_mean
new_mean = mean + np.dot(innovation, kalman_gain.T)
new_covariance = covariance - np.linalg.multi_dot((
kalman_gain, projected_cov, kalman_gain.T))
return new_mean, new_covariance
def gating_distance(self, mean, covariance, measurements,
only_position=False, metric='maha'):
"""Compute gating distance between state distribution and measurements.
A suitable distance threshold can be obtained from `chi2inv95`. If
`only_position` is False, the chi-square distribution has 4 degrees of
freedom, otherwise 2.
Parameters
----------
mean : ndarray
Mean vector over the state distribution (8 dimensional).
covariance : ndarray
Covariance of the state distribution (8x8 dimensional).
measurements : ndarray
An Nx4 dimensional matrix of N measurements, each in
format (x, y, a, h) where (x, y) is the bounding box center
position, a the aspect ratio, and h the height.
only_position : Optional[bool]
If True, distance computation is done with respect to the bounding
box center position only.
Returns
-------
ndarray
Returns an array of length N, where the i-th element contains the
squared Mahalanobis distance between (mean, covariance) and
`measurements[i]`.
"""
mean, covariance = self.project(mean, covariance)
if only_position:
mean, covariance = mean[:2], covariance[:2, :2]
measurements = measurements[:, :2]
d = measurements - mean
if metric == 'gaussian':
return np.sum(d * d, axis=1)
elif metric == 'maha':
cholesky_factor = np.linalg.cholesky(covariance)
z = scipy.linalg.solve_triangular(
cholesky_factor, d.T, lower=True, check_finite=False,
overwrite_b=True)
squared_maha = np.sum(z * z, axis=0)
return squared_maha
else:
raise ValueError('invalid distance metric')

View file

@ -0,0 +1,234 @@
import numpy as np
import scipy
import lap
from scipy.spatial.distance import cdist
from trackers.botsort import kalman_filter
def merge_matches(m1, m2, shape):
O,P,Q = shape
m1 = np.asarray(m1)
m2 = np.asarray(m2)
M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P))
M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q))
mask = M1*M2
match = mask.nonzero()
match = list(zip(match[0], match[1]))
unmatched_O = tuple(set(range(O)) - set([i for i, j in match]))
unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match]))
return match, unmatched_O, unmatched_Q
def _indices_to_matches(cost_matrix, indices, thresh):
matched_cost = cost_matrix[tuple(zip(*indices))]
matched_mask = (matched_cost <= thresh)
matches = indices[matched_mask]
unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
return matches, unmatched_a, unmatched_b
def linear_assignment(cost_matrix, thresh):
if cost_matrix.size == 0:
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
matches, unmatched_a, unmatched_b = [], [], []
cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
for ix, mx in enumerate(x):
if mx >= 0:
matches.append([ix, mx])
unmatched_a = np.where(x < 0)[0]
unmatched_b = np.where(y < 0)[0]
matches = np.asarray(matches)
return matches, unmatched_a, unmatched_b
def ious(atlbrs, btlbrs):
"""
Compute cost based on IoU
:type atlbrs: list[tlbr] | np.ndarray
:type atlbrs: list[tlbr] | np.ndarray
:rtype ious np.ndarray
"""
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
if ious.size == 0:
return ious
ious = bbox_ious(
np.ascontiguousarray(atlbrs, dtype=np.float32),
np.ascontiguousarray(btlbrs, dtype=np.float32)
)
return ious
def tlbr_expand(tlbr, scale=1.2):
w = tlbr[2] - tlbr[0]
h = tlbr[3] - tlbr[1]
half_scale = 0.5 * scale
tlbr[0] -= half_scale * w
tlbr[1] -= half_scale * h
tlbr[2] += half_scale * w
tlbr[3] += half_scale * h
return tlbr
def iou_distance(atracks, btracks):
"""
Compute cost based on IoU
:type atracks: list[STrack]
:type btracks: list[STrack]
:rtype cost_matrix np.ndarray
"""
if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
atlbrs = atracks
btlbrs = btracks
else:
atlbrs = [track.tlbr for track in atracks]
btlbrs = [track.tlbr for track in btracks]
_ious = ious(atlbrs, btlbrs)
cost_matrix = 1 - _ious
return cost_matrix
def v_iou_distance(atracks, btracks):
"""
Compute cost based on IoU
:type atracks: list[STrack]
:type btracks: list[STrack]
:rtype cost_matrix np.ndarray
"""
if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
atlbrs = atracks
btlbrs = btracks
else:
atlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in atracks]
btlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in btracks]
_ious = ious(atlbrs, btlbrs)
cost_matrix = 1 - _ious
return cost_matrix
def embedding_distance(tracks, detections, metric='cosine'):
"""
:param tracks: list[STrack]
:param detections: list[BaseTrack]
:param metric:
:return: cost_matrix np.ndarray
"""
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
if cost_matrix.size == 0:
return cost_matrix
det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32)
track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32)
cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # / 2.0 # Nomalized features
return cost_matrix
def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False):
if cost_matrix.size == 0:
return cost_matrix
gating_dim = 2 if only_position else 4
gating_threshold = kalman_filter.chi2inv95[gating_dim]
# measurements = np.asarray([det.to_xyah() for det in detections])
measurements = np.asarray([det.to_xywh() for det in detections])
for row, track in enumerate(tracks):
gating_distance = kf.gating_distance(
track.mean, track.covariance, measurements, only_position)
cost_matrix[row, gating_distance > gating_threshold] = np.inf
return cost_matrix
def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98):
if cost_matrix.size == 0:
return cost_matrix
gating_dim = 2 if only_position else 4
gating_threshold = kalman_filter.chi2inv95[gating_dim]
# measurements = np.asarray([det.to_xyah() for det in detections])
measurements = np.asarray([det.to_xywh() for det in detections])
for row, track in enumerate(tracks):
gating_distance = kf.gating_distance(
track.mean, track.covariance, measurements, only_position, metric='maha')
cost_matrix[row, gating_distance > gating_threshold] = np.inf
cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance
return cost_matrix
def fuse_iou(cost_matrix, tracks, detections):
if cost_matrix.size == 0:
return cost_matrix
reid_sim = 1 - cost_matrix
iou_dist = iou_distance(tracks, detections)
iou_sim = 1 - iou_dist
fuse_sim = reid_sim * (1 + iou_sim) / 2
det_scores = np.array([det.score for det in detections])
det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
#fuse_sim = fuse_sim * (1 + det_scores) / 2
fuse_cost = 1 - fuse_sim
return fuse_cost
def fuse_score(cost_matrix, detections):
if cost_matrix.size == 0:
return cost_matrix
iou_sim = 1 - cost_matrix
det_scores = np.array([det.score for det in detections])
det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
fuse_sim = iou_sim * det_scores
fuse_cost = 1 - fuse_sim
return fuse_cost
def bbox_ious(boxes, query_boxes):
"""
Parameters
----------
boxes: (N, 4) ndarray of float
query_boxes: (K, 4) ndarray of float
Returns
-------
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
"""
N = boxes.shape[0]
K = query_boxes.shape[0]
overlaps = np.zeros((N, K), dtype=np.float32)
for k in range(K):
box_area = (
(query_boxes[k, 2] - query_boxes[k, 0] + 1) *
(query_boxes[k, 3] - query_boxes[k, 1] + 1)
)
for n in range(N):
iw = (
min(boxes[n, 2], query_boxes[k, 2]) -
max(boxes[n, 0], query_boxes[k, 0]) + 1
)
if iw > 0:
ih = (
min(boxes[n, 3], query_boxes[k, 3]) -
max(boxes[n, 1], query_boxes[k, 1]) + 1
)
if ih > 0:
ua = float(
(boxes[n, 2] - boxes[n, 0] + 1) *
(boxes[n, 3] - boxes[n, 1] + 1) +
box_area - iw * ih
)
overlaps[n, k] = iw * ih / ua
return overlaps

View file

@ -0,0 +1,52 @@
import numpy as np
from collections import OrderedDict
class TrackState(object):
New = 0
Tracked = 1
Lost = 2
Removed = 3
class BaseTrack(object):
_count = 0
track_id = 0
is_activated = False
state = TrackState.New
history = OrderedDict()
features = []
curr_feature = None
score = 0
start_frame = 0
frame_id = 0
time_since_update = 0
# multi-camera
location = (np.inf, np.inf)
@property
def end_frame(self):
return self.frame_id
@staticmethod
def next_id():
BaseTrack._count += 1
return BaseTrack._count
def activate(self, *args):
raise NotImplementedError
def predict(self):
raise NotImplementedError
def update(self, *args, **kwargs):
raise NotImplementedError
def mark_lost(self):
self.state = TrackState.Lost
def mark_removed(self):
self.state = TrackState.Removed

View file

@ -0,0 +1,348 @@
import numpy as np
from ultralytics.yolo.utils.ops import xywh2xyxy, xyxy2xywh
from trackers.bytetrack.kalman_filter import KalmanFilter
from trackers.bytetrack import matching
from trackers.bytetrack.basetrack import BaseTrack, TrackState
class STrack(BaseTrack):
shared_kalman = KalmanFilter()
def __init__(self, tlwh, score, cls):
# wait activate
self._tlwh = np.asarray(tlwh, dtype=np.float32)
self.kalman_filter = None
self.mean, self.covariance = None, None
self.is_activated = False
self.score = score
self.tracklet_len = 0
self.cls = cls
def predict(self):
mean_state = self.mean.copy()
if self.state != TrackState.Tracked:
mean_state[7] = 0
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
@staticmethod
def multi_predict(stracks):
if len(stracks) > 0:
multi_mean = np.asarray([st.mean.copy() for st in stracks])
multi_covariance = np.asarray([st.covariance for st in stracks])
for i, st in enumerate(stracks):
if st.state != TrackState.Tracked:
multi_mean[i][7] = 0
multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
stracks[i].mean = mean
stracks[i].covariance = cov
def activate(self, kalman_filter, frame_id):
"""Start a new tracklet"""
self.kalman_filter = kalman_filter
self.track_id = self.next_id()
self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))
self.tracklet_len = 0
self.state = TrackState.Tracked
if frame_id == 1:
self.is_activated = True
# self.is_activated = True
self.frame_id = frame_id
self.start_frame = frame_id
def re_activate(self, new_track, frame_id, new_id=False):
self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
)
self.tracklet_len = 0
self.state = TrackState.Tracked
self.is_activated = True
self.frame_id = frame_id
if new_id:
self.track_id = self.next_id()
self.score = new_track.score
self.cls = new_track.cls
def update(self, new_track, frame_id):
"""
Update a matched track
:type new_track: STrack
:type frame_id: int
:type update_feature: bool
:return:
"""
self.frame_id = frame_id
self.tracklet_len += 1
# self.cls = cls
new_tlwh = new_track.tlwh
self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
self.state = TrackState.Tracked
self.is_activated = True
self.score = new_track.score
@property
# @jit(nopython=True)
def tlwh(self):
"""Get current position in bounding box format `(top left x, top left y,
width, height)`.
"""
if self.mean is None:
return self._tlwh.copy()
ret = self.mean[:4].copy()
ret[2] *= ret[3]
ret[:2] -= ret[2:] / 2
return ret
@property
# @jit(nopython=True)
def tlbr(self):
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`.
"""
ret = self.tlwh.copy()
ret[2:] += ret[:2]
return ret
@staticmethod
# @jit(nopython=True)
def tlwh_to_xyah(tlwh):
"""Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`.
"""
ret = np.asarray(tlwh).copy()
ret[:2] += ret[2:] / 2
ret[2] /= ret[3]
return ret
def to_xyah(self):
return self.tlwh_to_xyah(self.tlwh)
@staticmethod
# @jit(nopython=True)
def tlbr_to_tlwh(tlbr):
ret = np.asarray(tlbr).copy()
ret[2:] -= ret[:2]
return ret
@staticmethod
# @jit(nopython=True)
def tlwh_to_tlbr(tlwh):
ret = np.asarray(tlwh).copy()
ret[2:] += ret[:2]
return ret
def __repr__(self):
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
class BYTETracker(object):
def __init__(self, track_thresh=0.45, match_thresh=0.8, track_buffer=25, frame_rate=30):
self.tracked_stracks = [] # type: list[STrack]
self.lost_stracks = [] # type: list[STrack]
self.removed_stracks = [] # type: list[STrack]
self.frame_id = 0
self.track_buffer=track_buffer
self.track_thresh = track_thresh
self.match_thresh = match_thresh
self.det_thresh = track_thresh + 0.1
self.buffer_size = int(frame_rate / 30.0 * track_buffer)
self.max_time_lost = self.buffer_size
self.kalman_filter = KalmanFilter()
def update(self, dets, _):
self.frame_id += 1
activated_starcks = []
refind_stracks = []
lost_stracks = []
removed_stracks = []
xyxys = dets[:, 0:4]
xywh = xyxy2xywh(xyxys.numpy())
confs = dets[:, 4]
clss = dets[:, 5]
classes = clss.numpy()
xyxys = xyxys.numpy()
confs = confs.numpy()
remain_inds = confs > self.track_thresh
inds_low = confs > 0.1
inds_high = confs < self.track_thresh
inds_second = np.logical_and(inds_low, inds_high)
dets_second = xywh[inds_second]
dets = xywh[remain_inds]
scores_keep = confs[remain_inds]
scores_second = confs[inds_second]
clss_keep = classes[remain_inds]
clss_second = classes[inds_second]
if len(dets) > 0:
'''Detections'''
detections = [STrack(xyxy, s, c) for
(xyxy, s, c) in zip(dets, scores_keep, clss_keep)]
else:
detections = []
''' Add newly detected tracklets to tracked_stracks'''
unconfirmed = []
tracked_stracks = [] # type: list[STrack]
for track in self.tracked_stracks:
if not track.is_activated:
unconfirmed.append(track)
else:
tracked_stracks.append(track)
''' Step 2: First association, with high score detection boxes'''
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
# Predict the current location with KF
STrack.multi_predict(strack_pool)
dists = matching.iou_distance(strack_pool, detections)
#if not self.args.mot20:
dists = matching.fuse_score(dists, detections)
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.match_thresh)
for itracked, idet in matches:
track = strack_pool[itracked]
det = detections[idet]
if track.state == TrackState.Tracked:
track.update(detections[idet], self.frame_id)
activated_starcks.append(track)
else:
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
''' Step 3: Second association, with low score detection boxes'''
# association the untrack to the low score detections
if len(dets_second) > 0:
'''Detections'''
detections_second = [STrack(xywh, s, c) for (xywh, s, c) in zip(dets_second, scores_second, clss_second)]
else:
detections_second = []
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
dists = matching.iou_distance(r_tracked_stracks, detections_second)
matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)
for itracked, idet in matches:
track = r_tracked_stracks[itracked]
det = detections_second[idet]
if track.state == TrackState.Tracked:
track.update(det, self.frame_id)
activated_starcks.append(track)
else:
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
for it in u_track:
track = r_tracked_stracks[it]
if not track.state == TrackState.Lost:
track.mark_lost()
lost_stracks.append(track)
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
detections = [detections[i] for i in u_detection]
dists = matching.iou_distance(unconfirmed, detections)
#if not self.args.mot20:
dists = matching.fuse_score(dists, detections)
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
for itracked, idet in matches:
unconfirmed[itracked].update(detections[idet], self.frame_id)
activated_starcks.append(unconfirmed[itracked])
for it in u_unconfirmed:
track = unconfirmed[it]
track.mark_removed()
removed_stracks.append(track)
""" Step 4: Init new stracks"""
for inew in u_detection:
track = detections[inew]
if track.score < self.det_thresh:
continue
track.activate(self.kalman_filter, self.frame_id)
activated_starcks.append(track)
""" Step 5: Update state"""
for track in self.lost_stracks:
if self.frame_id - track.end_frame > self.max_time_lost:
track.mark_removed()
removed_stracks.append(track)
# print('Ramained match {} s'.format(t4-t3))
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
self.lost_stracks.extend(lost_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
self.removed_stracks.extend(removed_stracks)
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
# get scores of lost tracks
output_stracks = [track for track in self.tracked_stracks if track.is_activated]
outputs = []
for t in output_stracks:
output= []
tlwh = t.tlwh
tid = t.track_id
tlwh = np.expand_dims(tlwh, axis=0)
xyxy = xywh2xyxy(tlwh)
xyxy = np.squeeze(xyxy, axis=0)
output.extend(xyxy)
output.append(tid)
output.append(t.cls)
output.append(t.score)
outputs.append(output)
return outputs
#track_id, class_id, conf
def joint_stracks(tlista, tlistb):
exists = {}
res = []
for t in tlista:
exists[t.track_id] = 1
res.append(t)
for t in tlistb:
tid = t.track_id
if not exists.get(tid, 0):
exists[tid] = 1
res.append(t)
return res
def sub_stracks(tlista, tlistb):
stracks = {}
for t in tlista:
stracks[t.track_id] = t
for t in tlistb:
tid = t.track_id
if stracks.get(tid, 0):
del stracks[tid]
return list(stracks.values())
def remove_duplicate_stracks(stracksa, stracksb):
pdist = matching.iou_distance(stracksa, stracksb)
pairs = np.where(pdist < 0.15)
dupa, dupb = list(), list()
for p, q in zip(*pairs):
timep = stracksa[p].frame_id - stracksa[p].start_frame
timeq = stracksb[q].frame_id - stracksb[q].start_frame
if timep > timeq:
dupb.append(q)
else:
dupa.append(p)
resa = [t for i, t in enumerate(stracksa) if not i in dupa]
resb = [t for i, t in enumerate(stracksb) if not i in dupb]
return resa, resb

View file

@ -0,0 +1,7 @@
bytetrack:
track_thresh: 0.6 # tracking confidence threshold
track_buffer: 30 # the frames for keep lost tracks
match_thresh: 0.8 # matching threshold for tracking
frame_rate: 30 # FPS
conf_thres: 0.5122620708221085

View file

@ -0,0 +1,270 @@
# vim: expandtab:ts=4:sw=4
import numpy as np
import scipy.linalg
"""
Table for the 0.95 quantile of the chi-square distribution with N degrees of
freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
function and used as Mahalanobis gating threshold.
"""
chi2inv95 = {
1: 3.8415,
2: 5.9915,
3: 7.8147,
4: 9.4877,
5: 11.070,
6: 12.592,
7: 14.067,
8: 15.507,
9: 16.919}
class KalmanFilter(object):
"""
A simple Kalman filter for tracking bounding boxes in image space.
The 8-dimensional state space
x, y, a, h, vx, vy, va, vh
contains the bounding box center position (x, y), aspect ratio a, height h,
and their respective velocities.
Object motion follows a constant velocity model. The bounding box location
(x, y, a, h) is taken as direct observation of the state space (linear
observation model).
"""
def __init__(self):
ndim, dt = 4, 1.
# Create Kalman filter model matrices.
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
for i in range(ndim):
self._motion_mat[i, ndim + i] = dt
self._update_mat = np.eye(ndim, 2 * ndim)
# Motion and observation uncertainty are chosen relative to the current
# state estimate. These weights control the amount of uncertainty in
# the model. This is a bit hacky.
self._std_weight_position = 1. / 20
self._std_weight_velocity = 1. / 160
def initiate(self, measurement):
"""Create track from unassociated measurement.
Parameters
----------
measurement : ndarray
Bounding box coordinates (x, y, a, h) with center position (x, y),
aspect ratio a, and height h.
Returns
-------
(ndarray, ndarray)
Returns the mean vector (8 dimensional) and covariance matrix (8x8
dimensional) of the new track. Unobserved velocities are initialized
to 0 mean.
"""
mean_pos = measurement
mean_vel = np.zeros_like(mean_pos)
mean = np.r_[mean_pos, mean_vel]
std = [
2 * self._std_weight_position * measurement[3],
2 * self._std_weight_position * measurement[3],
1e-2,
2 * self._std_weight_position * measurement[3],
10 * self._std_weight_velocity * measurement[3],
10 * self._std_weight_velocity * measurement[3],
1e-5,
10 * self._std_weight_velocity * measurement[3]]
covariance = np.diag(np.square(std))
return mean, covariance
def predict(self, mean, covariance):
"""Run Kalman filter prediction step.
Parameters
----------
mean : ndarray
The 8 dimensional mean vector of the object state at the previous
time step.
covariance : ndarray
The 8x8 dimensional covariance matrix of the object state at the
previous time step.
Returns
-------
(ndarray, ndarray)
Returns the mean vector and covariance matrix of the predicted
state. Unobserved velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[3],
self._std_weight_position * mean[3],
1e-2,
self._std_weight_position * mean[3]]
std_vel = [
self._std_weight_velocity * mean[3],
self._std_weight_velocity * mean[3],
1e-5,
self._std_weight_velocity * mean[3]]
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
#mean = np.dot(self._motion_mat, mean)
mean = np.dot(mean, self._motion_mat.T)
covariance = np.linalg.multi_dot((
self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
return mean, covariance
def project(self, mean, covariance):
"""Project state distribution to measurement space.
Parameters
----------
mean : ndarray
The state's mean vector (8 dimensional array).
covariance : ndarray
The state's covariance matrix (8x8 dimensional).
Returns
-------
(ndarray, ndarray)
Returns the projected mean and covariance matrix of the given state
estimate.
"""
std = [
self._std_weight_position * mean[3],
self._std_weight_position * mean[3],
1e-1,
self._std_weight_position * mean[3]]
innovation_cov = np.diag(np.square(std))
mean = np.dot(self._update_mat, mean)
covariance = np.linalg.multi_dot((
self._update_mat, covariance, self._update_mat.T))
return mean, covariance + innovation_cov
def multi_predict(self, mean, covariance):
"""Run Kalman filter prediction step (Vectorized version).
Parameters
----------
mean : ndarray
The Nx8 dimensional mean matrix of the object states at the previous
time step.
covariance : ndarray
The Nx8x8 dimensional covariance matrics of the object states at the
previous time step.
Returns
-------
(ndarray, ndarray)
Returns the mean vector and covariance matrix of the predicted
state. Unobserved velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[:, 3],
self._std_weight_position * mean[:, 3],
1e-2 * np.ones_like(mean[:, 3]),
self._std_weight_position * mean[:, 3]]
std_vel = [
self._std_weight_velocity * mean[:, 3],
self._std_weight_velocity * mean[:, 3],
1e-5 * np.ones_like(mean[:, 3]),
self._std_weight_velocity * mean[:, 3]]
sqr = np.square(np.r_[std_pos, std_vel]).T
motion_cov = []
for i in range(len(mean)):
motion_cov.append(np.diag(sqr[i]))
motion_cov = np.asarray(motion_cov)
mean = np.dot(mean, self._motion_mat.T)
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
covariance = np.dot(left, self._motion_mat.T) + motion_cov
return mean, covariance
def update(self, mean, covariance, measurement):
"""Run Kalman filter correction step.
Parameters
----------
mean : ndarray
The predicted state's mean vector (8 dimensional).
covariance : ndarray
The state's covariance matrix (8x8 dimensional).
measurement : ndarray
The 4 dimensional measurement vector (x, y, a, h), where (x, y)
is the center position, a the aspect ratio, and h the height of the
bounding box.
Returns
-------
(ndarray, ndarray)
Returns the measurement-corrected state distribution.
"""
projected_mean, projected_cov = self.project(mean, covariance)
chol_factor, lower = scipy.linalg.cho_factor(
projected_cov, lower=True, check_finite=False)
kalman_gain = scipy.linalg.cho_solve(
(chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
check_finite=False).T
innovation = measurement - projected_mean
new_mean = mean + np.dot(innovation, kalman_gain.T)
new_covariance = covariance - np.linalg.multi_dot((
kalman_gain, projected_cov, kalman_gain.T))
return new_mean, new_covariance
def gating_distance(self, mean, covariance, measurements,
only_position=False, metric='maha'):
"""Compute gating distance between state distribution and measurements.
A suitable distance threshold can be obtained from `chi2inv95`. If
`only_position` is False, the chi-square distribution has 4 degrees of
freedom, otherwise 2.
Parameters
----------
mean : ndarray
Mean vector over the state distribution (8 dimensional).
covariance : ndarray
Covariance of the state distribution (8x8 dimensional).
measurements : ndarray
An Nx4 dimensional matrix of N measurements, each in
format (x, y, a, h) where (x, y) is the bounding box center
position, a the aspect ratio, and h the height.
only_position : Optional[bool]
If True, distance computation is done with respect to the bounding
box center position only.
Returns
-------
ndarray
Returns an array of length N, where the i-th element contains the
squared Mahalanobis distance between (mean, covariance) and
`measurements[i]`.
"""
mean, covariance = self.project(mean, covariance)
if only_position:
mean, covariance = mean[:2], covariance[:2, :2]
measurements = measurements[:, :2]
d = measurements - mean
if metric == 'gaussian':
return np.sum(d * d, axis=1)
elif metric == 'maha':
cholesky_factor = np.linalg.cholesky(covariance)
z = scipy.linalg.solve_triangular(
cholesky_factor, d.T, lower=True, check_finite=False,
overwrite_b=True)
squared_maha = np.sum(z * z, axis=0)
return squared_maha
else:
raise ValueError('invalid distance metric')

View file

@ -0,0 +1,219 @@
import cv2
import numpy as np
import scipy
import lap
from scipy.spatial.distance import cdist
from trackers.bytetrack import kalman_filter
import time
def merge_matches(m1, m2, shape):
O,P,Q = shape
m1 = np.asarray(m1)
m2 = np.asarray(m2)
M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P))
M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q))
mask = M1*M2
match = mask.nonzero()
match = list(zip(match[0], match[1]))
unmatched_O = tuple(set(range(O)) - set([i for i, j in match]))
unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match]))
return match, unmatched_O, unmatched_Q
def _indices_to_matches(cost_matrix, indices, thresh):
matched_cost = cost_matrix[tuple(zip(*indices))]
matched_mask = (matched_cost <= thresh)
matches = indices[matched_mask]
unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
return matches, unmatched_a, unmatched_b
def linear_assignment(cost_matrix, thresh):
if cost_matrix.size == 0:
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
matches, unmatched_a, unmatched_b = [], [], []
cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
for ix, mx in enumerate(x):
if mx >= 0:
matches.append([ix, mx])
unmatched_a = np.where(x < 0)[0]
unmatched_b = np.where(y < 0)[0]
matches = np.asarray(matches)
return matches, unmatched_a, unmatched_b
def ious(atlbrs, btlbrs):
"""
Compute cost based on IoU
:type atlbrs: list[tlbr] | np.ndarray
:type atlbrs: list[tlbr] | np.ndarray
:rtype ious np.ndarray
"""
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
if ious.size == 0:
return ious
ious = bbox_ious(
np.ascontiguousarray(atlbrs, dtype=np.float32),
np.ascontiguousarray(btlbrs, dtype=np.float32)
)
return ious
def iou_distance(atracks, btracks):
"""
Compute cost based on IoU
:type atracks: list[STrack]
:type btracks: list[STrack]
:rtype cost_matrix np.ndarray
"""
if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
atlbrs = atracks
btlbrs = btracks
else:
atlbrs = [track.tlbr for track in atracks]
btlbrs = [track.tlbr for track in btracks]
_ious = ious(atlbrs, btlbrs)
cost_matrix = 1 - _ious
return cost_matrix
def v_iou_distance(atracks, btracks):
"""
Compute cost based on IoU
:type atracks: list[STrack]
:type btracks: list[STrack]
:rtype cost_matrix np.ndarray
"""
if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
atlbrs = atracks
btlbrs = btracks
else:
atlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in atracks]
btlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in btracks]
_ious = ious(atlbrs, btlbrs)
cost_matrix = 1 - _ious
return cost_matrix
def embedding_distance(tracks, detections, metric='cosine'):
"""
:param tracks: list[STrack]
:param detections: list[BaseTrack]
:param metric:
:return: cost_matrix np.ndarray
"""
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
if cost_matrix.size == 0:
return cost_matrix
det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32)
#for i, track in enumerate(tracks):
#cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric))
track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32)
cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Nomalized features
return cost_matrix
def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False):
if cost_matrix.size == 0:
return cost_matrix
gating_dim = 2 if only_position else 4
gating_threshold = kalman_filter.chi2inv95[gating_dim]
measurements = np.asarray([det.to_xyah() for det in detections])
for row, track in enumerate(tracks):
gating_distance = kf.gating_distance(
track.mean, track.covariance, measurements, only_position)
cost_matrix[row, gating_distance > gating_threshold] = np.inf
return cost_matrix
def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98):
if cost_matrix.size == 0:
return cost_matrix
gating_dim = 2 if only_position else 4
gating_threshold = kalman_filter.chi2inv95[gating_dim]
measurements = np.asarray([det.to_xyah() for det in detections])
for row, track in enumerate(tracks):
gating_distance = kf.gating_distance(
track.mean, track.covariance, measurements, only_position, metric='maha')
cost_matrix[row, gating_distance > gating_threshold] = np.inf
cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance
return cost_matrix
def fuse_iou(cost_matrix, tracks, detections):
if cost_matrix.size == 0:
return cost_matrix
reid_sim = 1 - cost_matrix
iou_dist = iou_distance(tracks, detections)
iou_sim = 1 - iou_dist
fuse_sim = reid_sim * (1 + iou_sim) / 2
det_scores = np.array([det.score for det in detections])
det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
#fuse_sim = fuse_sim * (1 + det_scores) / 2
fuse_cost = 1 - fuse_sim
return fuse_cost
def fuse_score(cost_matrix, detections):
if cost_matrix.size == 0:
return cost_matrix
iou_sim = 1 - cost_matrix
det_scores = np.array([det.score for det in detections])
det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
fuse_sim = iou_sim * det_scores
fuse_cost = 1 - fuse_sim
return fuse_cost
def bbox_ious(boxes, query_boxes):
"""
Parameters
----------
boxes: (N, 4) ndarray of float
query_boxes: (K, 4) ndarray of float
Returns
-------
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
"""
N = boxes.shape[0]
K = query_boxes.shape[0]
overlaps = np.zeros((N, K), dtype=np.float32)
for k in range(K):
box_area = (
(query_boxes[k, 2] - query_boxes[k, 0] + 1) *
(query_boxes[k, 3] - query_boxes[k, 1] + 1)
)
for n in range(N):
iw = (
min(boxes[n, 2], query_boxes[k, 2]) -
max(boxes[n, 0], query_boxes[k, 0]) + 1
)
if iw > 0:
ih = (
min(boxes[n, 3], query_boxes[k, 3]) -
max(boxes[n, 1], query_boxes[k, 1]) + 1
)
if ih > 0:
ua = float(
(boxes[n, 2] - boxes[n, 0] + 1) *
(boxes[n, 3] - boxes[n, 1] + 1) +
box_area - iw * ih
)
overlaps[n, k] = iw * ih / ua
return overlaps

View file

@ -0,0 +1,2 @@
from . import args
from . import ocsort

View file

@ -0,0 +1,110 @@
import argparse
def make_parser():
parser = argparse.ArgumentParser("OC-SORT parameters")
# distributed
parser.add_argument("-b", "--batch-size", type=int, default=1, help="batch size")
parser.add_argument("-d", "--devices", default=None, type=int, help="device for training")
parser.add_argument("--local_rank", default=0, type=int, help="local rank for dist training")
parser.add_argument("--num_machines", default=1, type=int, help="num of node for training")
parser.add_argument("--machine_rank", default=0, type=int, help="node rank for multi-node training")
parser.add_argument(
"-f",
"--exp_file",
default=None,
type=str,
help="pls input your expriment description file",
)
parser.add_argument(
"--test",
dest="test",
default=False,
action="store_true",
help="Evaluating on test-dev set.",
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
# det args
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
parser.add_argument("--conf", default=0.1, type=float, help="test conf")
parser.add_argument("--nms", default=0.7, type=float, help="test nms threshold")
parser.add_argument("--tsize", default=[800, 1440], nargs="+", type=int, help="test img size")
parser.add_argument("--seed", default=None, type=int, help="eval seed")
# tracking args
parser.add_argument("--track_thresh", type=float, default=0.6, help="detection confidence threshold")
parser.add_argument(
"--iou_thresh",
type=float,
default=0.3,
help="the iou threshold in Sort for matching",
)
parser.add_argument("--min_hits", type=int, default=3, help="min hits to create track in SORT")
parser.add_argument(
"--inertia",
type=float,
default=0.2,
help="the weight of VDC term in cost matrix",
)
parser.add_argument(
"--deltat",
type=int,
default=3,
help="time step difference to estimate direction",
)
parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
parser.add_argument(
"--match_thresh",
type=float,
default=0.9,
help="matching threshold for tracking",
)
parser.add_argument(
"--gt-type",
type=str,
default="_val_half",
help="suffix to find the gt annotation",
)
parser.add_argument("--public", action="store_true", help="use public detection")
parser.add_argument("--asso", default="iou", help="similarity function: iou/giou/diou/ciou/ctdis")
# for kitti/bdd100k inference with public detections
parser.add_argument(
"--raw_results_path",
type=str,
default="exps/permatrack_kitti_test/",
help="path to the raw tracking results from other tracks",
)
parser.add_argument("--out_path", type=str, help="path to save output results")
parser.add_argument(
"--hp",
action="store_true",
help="use head padding to add the missing objects during \
initializing the tracks (offline).",
)
# for demo video
parser.add_argument("--demo_type", default="image", help="demo type, eg. image, video and webcam")
parser.add_argument("--path", default="./videos/demo.mp4", help="path to images or video")
parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
parser.add_argument(
"--save_result",
action="store_true",
help="whether to save the inference result of image/video",
)
parser.add_argument(
"--device",
default="gpu",
type=str,
help="device to run our model, can either be cpu or gpu",
)
return parser

View file

@ -0,0 +1,445 @@
import os
import pdb
import numpy as np
from scipy.special import softmax
def iou_batch(bboxes1, bboxes2):
"""
From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2]
"""
bboxes2 = np.expand_dims(bboxes2, 0)
bboxes1 = np.expand_dims(bboxes1, 1)
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
w = np.maximum(0.0, xx2 - xx1)
h = np.maximum(0.0, yy2 - yy1)
wh = w * h
o = wh / (
(bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1])
- wh
)
return o
def giou_batch(bboxes1, bboxes2):
"""
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
:return:
"""
# for details should go to https://arxiv.org/pdf/1902.09630.pdf
# ensure predict's bbox form
bboxes2 = np.expand_dims(bboxes2, 0)
bboxes1 = np.expand_dims(bboxes1, 1)
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
w = np.maximum(0.0, xx2 - xx1)
h = np.maximum(0.0, yy2 - yy1)
wh = w * h
iou = wh / (
(bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1])
- wh
)
xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0])
yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1])
xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2])
yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3])
wc = xxc2 - xxc1
hc = yyc2 - yyc1
assert (wc > 0).all() and (hc > 0).all()
area_enclose = wc * hc
giou = iou - (area_enclose - wh) / area_enclose
giou = (giou + 1.0) / 2.0 # resize from (-1,1) to (0,1)
return giou
def diou_batch(bboxes1, bboxes2):
"""
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
:return:
"""
# for details should go to https://arxiv.org/pdf/1902.09630.pdf
# ensure predict's bbox form
bboxes2 = np.expand_dims(bboxes2, 0)
bboxes1 = np.expand_dims(bboxes1, 1)
# calculate the intersection box
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
w = np.maximum(0.0, xx2 - xx1)
h = np.maximum(0.0, yy2 - yy1)
wh = w * h
iou = wh / (
(bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1])
- wh
)
centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0
centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0
centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0
centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0
inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2
xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0])
yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1])
xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2])
yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3])
outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2
diou = iou - inner_diag / outer_diag
return (diou + 1) / 2.0 # resize from (-1,1) to (0,1)
def ciou_batch(bboxes1, bboxes2):
"""
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
:return:
"""
# for details should go to https://arxiv.org/pdf/1902.09630.pdf
# ensure predict's bbox form
bboxes2 = np.expand_dims(bboxes2, 0)
bboxes1 = np.expand_dims(bboxes1, 1)
# calculate the intersection box
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
w = np.maximum(0.0, xx2 - xx1)
h = np.maximum(0.0, yy2 - yy1)
wh = w * h
iou = wh / (
(bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1])
- wh
)
centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0
centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0
centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0
centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0
inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2
xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0])
yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1])
xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2])
yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3])
outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2
w1 = bboxes1[..., 2] - bboxes1[..., 0]
h1 = bboxes1[..., 3] - bboxes1[..., 1]
w2 = bboxes2[..., 2] - bboxes2[..., 0]
h2 = bboxes2[..., 3] - bboxes2[..., 1]
# prevent dividing over zero. add one pixel shift
h2 = h2 + 1.0
h1 = h1 + 1.0
arctan = np.arctan(w2 / h2) - np.arctan(w1 / h1)
v = (4 / (np.pi**2)) * (arctan**2)
S = 1 - iou
alpha = v / (S + v)
ciou = iou - inner_diag / outer_diag - alpha * v
return (ciou + 1) / 2.0 # resize from (-1,1) to (0,1)
def ct_dist(bboxes1, bboxes2):
"""
Measure the center distance between two sets of bounding boxes,
this is a coarse implementation, we don't recommend using it only
for association, which can be unstable and sensitive to frame rate
and object speed.
"""
bboxes2 = np.expand_dims(bboxes2, 0)
bboxes1 = np.expand_dims(bboxes1, 1)
centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0
centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0
centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0
centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0
ct_dist2 = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2
ct_dist = np.sqrt(ct_dist2)
# The linear rescaling is a naive version and needs more study
ct_dist = ct_dist / ct_dist.max()
return ct_dist.max() - ct_dist # resize to (0,1)
def speed_direction_batch(dets, tracks):
tracks = tracks[..., np.newaxis]
CX1, CY1 = (dets[:, 0] + dets[:, 2]) / 2.0, (dets[:, 1] + dets[:, 3]) / 2.0
CX2, CY2 = (tracks[:, 0] + tracks[:, 2]) / 2.0, (tracks[:, 1] + tracks[:, 3]) / 2.0
dx = CX1 - CX2
dy = CY1 - CY2
norm = np.sqrt(dx**2 + dy**2) + 1e-6
dx = dx / norm
dy = dy / norm
return dy, dx # size: num_track x num_det
def linear_assignment(cost_matrix):
try:
import lap
_, x, y = lap.lapjv(cost_matrix, extend_cost=True)
return np.array([[y[i], i] for i in x if i >= 0]) #
except ImportError:
from scipy.optimize import linear_sum_assignment
x, y = linear_sum_assignment(cost_matrix)
return np.array(list(zip(x, y)))
def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3):
"""
Assigns detections to tracked object (both represented as bounding boxes)
Returns 3 lists of matches, unmatched_detections and unmatched_trackers
"""
if len(trackers) == 0:
return (
np.empty((0, 2), dtype=int),
np.arange(len(detections)),
np.empty((0, 5), dtype=int),
)
iou_matrix = iou_batch(detections, trackers)
if min(iou_matrix.shape) > 0:
a = (iou_matrix > iou_threshold).astype(np.int32)
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
matched_indices = np.stack(np.where(a), axis=1)
else:
matched_indices = linear_assignment(-iou_matrix)
else:
matched_indices = np.empty(shape=(0, 2))
unmatched_detections = []
for d, det in enumerate(detections):
if d not in matched_indices[:, 0]:
unmatched_detections.append(d)
unmatched_trackers = []
for t, trk in enumerate(trackers):
if t not in matched_indices[:, 1]:
unmatched_trackers.append(t)
# filter out matched with low IOU
matches = []
for m in matched_indices:
if iou_matrix[m[0], m[1]] < iou_threshold:
unmatched_detections.append(m[0])
unmatched_trackers.append(m[1])
else:
matches.append(m.reshape(1, 2))
if len(matches) == 0:
matches = np.empty((0, 2), dtype=int)
else:
matches = np.concatenate(matches, axis=0)
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
def compute_aw_max_metric(emb_cost, w_association_emb, bottom=0.5):
w_emb = np.full_like(emb_cost, w_association_emb)
for idx in range(emb_cost.shape[0]):
inds = np.argsort(-emb_cost[idx])
# If there's less than two matches, just keep original weight
if len(inds) < 2:
continue
if emb_cost[idx, inds[0]] == 0:
row_weight = 0
else:
row_weight = 1 - max((emb_cost[idx, inds[1]] / emb_cost[idx, inds[0]]) - bottom, 0) / (1 - bottom)
w_emb[idx] *= row_weight
for idj in range(emb_cost.shape[1]):
inds = np.argsort(-emb_cost[:, idj])
# If there's less than two matches, just keep original weight
if len(inds) < 2:
continue
if emb_cost[inds[0], idj] == 0:
col_weight = 0
else:
col_weight = 1 - max((emb_cost[inds[1], idj] / emb_cost[inds[0], idj]) - bottom, 0) / (1 - bottom)
w_emb[:, idj] *= col_weight
return w_emb * emb_cost
def associate(
detections, trackers, iou_threshold, velocities, previous_obs, vdc_weight, emb_cost, w_assoc_emb, aw_off, aw_param
):
if len(trackers) == 0:
return (
np.empty((0, 2), dtype=int),
np.arange(len(detections)),
np.empty((0, 5), dtype=int),
)
Y, X = speed_direction_batch(detections, previous_obs)
inertia_Y, inertia_X = velocities[:, 0], velocities[:, 1]
inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1)
inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1)
diff_angle_cos = inertia_X * X + inertia_Y * Y
diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1)
diff_angle = np.arccos(diff_angle_cos)
diff_angle = (np.pi / 2.0 - np.abs(diff_angle)) / np.pi
valid_mask = np.ones(previous_obs.shape[0])
valid_mask[np.where(previous_obs[:, 4] < 0)] = 0
iou_matrix = iou_batch(detections, trackers)
scores = np.repeat(detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1)
# iou_matrix = iou_matrix * scores # a trick sometiems works, we don't encourage this
valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1)
angle_diff_cost = (valid_mask * diff_angle) * vdc_weight
angle_diff_cost = angle_diff_cost.T
angle_diff_cost = angle_diff_cost * scores
if min(iou_matrix.shape) > 0:
a = (iou_matrix > iou_threshold).astype(np.int32)
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
matched_indices = np.stack(np.where(a), axis=1)
else:
if emb_cost is None:
emb_cost = 0
else:
emb_cost = emb_cost.cpu().numpy()
emb_cost[iou_matrix <= 0] = 0
if not aw_off:
emb_cost = compute_aw_max_metric(emb_cost, w_assoc_emb, bottom=aw_param)
else:
emb_cost *= w_assoc_emb
final_cost = -(iou_matrix + angle_diff_cost + emb_cost)
matched_indices = linear_assignment(final_cost)
else:
matched_indices = np.empty(shape=(0, 2))
unmatched_detections = []
for d, det in enumerate(detections):
if d not in matched_indices[:, 0]:
unmatched_detections.append(d)
unmatched_trackers = []
for t, trk in enumerate(trackers):
if t not in matched_indices[:, 1]:
unmatched_trackers.append(t)
# filter out matched with low IOU
matches = []
for m in matched_indices:
if iou_matrix[m[0], m[1]] < iou_threshold:
unmatched_detections.append(m[0])
unmatched_trackers.append(m[1])
else:
matches.append(m.reshape(1, 2))
if len(matches) == 0:
matches = np.empty((0, 2), dtype=int)
else:
matches = np.concatenate(matches, axis=0)
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
def associate_kitti(detections, trackers, det_cates, iou_threshold, velocities, previous_obs, vdc_weight):
if len(trackers) == 0:
return (
np.empty((0, 2), dtype=int),
np.arange(len(detections)),
np.empty((0, 5), dtype=int),
)
"""
Cost from the velocity direction consistency
"""
Y, X = speed_direction_batch(detections, previous_obs)
inertia_Y, inertia_X = velocities[:, 0], velocities[:, 1]
inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1)
inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1)
diff_angle_cos = inertia_X * X + inertia_Y * Y
diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1)
diff_angle = np.arccos(diff_angle_cos)
diff_angle = (np.pi / 2.0 - np.abs(diff_angle)) / np.pi
valid_mask = np.ones(previous_obs.shape[0])
valid_mask[np.where(previous_obs[:, 4] < 0)] = 0
valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1)
scores = np.repeat(detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1)
angle_diff_cost = (valid_mask * diff_angle) * vdc_weight
angle_diff_cost = angle_diff_cost.T
angle_diff_cost = angle_diff_cost * scores
"""
Cost from IoU
"""
iou_matrix = iou_batch(detections, trackers)
"""
With multiple categories, generate the cost for catgory mismatch
"""
num_dets = detections.shape[0]
num_trk = trackers.shape[0]
cate_matrix = np.zeros((num_dets, num_trk))
for i in range(num_dets):
for j in range(num_trk):
if det_cates[i] != trackers[j, 4]:
cate_matrix[i][j] = -1e6
cost_matrix = -iou_matrix - angle_diff_cost - cate_matrix
if min(iou_matrix.shape) > 0:
a = (iou_matrix > iou_threshold).astype(np.int32)
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
matched_indices = np.stack(np.where(a), axis=1)
else:
matched_indices = linear_assignment(cost_matrix)
else:
matched_indices = np.empty(shape=(0, 2))
unmatched_detections = []
for d, det in enumerate(detections):
if d not in matched_indices[:, 0]:
unmatched_detections.append(d)
unmatched_trackers = []
for t, trk in enumerate(trackers):
if t not in matched_indices[:, 1]:
unmatched_trackers.append(t)
# filter out matched with low IOU
matches = []
for m in matched_indices:
if iou_matrix[m[0], m[1]] < iou_threshold:
unmatched_detections.append(m[0])
unmatched_trackers.append(m[1])
else:
matches.append(m.reshape(1, 2))
if len(matches) == 0:
matches = np.empty((0, 2), dtype=int)
else:
matches = np.concatenate(matches, axis=0)
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)

View file

@ -0,0 +1,170 @@
import pdb
import pickle
import os
import cv2
import numpy as np
class CMCComputer:
def __init__(self, minimum_features=10, method="sparse"):
assert method in ["file", "sparse", "sift"]
os.makedirs("./cache", exist_ok=True)
self.cache_path = "./cache/affine_ocsort.pkl"
self.cache = {}
if os.path.exists(self.cache_path):
with open(self.cache_path, "rb") as fp:
self.cache = pickle.load(fp)
self.minimum_features = minimum_features
self.prev_img = None
self.prev_desc = None
self.sparse_flow_param = dict(
maxCorners=3000,
qualityLevel=0.01,
minDistance=1,
blockSize=3,
useHarrisDetector=False,
k=0.04,
)
self.file_computed = {}
self.comp_function = None
if method == "sparse":
self.comp_function = self._affine_sparse_flow
elif method == "sift":
self.comp_function = self._affine_sift
# Same BoT-SORT CMC arrays
elif method == "file":
self.comp_function = self._affine_file
self.file_affines = {}
# Maps from tag name to file name
self.file_names = {}
# All the ablation file names
for f_name in os.listdir("./cache/cmc_files/MOT17_ablation/"):
# The tag that'll be passed into compute_affine based on image name
tag = f_name.replace("GMC-", "").replace(".txt", "") + "-FRCNN"
f_name = os.path.join("./cache/cmc_files/MOT17_ablation/", f_name)
self.file_names[tag] = f_name
for f_name in os.listdir("./cache/cmc_files/MOT20_ablation/"):
tag = f_name.replace("GMC-", "").replace(".txt", "")
f_name = os.path.join("./cache/cmc_files/MOT20_ablation/", f_name)
self.file_names[tag] = f_name
# All the test file names
for f_name in os.listdir("./cache/cmc_files/MOTChallenge/"):
tag = f_name.replace("GMC-", "").replace(".txt", "")
if "MOT17" in tag:
tag = tag + "-FRCNN"
# If it's an ablation one (not test) don't overwrite it
if tag in self.file_names:
continue
f_name = os.path.join("./cache/cmc_files/MOTChallenge/", f_name)
self.file_names[tag] = f_name
def compute_affine(self, img, bbox, tag):
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
if tag in self.cache:
A = self.cache[tag]
return A
mask = np.ones_like(img, dtype=np.uint8)
if bbox.shape[0] > 0:
bbox = np.round(bbox).astype(np.int32)
bbox[bbox < 0] = 0
for bb in bbox:
mask[bb[1] : bb[3], bb[0] : bb[2]] = 0
A = self.comp_function(img, mask, tag)
self.cache[tag] = A
return A
def _load_file(self, name):
affines = []
with open(self.file_names[name], "r") as fp:
for line in fp:
tokens = [float(f) for f in line.split("\t")[1:7]]
A = np.eye(2, 3)
A[0, 0] = tokens[0]
A[0, 1] = tokens[1]
A[0, 2] = tokens[2]
A[1, 0] = tokens[3]
A[1, 1] = tokens[4]
A[1, 2] = tokens[5]
affines.append(A)
self.file_affines[name] = affines
def _affine_file(self, frame, mask, tag):
name, num = tag.split(":")
if name not in self.file_affines:
self._load_file(name)
if name not in self.file_affines:
raise RuntimeError("Error loading file affines for CMC.")
return self.file_affines[name][int(num) - 1]
def _affine_sift(self, frame, mask, tag):
A = np.eye(2, 3)
detector = cv2.SIFT_create()
kp, desc = detector.detectAndCompute(frame, mask)
if self.prev_desc is None:
self.prev_desc = [kp, desc]
return A
if desc.shape[0] < self.minimum_features or self.prev_desc[1].shape[0] < self.minimum_features:
return A
bf = cv2.BFMatcher(cv2.NORM_L2)
matches = bf.knnMatch(self.prev_desc[1], desc, k=2)
good = []
for m, n in matches:
if m.distance < 0.7 * n.distance:
good.append(m)
if len(good) > self.minimum_features:
src_pts = np.float32([self.prev_desc[0][m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
dst_pts = np.float32([kp[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
A, _ = cv2.estimateAffinePartial2D(src_pts, dst_pts, method=cv2.RANSAC)
else:
print("Warning: not enough matching points")
if A is None:
A = np.eye(2, 3)
self.prev_desc = [kp, desc]
return A
def _affine_sparse_flow(self, frame, mask, tag):
# Initialize
A = np.eye(2, 3)
# find the keypoints
keypoints = cv2.goodFeaturesToTrack(frame, mask=mask, **self.sparse_flow_param)
# Handle first frame
if self.prev_img is None:
self.prev_img = frame
self.prev_desc = keypoints
return A
matched_kp, status, err = cv2.calcOpticalFlowPyrLK(self.prev_img, frame, self.prev_desc, None)
matched_kp = matched_kp.reshape(-1, 2)
status = status.reshape(-1)
prev_points = self.prev_desc.reshape(-1, 2)
prev_points = prev_points[status]
curr_points = matched_kp[status]
# Find rigid matrix
if prev_points.shape[0] > self.minimum_features:
A, _ = cv2.estimateAffinePartial2D(prev_points, curr_points, method=cv2.RANSAC)
else:
print("Warning: not enough matching points")
if A is None:
A = np.eye(2, 3)
self.prev_img = frame
self.prev_desc = keypoints
return A
def dump_cache(self):
with open(self.cache_path, "wb") as fp:
pickle.dump(self.cache, fp)

View file

@ -0,0 +1,12 @@
# Trial number: 137
# HOTA, MOTA, IDF1: [55.567]
deepocsort:
asso_func: giou
conf_thres: 0.5122620708221085
delta_t: 1
det_thresh: 0
inertia: 0.3941737016672115
iou_thresh: 0.22136877277096445
max_age: 50
min_hits: 1
use_byte: false

View file

@ -0,0 +1,116 @@
import pdb
from collections import OrderedDict
import os
import pickle
import torch
import cv2
import torchvision
import numpy as np
class EmbeddingComputer:
def __init__(self, dataset):
self.model = None
self.dataset = dataset
self.crop_size = (128, 384)
os.makedirs("./cache/embeddings/", exist_ok=True)
self.cache_path = "./cache/embeddings/{}_embedding.pkl"
self.cache = {}
self.cache_name = ""
def load_cache(self, path):
self.cache_name = path
cache_path = self.cache_path.format(path)
if os.path.exists(cache_path):
with open(cache_path, "rb") as fp:
self.cache = pickle.load(fp)
def compute_embedding(self, img, bbox, tag, is_numpy=True):
if self.cache_name != tag.split(":")[0]:
self.load_cache(tag.split(":")[0])
if tag in self.cache:
embs = self.cache[tag]
if embs.shape[0] != bbox.shape[0]:
raise RuntimeError(
"ERROR: The number of cached embeddings don't match the "
"number of detections.\nWas the detector model changed? Delete cache if so."
)
return embs
if self.model is None:
self.initialize_model()
# Make sure bbox is within image frame
if is_numpy:
h, w = img.shape[:2]
else:
h, w = img.shape[2:]
results = np.round(bbox).astype(np.int32)
results[:, 0] = results[:, 0].clip(0, w)
results[:, 1] = results[:, 1].clip(0, h)
results[:, 2] = results[:, 2].clip(0, w)
results[:, 3] = results[:, 3].clip(0, h)
# Generate all the crops
crops = []
for p in results:
if is_numpy:
crop = img[p[1] : p[3], p[0] : p[2]]
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
crop = cv2.resize(crop, self.crop_size, interpolation=cv2.INTER_LINEAR)
crop = torch.as_tensor(crop.astype("float32").transpose(2, 0, 1))
crop = crop.unsqueeze(0)
else:
crop = img[:, :, p[1] : p[3], p[0] : p[2]]
crop = torchvision.transforms.functional.resize(crop, self.crop_size)
crops.append(crop)
crops = torch.cat(crops, dim=0)
# Create embeddings and l2 normalize them
with torch.no_grad():
crops = crops.cuda()
crops = crops.half()
embs = self.model(crops)
embs = torch.nn.functional.normalize(embs)
embs = embs.cpu().numpy()
self.cache[tag] = embs
return embs
def initialize_model(self):
"""
model = torchreid.models.build_model(name="osnet_ain_x1_0", num_classes=2510, loss="softmax", pretrained=False)
sd = torch.load("external/weights/osnet_ain_ms_d_c.pth.tar")["state_dict"]
new_state_dict = OrderedDict()
for k, v in sd.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)
model.eval()
model.cuda()
"""
if self.dataset == "mot17":
path = "external/weights/mot17_sbs_S50.pth"
elif self.dataset == "mot20":
path = "external/weights/mot20_sbs_S50.pth"
elif self.dataset == "dance":
path = None
else:
raise RuntimeError("Need the path for a new ReID model.")
model = FastReID(path)
model.eval()
model.cuda()
model.half()
self.model = model
def dump_cache(self):
if self.cache_name:
with open(self.cache_path.format(self.cache_name), "wb") as fp:
pickle.dump(self.cache, fp)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,670 @@
"""
This script is adopted from the SORT script by Alex Bewley alex@bewley.ai
"""
from __future__ import print_function
import pdb
import pickle
import cv2
import torch
import torchvision
import numpy as np
from .association import *
from .embedding import EmbeddingComputer
from .cmc import CMCComputer
from reid_multibackend import ReIDDetectMultiBackend
def k_previous_obs(observations, cur_age, k):
if len(observations) == 0:
return [-1, -1, -1, -1, -1]
for i in range(k):
dt = k - i
if cur_age - dt in observations:
return observations[cur_age - dt]
max_age = max(observations.keys())
return observations[max_age]
def convert_bbox_to_z(bbox):
"""
Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
[x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
the aspect ratio
"""
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
x = bbox[0] + w / 2.0
y = bbox[1] + h / 2.0
s = w * h # scale is just area
r = w / float(h + 1e-6)
return np.array([x, y, s, r]).reshape((4, 1))
def convert_bbox_to_z_new(bbox):
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
x = bbox[0] + w / 2.0
y = bbox[1] + h / 2.0
return np.array([x, y, w, h]).reshape((4, 1))
def convert_x_to_bbox_new(x):
x, y, w, h = x.reshape(-1)[:4]
return np.array([x - w / 2, y - h / 2, x + w / 2, y + h / 2]).reshape(1, 4)
def convert_x_to_bbox(x, score=None):
"""
Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
[x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
"""
w = np.sqrt(x[2] * x[3])
h = x[2] / w
if score == None:
return np.array([x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0]).reshape((1, 4))
else:
return np.array([x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0, score]).reshape((1, 5))
def speed_direction(bbox1, bbox2):
cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0
cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0
speed = np.array([cy2 - cy1, cx2 - cx1])
norm = np.sqrt((cy2 - cy1) ** 2 + (cx2 - cx1) ** 2) + 1e-6
return speed / norm
def new_kf_process_noise(w, h, p=1 / 20, v=1 / 160):
Q = np.diag(
((p * w) ** 2, (p * h) ** 2, (p * w) ** 2, (p * h) ** 2, (v * w) ** 2, (v * h) ** 2, (v * w) ** 2, (v * h) ** 2)
)
return Q
def new_kf_measurement_noise(w, h, m=1 / 20):
w_var = (m * w) ** 2
h_var = (m * h) ** 2
R = np.diag((w_var, h_var, w_var, h_var))
return R
class KalmanBoxTracker(object):
"""
This class represents the internal state of individual tracked objects observed as bbox.
"""
count = 0
def __init__(self, bbox, cls, delta_t=3, orig=False, emb=None, alpha=0, new_kf=False):
"""
Initialises a tracker using initial bounding box.
"""
# define constant velocity model
if not orig:
from .kalmanfilter import KalmanFilterNew as KalmanFilter
else:
from filterpy.kalman import KalmanFilter
self.cls = cls
self.conf = bbox[-1]
self.new_kf = new_kf
if new_kf:
self.kf = KalmanFilter(dim_x=8, dim_z=4)
self.kf.F = np.array(
[
# x y w h x' y' w' h'
[1, 0, 0, 0, 1, 0, 0, 0],
[0, 1, 0, 0, 0, 1, 0, 0],
[0, 0, 1, 0, 0, 0, 1, 0],
[0, 0, 0, 1, 0, 0, 0, 1],
[0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
]
)
self.kf.H = np.array(
[
[1, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0],
]
)
_, _, w, h = convert_bbox_to_z_new(bbox).reshape(-1)
self.kf.P = new_kf_process_noise(w, h)
self.kf.P[:4, :4] *= 4
self.kf.P[4:, 4:] *= 100
# Process and measurement uncertainty happen in functions
self.bbox_to_z_func = convert_bbox_to_z_new
self.x_to_bbox_func = convert_x_to_bbox_new
else:
self.kf = KalmanFilter(dim_x=7, dim_z=4)
self.kf.F = np.array(
[
# x y s r x' y' s'
[1, 0, 0, 0, 1, 0, 0],
[0, 1, 0, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 0, 1],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1],
]
)
self.kf.H = np.array(
[
[1, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
]
)
self.kf.R[2:, 2:] *= 10.0
self.kf.P[4:, 4:] *= 1000.0 # give high uncertainty to the unobservable initial velocities
self.kf.P *= 10.0
self.kf.Q[-1, -1] *= 0.01
self.kf.Q[4:, 4:] *= 0.01
self.bbox_to_z_func = convert_bbox_to_z
self.x_to_bbox_func = convert_x_to_bbox
self.kf.x[:4] = self.bbox_to_z_func(bbox)
self.time_since_update = 0
self.id = KalmanBoxTracker.count
KalmanBoxTracker.count += 1
self.history = []
self.hits = 0
self.hit_streak = 0
self.age = 0
"""
NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of
function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a
fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), let's bear it for now.
"""
# Used for OCR
self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder
# Used to output track after min_hits reached
self.history_observations = []
# Used for velocity
self.observations = dict()
self.velocity = None
self.delta_t = delta_t
self.emb = emb
self.frozen = False
def update(self, bbox, cls):
"""
Updates the state vector with observed bbox.
"""
if bbox is not None:
self.frozen = False
self.cls = cls
if self.last_observation.sum() >= 0: # no previous observation
previous_box = None
for dt in range(self.delta_t, 0, -1):
if self.age - dt in self.observations:
previous_box = self.observations[self.age - dt]
break
if previous_box is None:
previous_box = self.last_observation
"""
Estimate the track speed direction with observations \Delta t steps away
"""
self.velocity = speed_direction(previous_box, bbox)
"""
Insert new observations. This is a ugly way to maintain both self.observations
and self.history_observations. Bear it for the moment.
"""
self.last_observation = bbox
self.observations[self.age] = bbox
self.history_observations.append(bbox)
self.time_since_update = 0
self.history = []
self.hits += 1
self.hit_streak += 1
if self.new_kf:
R = new_kf_measurement_noise(self.kf.x[2, 0], self.kf.x[3, 0])
self.kf.update(self.bbox_to_z_func(bbox), R=R)
else:
self.kf.update(self.bbox_to_z_func(bbox))
else:
self.kf.update(bbox)
self.frozen = True
def update_emb(self, emb, alpha=0.9):
self.emb = alpha * self.emb + (1 - alpha) * emb
self.emb /= np.linalg.norm(self.emb)
def get_emb(self):
return self.emb.cpu()
def apply_affine_correction(self, affine):
m = affine[:, :2]
t = affine[:, 2].reshape(2, 1)
# For OCR
if self.last_observation.sum() > 0:
ps = self.last_observation[:4].reshape(2, 2).T
ps = m @ ps + t
self.last_observation[:4] = ps.T.reshape(-1)
# Apply to each box in the range of velocity computation
for dt in range(self.delta_t, -1, -1):
if self.age - dt in self.observations:
ps = self.observations[self.age - dt][:4].reshape(2, 2).T
ps = m @ ps + t
self.observations[self.age - dt][:4] = ps.T.reshape(-1)
# Also need to change kf state, but might be frozen
self.kf.apply_affine_correction(m, t, self.new_kf)
def predict(self):
"""
Advances the state vector and returns the predicted bounding box estimate.
"""
# Don't allow negative bounding boxes
if self.new_kf:
if self.kf.x[2] + self.kf.x[6] <= 0:
self.kf.x[6] = 0
if self.kf.x[3] + self.kf.x[7] <= 0:
self.kf.x[7] = 0
# Stop velocity, will update in kf during OOS
if self.frozen:
self.kf.x[6] = self.kf.x[7] = 0
Q = new_kf_process_noise(self.kf.x[2, 0], self.kf.x[3, 0])
else:
if (self.kf.x[6] + self.kf.x[2]) <= 0:
self.kf.x[6] *= 0.0
Q = None
self.kf.predict(Q=Q)
self.age += 1
if self.time_since_update > 0:
self.hit_streak = 0
self.time_since_update += 1
self.history.append(self.x_to_bbox_func(self.kf.x))
return self.history[-1]
def get_state(self):
"""
Returns the current bounding box estimate.
"""
return self.x_to_bbox_func(self.kf.x)
def mahalanobis(self, bbox):
"""Should be run after a predict() call for accuracy."""
return self.kf.md_for_measurement(self.bbox_to_z_func(bbox))
"""
We support multiple ways for association cost calculation, by default
we use IoU. GIoU may have better performance in some situations. We note
that we hardly normalize the cost by all methods to (0,1) which may not be
the best practice.
"""
ASSO_FUNCS = {
"iou": iou_batch,
"giou": giou_batch,
"ciou": ciou_batch,
"diou": diou_batch,
"ct_dist": ct_dist,
}
class OCSort(object):
def __init__(
self,
model_weights,
device,
fp16,
det_thresh,
max_age=30,
min_hits=3,
iou_threshold=0.3,
delta_t=3,
asso_func="iou",
inertia=0.2,
w_association_emb=0.75,
alpha_fixed_emb=0.95,
aw_param=0.5,
embedding_off=False,
cmc_off=False,
aw_off=False,
new_kf_off=False,
**kwargs
):
"""
Sets key parameters for SORT
"""
self.max_age = max_age
self.min_hits = min_hits
self.iou_threshold = iou_threshold
self.trackers = []
self.frame_count = 0
self.det_thresh = det_thresh
self.delta_t = delta_t
self.asso_func = ASSO_FUNCS[asso_func]
self.inertia = inertia
self.w_association_emb = w_association_emb
self.alpha_fixed_emb = alpha_fixed_emb
self.aw_param = aw_param
KalmanBoxTracker.count = 0
self.embedder = ReIDDetectMultiBackend(weights=model_weights, device=device, fp16=fp16)
self.cmc = CMCComputer()
self.embedding_off = embedding_off
self.cmc_off = cmc_off
self.aw_off = aw_off
self.new_kf_off = new_kf_off
def update(self, dets, img_numpy, tag='blub'):
"""
Params:
dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...]
Requires: this method must be called once for each frame even with empty detections (use np.empty((0, 5)) for frames without detections).
Returns the a similar array, where the last column is the object ID.
NOTE: The number of objects returned may differ from the number of detections provided.
"""
xyxys = dets[:, 0:4]
scores = dets[:, 4]
clss = dets[:, 5]
classes = clss.numpy()
xyxys = xyxys.numpy()
scores = scores.numpy()
dets = dets[:, 0:6].numpy()
remain_inds = scores > self.det_thresh
dets = dets[remain_inds]
self.height, self.width = img_numpy.shape[:2]
# Rescale
#scale = min(img_tensor.shape[2] / img_numpy.shape[0], img_tensor.shape[3] / img_numpy.shape[1])
#dets[:, :4] /= scale
# Embedding
if self.embedding_off or dets.shape[0] == 0:
dets_embs = np.ones((dets.shape[0], 1))
else:
# (Ndets x X) [512, 1024, 2048]
#dets_embs = self.embedder.compute_embedding(img_numpy, dets[:, :4], tag)
dets_embs = self._get_features(dets[:, :4], img_numpy)
# CMC
if not self.cmc_off:
transform = self.cmc.compute_affine(img_numpy, dets[:, :4], tag)
for trk in self.trackers:
trk.apply_affine_correction(transform)
trust = (dets[:, 4] - self.det_thresh) / (1 - self.det_thresh)
af = self.alpha_fixed_emb
# From [self.alpha_fixed_emb, 1], goes to 1 as detector is less confident
dets_alpha = af + (1 - af) * (1 - trust)
# get predicted locations from existing trackers.
trks = np.zeros((len(self.trackers), 5))
trk_embs = []
to_del = []
ret = []
for t, trk in enumerate(trks):
pos = self.trackers[t].predict()[0]
trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
if np.any(np.isnan(pos)):
to_del.append(t)
else:
trk_embs.append(self.trackers[t].get_emb())
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
if len(trk_embs) > 0:
trk_embs = np.vstack(trk_embs)
else:
trk_embs = np.array(trk_embs)
for t in reversed(to_del):
self.trackers.pop(t)
velocities = np.array([trk.velocity if trk.velocity is not None else np.array((0, 0)) for trk in self.trackers])
last_boxes = np.array([trk.last_observation for trk in self.trackers])
k_observations = np.array([k_previous_obs(trk.observations, trk.age, self.delta_t) for trk in self.trackers])
"""
First round of association
"""
# (M detections X N tracks, final score)
if self.embedding_off or dets.shape[0] == 0 or trk_embs.shape[0] == 0:
stage1_emb_cost = None
else:
stage1_emb_cost = dets_embs @ trk_embs.T
matched, unmatched_dets, unmatched_trks = associate(
dets,
trks,
self.iou_threshold,
velocities,
k_observations,
self.inertia,
stage1_emb_cost,
self.w_association_emb,
self.aw_off,
self.aw_param,
)
for m in matched:
self.trackers[m[1]].update(dets[m[0], :5], dets[m[0], 5])
self.trackers[m[1]].update_emb(dets_embs[m[0]], alpha=dets_alpha[m[0]])
"""
Second round of associaton by OCR
"""
if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
left_dets = dets[unmatched_dets]
left_dets_embs = dets_embs[unmatched_dets]
left_trks = last_boxes[unmatched_trks]
left_trks_embs = trk_embs[unmatched_trks]
iou_left = self.asso_func(left_dets, left_trks)
# TODO: is better without this
emb_cost_left = left_dets_embs @ left_trks_embs.T
if self.embedding_off:
emb_cost_left = np.zeros_like(emb_cost_left)
iou_left = np.array(iou_left)
if iou_left.max() > self.iou_threshold:
"""
NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
uniform here for simplicity
"""
rematched_indices = linear_assignment(-iou_left)
to_remove_det_indices = []
to_remove_trk_indices = []
for m in rematched_indices:
det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]]
if iou_left[m[0], m[1]] < self.iou_threshold:
continue
self.trackers[trk_ind].update(dets[det_ind, :5], dets[det_ind, 5])
self.trackers[trk_ind].update_emb(dets_embs[det_ind], alpha=dets_alpha[det_ind])
to_remove_det_indices.append(det_ind)
to_remove_trk_indices.append(trk_ind)
unmatched_dets = np.setdiff1d(unmatched_dets, np.array(to_remove_det_indices))
unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices))
for m in unmatched_trks:
self.trackers[m].update(None, None)
# create and initialise new trackers for unmatched detections
for i in unmatched_dets:
trk = KalmanBoxTracker(
dets[i, :5], dets[i, 5], delta_t=self.delta_t, emb=dets_embs[i], alpha=dets_alpha[i], new_kf=not self.new_kf_off
)
self.trackers.append(trk)
i = len(self.trackers)
for trk in reversed(self.trackers):
if trk.last_observation.sum() < 0:
d = trk.get_state()[0]
else:
"""
this is optional to use the recent observation or the kalman filter prediction,
we didn't notice significant difference here
"""
d = trk.last_observation[:4]
if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
# +1 as MOT benchmark requires positive
ret.append(np.concatenate((d, [trk.id + 1], [trk.cls], [trk.conf])).reshape(1, -1))
i -= 1
# remove dead tracklet
if trk.time_since_update > self.max_age:
self.trackers.pop(i)
if len(ret) > 0:
return np.concatenate(ret)
return np.empty((0, 5))
def _xywh_to_xyxy(self, bbox_xywh):
x, y, w, h = bbox_xywh
x1 = max(int(x - w / 2), 0)
x2 = min(int(x + w / 2), self.width - 1)
y1 = max(int(y - h / 2), 0)
y2 = min(int(y + h / 2), self.height - 1)
return x1, y1, x2, y2
def _get_features(self, bbox_xywh, ori_img):
im_crops = []
for box in bbox_xywh:
x1, y1, x2, y2 = self._xywh_to_xyxy(box)
im = ori_img[y1:y2, x1:x2]
im_crops.append(im)
if im_crops:
features = self.embedder(im_crops).cpu()
else:
features = np.array([])
return features
def update_public(self, dets, cates, scores):
self.frame_count += 1
det_scores = np.ones((dets.shape[0], 1))
dets = np.concatenate((dets, det_scores), axis=1)
remain_inds = scores > self.det_thresh
cates = cates[remain_inds]
dets = dets[remain_inds]
trks = np.zeros((len(self.trackers), 5))
to_del = []
ret = []
for t, trk in enumerate(trks):
pos = self.trackers[t].predict()[0]
cat = self.trackers[t].cate
trk[:] = [pos[0], pos[1], pos[2], pos[3], cat]
if np.any(np.isnan(pos)):
to_del.append(t)
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
for t in reversed(to_del):
self.trackers.pop(t)
velocities = np.array([trk.velocity if trk.velocity is not None else np.array((0, 0)) for trk in self.trackers])
last_boxes = np.array([trk.last_observation for trk in self.trackers])
k_observations = np.array([k_previous_obs(trk.observations, trk.age, self.delta_t) for trk in self.trackers])
matched, unmatched_dets, unmatched_trks = associate_kitti(
dets,
trks,
cates,
self.iou_threshold,
velocities,
k_observations,
self.inertia,
)
for m in matched:
self.trackers[m[1]].update(dets[m[0], :])
if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
"""
The re-association stage by OCR.
NOTE: at this stage, adding other strategy might be able to continue improve
the performance, such as BYTE association by ByteTrack.
"""
left_dets = dets[unmatched_dets]
left_trks = last_boxes[unmatched_trks]
left_dets_c = left_dets.copy()
left_trks_c = left_trks.copy()
iou_left = self.asso_func(left_dets_c, left_trks_c)
iou_left = np.array(iou_left)
det_cates_left = cates[unmatched_dets]
trk_cates_left = trks[unmatched_trks][:, 4]
num_dets = unmatched_dets.shape[0]
num_trks = unmatched_trks.shape[0]
cate_matrix = np.zeros((num_dets, num_trks))
for i in range(num_dets):
for j in range(num_trks):
if det_cates_left[i] != trk_cates_left[j]:
"""
For some datasets, such as KITTI, there are different categories,
we have to avoid associate them together.
"""
cate_matrix[i][j] = -1e6
iou_left = iou_left + cate_matrix
if iou_left.max() > self.iou_threshold - 0.1:
rematched_indices = linear_assignment(-iou_left)
to_remove_det_indices = []
to_remove_trk_indices = []
for m in rematched_indices:
det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]]
if iou_left[m[0], m[1]] < self.iou_threshold - 0.1:
continue
self.trackers[trk_ind].update(dets[det_ind, :])
to_remove_det_indices.append(det_ind)
to_remove_trk_indices.append(trk_ind)
unmatched_dets = np.setdiff1d(unmatched_dets, np.array(to_remove_det_indices))
unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices))
for i in unmatched_dets:
trk = KalmanBoxTracker(dets[i, :])
trk.cate = cates[i]
self.trackers.append(trk)
i = len(self.trackers)
for trk in reversed(self.trackers):
if trk.last_observation.sum() > 0:
d = trk.last_observation[:4]
else:
d = trk.get_state()[0]
if trk.time_since_update < 1:
if (self.frame_count <= self.min_hits) or (trk.hit_streak >= self.min_hits):
# id+1 as MOT benchmark requires positive
ret.append(np.concatenate((d, [trk.id + 1], [trk.cls], [trk.conf])).reshape(1, -1))
if trk.hit_streak == self.min_hits:
# Head Padding (HP): recover the lost steps during initializing the track
for prev_i in range(self.min_hits - 1):
prev_observation = trk.history_observations[-(prev_i + 2)]
ret.append(
(
np.concatenate(
(
prev_observation[:4],
[trk.id + 1],
[trk.cls],
[trk.conf],
)
)
).reshape(1, -1)
)
i -= 1
if trk.time_since_update > self.max_age:
self.trackers.pop(i)
if len(ret) > 0:
return np.concatenate(ret)
return np.empty((0, 7))
def dump_cache(self):
self.cmc.dump_cache()
self.embedder.dump_cache()

View file

@ -0,0 +1,237 @@
import torch.nn as nn
import torch
from pathlib import Path
import numpy as np
from itertools import islice
import torchvision.transforms as transforms
import cv2
import sys
import torchvision.transforms as T
from collections import OrderedDict, namedtuple
import gdown
from os.path import exists as file_exists
from yolov8.ultralytics.yolo.utils.checks import check_requirements, check_version
from yolov8.ultralytics.yolo.utils import LOGGER
from trackers.strongsort.deep.reid_model_factory import (show_downloadeable_models, get_model_url, get_model_name,
download_url, load_pretrained_weights)
from trackers.strongsort.deep.models import build_model
def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
# Check file(s) for acceptable suffix
if file and suffix:
if isinstance(suffix, str):
suffix = [suffix]
for f in file if isinstance(file, (list, tuple)) else [file]:
s = Path(f).suffix.lower() # file suffix
if len(s):
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
class ReIDDetectMultiBackend(nn.Module):
# ReID models MultiBackend class for python inference on various backends
def __init__(self, weights='osnet_x0_25_msmt17.pt', device=torch.device('cpu'), fp16=False):
super().__init__()
w = weights[0] if isinstance(weights, list) else weights
self.pt, self.jit, self.onnx, self.xml, self.engine, self.tflite = self.model_type(w) # get backend
self.fp16 = fp16
self.fp16 &= self.pt or self.jit or self.engine # FP16
# Build transform functions
self.device = device
self.image_size=(256, 128)
self.pixel_mean=[0.485, 0.456, 0.406]
self.pixel_std=[0.229, 0.224, 0.225]
self.transforms = []
self.transforms += [T.Resize(self.image_size)]
self.transforms += [T.ToTensor()]
self.transforms += [T.Normalize(mean=self.pixel_mean, std=self.pixel_std)]
self.preprocess = T.Compose(self.transforms)
self.to_pil = T.ToPILImage()
model_name = get_model_name(w)
if w.suffix == '.pt':
model_url = get_model_url(w)
if not file_exists(w) and model_url is not None:
gdown.download(model_url, str(w), quiet=False)
elif file_exists(w):
pass
else:
print(f'No URL associated to the chosen StrongSORT weights ({w}). Choose between:')
show_downloadeable_models()
exit()
# Build model
self.model = build_model(
model_name,
num_classes=1,
pretrained=not (w and w.is_file()),
use_gpu=device
)
if self.pt: # PyTorch
# populate model arch with weights
if w and w.is_file() and w.suffix == '.pt':
load_pretrained_weights(self.model, w)
self.model.to(device).eval()
self.model.half() if self.fp16 else self.model.float()
elif self.jit:
LOGGER.info(f'Loading {w} for TorchScript inference...')
self.model = torch.jit.load(w)
self.model.half() if self.fp16 else self.model.float()
elif self.onnx: # ONNX Runtime
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
cuda = torch.cuda.is_available() and device.type != 'cpu'
#check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
import onnxruntime
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
self.session = onnxruntime.InferenceSession(str(w), providers=providers)
elif self.engine: # TensorRT
LOGGER.info(f'Loading {w} for TensorRT inference...')
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
if device.type == 'cpu':
device = torch.device('cuda:0')
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
logger = trt.Logger(trt.Logger.INFO)
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
self.model_ = runtime.deserialize_cuda_engine(f.read())
self.context = self.model_.create_execution_context()
self.bindings = OrderedDict()
self.fp16 = False # default updated below
dynamic = False
for index in range(self.model_.num_bindings):
name = self.model_.get_binding_name(index)
dtype = trt.nptype(self.model_.get_binding_dtype(index))
if self.model_.binding_is_input(index):
if -1 in tuple(self.model_.get_binding_shape(index)): # dynamic
dynamic = True
self.context.set_binding_shape(index, tuple(self.model_.get_profile_shape(0, index)[2]))
if dtype == np.float16:
self.fp16 = True
shape = tuple(self.context.get_binding_shape(index))
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
self.bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items())
batch_size = self.bindings['images'].shape[0] # if dynamic, this is instead max batch size
elif self.xml: # OpenVINO
LOGGER.info(f'Loading {w} for OpenVINO inference...')
check_requirements(('openvino',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
from openvino.runtime import Core, Layout, get_batch
ie = Core()
if not Path(w).is_file(): # if not *.xml
w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
if network.get_parameters()[0].get_layout().empty:
network.get_parameters()[0].set_layout(Layout("NCWH"))
batch_dim = get_batch(network)
if batch_dim.is_static:
batch_size = batch_dim.get_length()
self.executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
self.output_layer = next(iter(self.executable_network.outputs))
elif self.tflite:
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
from tflite_runtime.interpreter import Interpreter, load_delegate
except ImportError:
import tensorflow as tf
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
self.interpreter = tf.lite.Interpreter(model_path=w)
self.interpreter.allocate_tensors()
# Get input and output tensors.
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
# Test model on random input data.
input_data = np.array(np.random.random_sample((1,256,128,3)), dtype=np.float32)
self.interpreter.set_tensor(self.input_details[0]['index'], input_data)
self.interpreter.invoke()
# The function `get_tensor()` returns a copy of the tensor data.
output_data = self.interpreter.get_tensor(self.output_details[0]['index'])
else:
print('This model framework is not supported yet!')
exit()
@staticmethod
def model_type(p='path/to/model.pt'):
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
from trackers.reid_export import export_formats
sf = list(export_formats().Suffix) # export suffixes
check_suffix(p, sf) # checks
types = [s in Path(p).name for s in sf]
return types
def _preprocess(self, im_batch):
images = []
for element in im_batch:
image = self.to_pil(element)
image = self.preprocess(image)
images.append(image)
images = torch.stack(images, dim=0)
images = images.to(self.device)
return images
def forward(self, im_batch):
# preprocess batch
im_batch = self._preprocess(im_batch)
# batch to half
if self.fp16 and im_batch.dtype != torch.float16:
im_batch = im_batch.half()
# batch processing
features = []
if self.pt:
features = self.model(im_batch)
elif self.jit: # TorchScript
features = self.model(im_batch)
elif self.onnx: # ONNX Runtime
im_batch = im_batch.cpu().numpy() # torch to numpy
features = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im_batch})[0]
elif self.engine: # TensorRT
if True and im_batch.shape != self.bindings['images'].shape:
i_in, i_out = (self.model_.get_binding_index(x) for x in ('images', 'output'))
self.context.set_binding_shape(i_in, im_batch.shape) # reshape if dynamic
self.bindings['images'] = self.bindings['images']._replace(shape=im_batch.shape)
self.bindings['output'].data.resize_(tuple(self.context.get_binding_shape(i_out)))
s = self.bindings['images'].shape
assert im_batch.shape == s, f"input size {im_batch.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
self.binding_addrs['images'] = int(im_batch.data_ptr())
self.context.execute_v2(list(self.binding_addrs.values()))
features = self.bindings['output'].data
elif self.xml: # OpenVINO
im_batch = im_batch.cpu().numpy() # FP32
features = self.executable_network([im_batch])[self.output_layer]
else:
print('Framework not supported at the moment, we are working on it...')
exit()
if isinstance(features, (list, tuple)):
return self.from_numpy(features[0]) if len(features) == 1 else [self.from_numpy(x) for x in features]
else:
return self.from_numpy(features)
def from_numpy(self, x):
return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x
def warmup(self, imgsz=[(256, 128, 3)]):
# Warmup model by running inference once
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.tflite
if any(warmup_types) and self.device.type != 'cpu':
im = [np.empty(*imgsz).astype(np.uint8)] # input
for _ in range(2 if self.jit else 1): #
self.forward(im) # warmup

View file

@ -0,0 +1,84 @@
from trackers.strongsort.utils.parser import get_config
def create_tracker(tracker_type, tracker_config, reid_weights, device, half):
cfg = get_config()
cfg.merge_from_file(tracker_config)
if tracker_type == 'strongsort':
from trackers.strongsort.strong_sort import StrongSORT
strongsort = StrongSORT(
reid_weights,
device,
half,
max_dist=cfg.strongsort.max_dist,
max_iou_dist=cfg.strongsort.max_iou_dist,
max_age=cfg.strongsort.max_age,
max_unmatched_preds=cfg.strongsort.max_unmatched_preds,
n_init=cfg.strongsort.n_init,
nn_budget=cfg.strongsort.nn_budget,
mc_lambda=cfg.strongsort.mc_lambda,
ema_alpha=cfg.strongsort.ema_alpha,
)
return strongsort
elif tracker_type == 'ocsort':
from trackers.ocsort.ocsort import OCSort
ocsort = OCSort(
det_thresh=cfg.ocsort.det_thresh,
max_age=cfg.ocsort.max_age,
min_hits=cfg.ocsort.min_hits,
iou_threshold=cfg.ocsort.iou_thresh,
delta_t=cfg.ocsort.delta_t,
asso_func=cfg.ocsort.asso_func,
inertia=cfg.ocsort.inertia,
use_byte=cfg.ocsort.use_byte,
)
return ocsort
elif tracker_type == 'bytetrack':
from trackers.bytetrack.byte_tracker import BYTETracker
bytetracker = BYTETracker(
track_thresh=cfg.bytetrack.track_thresh,
match_thresh=cfg.bytetrack.match_thresh,
track_buffer=cfg.bytetrack.track_buffer,
frame_rate=cfg.bytetrack.frame_rate
)
return bytetracker
elif tracker_type == 'botsort':
from trackers.botsort.bot_sort import BoTSORT
botsort = BoTSORT(
reid_weights,
device,
half,
track_high_thresh=cfg.botsort.track_high_thresh,
new_track_thresh=cfg.botsort.new_track_thresh,
track_buffer =cfg.botsort.track_buffer,
match_thresh=cfg.botsort.match_thresh,
proximity_thresh=cfg.botsort.proximity_thresh,
appearance_thresh=cfg.botsort.appearance_thresh,
cmc_method =cfg.botsort.cmc_method,
frame_rate=cfg.botsort.frame_rate,
lambda_=cfg.botsort.lambda_
)
return botsort
elif tracker_type == 'deepocsort':
from trackers.deepocsort.ocsort import OCSort
botsort = OCSort(
reid_weights,
device,
half,
det_thresh=cfg.deepocsort.det_thresh,
max_age=cfg.deepocsort.max_age,
min_hits=cfg.deepocsort.min_hits,
iou_threshold=cfg.deepocsort.iou_thresh,
delta_t=cfg.deepocsort.delta_t,
asso_func=cfg.deepocsort.asso_func,
inertia=cfg.deepocsort.inertia,
)
return botsort
else:
print('No such tracker')
exit()

View file

@ -0,0 +1,377 @@
import os
import numpy as np
def iou_batch(bboxes1, bboxes2):
"""
From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2]
"""
bboxes2 = np.expand_dims(bboxes2, 0)
bboxes1 = np.expand_dims(bboxes1, 1)
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
w = np.maximum(0., xx2 - xx1)
h = np.maximum(0., yy2 - yy1)
wh = w * h
o = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh)
return(o)
def giou_batch(bboxes1, bboxes2):
"""
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
:return:
"""
# for details should go to https://arxiv.org/pdf/1902.09630.pdf
# ensure predict's bbox form
bboxes2 = np.expand_dims(bboxes2, 0)
bboxes1 = np.expand_dims(bboxes1, 1)
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
w = np.maximum(0., xx2 - xx1)
h = np.maximum(0., yy2 - yy1)
wh = w * h
iou = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh)
xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0])
yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1])
xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2])
yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3])
wc = xxc2 - xxc1
hc = yyc2 - yyc1
assert((wc > 0).all() and (hc > 0).all())
area_enclose = wc * hc
giou = iou - (area_enclose - wh) / area_enclose
giou = (giou + 1.)/2.0 # resize from (-1,1) to (0,1)
return giou
def diou_batch(bboxes1, bboxes2):
"""
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
:return:
"""
# for details should go to https://arxiv.org/pdf/1902.09630.pdf
# ensure predict's bbox form
bboxes2 = np.expand_dims(bboxes2, 0)
bboxes1 = np.expand_dims(bboxes1, 1)
# calculate the intersection box
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
w = np.maximum(0., xx2 - xx1)
h = np.maximum(0., yy2 - yy1)
wh = w * h
iou = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh)
centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0
centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0
centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0
centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0
inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2
xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0])
yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1])
xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2])
yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3])
outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2
diou = iou - inner_diag / outer_diag
return (diou + 1) / 2.0 # resize from (-1,1) to (0,1)
def ciou_batch(bboxes1, bboxes2):
"""
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
:return:
"""
# for details should go to https://arxiv.org/pdf/1902.09630.pdf
# ensure predict's bbox form
bboxes2 = np.expand_dims(bboxes2, 0)
bboxes1 = np.expand_dims(bboxes1, 1)
# calculate the intersection box
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
w = np.maximum(0., xx2 - xx1)
h = np.maximum(0., yy2 - yy1)
wh = w * h
iou = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh)
centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0
centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0
centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0
centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0
inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2
xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0])
yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1])
xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2])
yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3])
outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2
w1 = bboxes1[..., 2] - bboxes1[..., 0]
h1 = bboxes1[..., 3] - bboxes1[..., 1]
w2 = bboxes2[..., 2] - bboxes2[..., 0]
h2 = bboxes2[..., 3] - bboxes2[..., 1]
# prevent dividing over zero. add one pixel shift
h2 = h2 + 1.
h1 = h1 + 1.
arctan = np.arctan(w2/h2) - np.arctan(w1/h1)
v = (4 / (np.pi ** 2)) * (arctan ** 2)
S = 1 - iou
alpha = v / (S+v)
ciou = iou - inner_diag / outer_diag - alpha * v
return (ciou + 1) / 2.0 # resize from (-1,1) to (0,1)
def ct_dist(bboxes1, bboxes2):
"""
Measure the center distance between two sets of bounding boxes,
this is a coarse implementation, we don't recommend using it only
for association, which can be unstable and sensitive to frame rate
and object speed.
"""
bboxes2 = np.expand_dims(bboxes2, 0)
bboxes1 = np.expand_dims(bboxes1, 1)
centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0
centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0
centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0
centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0
ct_dist2 = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2
ct_dist = np.sqrt(ct_dist2)
# The linear rescaling is a naive version and needs more study
ct_dist = ct_dist / ct_dist.max()
return ct_dist.max() - ct_dist # resize to (0,1)
def speed_direction_batch(dets, tracks):
tracks = tracks[..., np.newaxis]
CX1, CY1 = (dets[:,0] + dets[:,2])/2.0, (dets[:,1]+dets[:,3])/2.0
CX2, CY2 = (tracks[:,0] + tracks[:,2]) /2.0, (tracks[:,1]+tracks[:,3])/2.0
dx = CX1 - CX2
dy = CY1 - CY2
norm = np.sqrt(dx**2 + dy**2) + 1e-6
dx = dx / norm
dy = dy / norm
return dy, dx # size: num_track x num_det
def linear_assignment(cost_matrix):
try:
import lap
_, x, y = lap.lapjv(cost_matrix, extend_cost=True)
return np.array([[y[i],i] for i in x if i >= 0]) #
except ImportError:
from scipy.optimize import linear_sum_assignment
x, y = linear_sum_assignment(cost_matrix)
return np.array(list(zip(x, y)))
def associate_detections_to_trackers(detections,trackers, iou_threshold = 0.3):
"""
Assigns detections to tracked object (both represented as bounding boxes)
Returns 3 lists of matches, unmatched_detections and unmatched_trackers
"""
if(len(trackers)==0):
return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0,5),dtype=int)
iou_matrix = iou_batch(detections, trackers)
if min(iou_matrix.shape) > 0:
a = (iou_matrix > iou_threshold).astype(np.int32)
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
matched_indices = np.stack(np.where(a), axis=1)
else:
matched_indices = linear_assignment(-iou_matrix)
else:
matched_indices = np.empty(shape=(0,2))
unmatched_detections = []
for d, det in enumerate(detections):
if(d not in matched_indices[:,0]):
unmatched_detections.append(d)
unmatched_trackers = []
for t, trk in enumerate(trackers):
if(t not in matched_indices[:,1]):
unmatched_trackers.append(t)
#filter out matched with low IOU
matches = []
for m in matched_indices:
if(iou_matrix[m[0], m[1]]<iou_threshold):
unmatched_detections.append(m[0])
unmatched_trackers.append(m[1])
else:
matches.append(m.reshape(1,2))
if(len(matches)==0):
matches = np.empty((0,2),dtype=int)
else:
matches = np.concatenate(matches,axis=0)
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
def associate(detections, trackers, iou_threshold, velocities, previous_obs, vdc_weight):
if(len(trackers)==0):
return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0,5),dtype=int)
Y, X = speed_direction_batch(detections, previous_obs)
inertia_Y, inertia_X = velocities[:,0], velocities[:,1]
inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1)
inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1)
diff_angle_cos = inertia_X * X + inertia_Y * Y
diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1)
diff_angle = np.arccos(diff_angle_cos)
diff_angle = (np.pi /2.0 - np.abs(diff_angle)) / np.pi
valid_mask = np.ones(previous_obs.shape[0])
valid_mask[np.where(previous_obs[:,4]<0)] = 0
iou_matrix = iou_batch(detections, trackers)
scores = np.repeat(detections[:,-1][:, np.newaxis], trackers.shape[0], axis=1)
# iou_matrix = iou_matrix * scores # a trick sometiems works, we don't encourage this
valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1)
angle_diff_cost = (valid_mask * diff_angle) * vdc_weight
angle_diff_cost = angle_diff_cost.T
angle_diff_cost = angle_diff_cost * scores
if min(iou_matrix.shape) > 0:
a = (iou_matrix > iou_threshold).astype(np.int32)
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
matched_indices = np.stack(np.where(a), axis=1)
else:
matched_indices = linear_assignment(-(iou_matrix+angle_diff_cost))
else:
matched_indices = np.empty(shape=(0,2))
unmatched_detections = []
for d, det in enumerate(detections):
if(d not in matched_indices[:,0]):
unmatched_detections.append(d)
unmatched_trackers = []
for t, trk in enumerate(trackers):
if(t not in matched_indices[:,1]):
unmatched_trackers.append(t)
# filter out matched with low IOU
matches = []
for m in matched_indices:
if(iou_matrix[m[0], m[1]]<iou_threshold):
unmatched_detections.append(m[0])
unmatched_trackers.append(m[1])
else:
matches.append(m.reshape(1,2))
if(len(matches)==0):
matches = np.empty((0,2),dtype=int)
else:
matches = np.concatenate(matches,axis=0)
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
def associate_kitti(detections, trackers, det_cates, iou_threshold,
velocities, previous_obs, vdc_weight):
if(len(trackers)==0):
return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0,5),dtype=int)
"""
Cost from the velocity direction consistency
"""
Y, X = speed_direction_batch(detections, previous_obs)
inertia_Y, inertia_X = velocities[:,0], velocities[:,1]
inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1)
inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1)
diff_angle_cos = inertia_X * X + inertia_Y * Y
diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1)
diff_angle = np.arccos(diff_angle_cos)
diff_angle = (np.pi /2.0 - np.abs(diff_angle)) / np.pi
valid_mask = np.ones(previous_obs.shape[0])
valid_mask[np.where(previous_obs[:,4]<0)]=0
valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1)
scores = np.repeat(detections[:,-1][:, np.newaxis], trackers.shape[0], axis=1)
angle_diff_cost = (valid_mask * diff_angle) * vdc_weight
angle_diff_cost = angle_diff_cost.T
angle_diff_cost = angle_diff_cost * scores
"""
Cost from IoU
"""
iou_matrix = iou_batch(detections, trackers)
"""
With multiple categories, generate the cost for catgory mismatch
"""
num_dets = detections.shape[0]
num_trk = trackers.shape[0]
cate_matrix = np.zeros((num_dets, num_trk))
for i in range(num_dets):
for j in range(num_trk):
if det_cates[i] != trackers[j, 4]:
cate_matrix[i][j] = -1e6
cost_matrix = - iou_matrix -angle_diff_cost - cate_matrix
if min(iou_matrix.shape) > 0:
a = (iou_matrix > iou_threshold).astype(np.int32)
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
matched_indices = np.stack(np.where(a), axis=1)
else:
matched_indices = linear_assignment(cost_matrix)
else:
matched_indices = np.empty(shape=(0,2))
unmatched_detections = []
for d, det in enumerate(detections):
if(d not in matched_indices[:,0]):
unmatched_detections.append(d)
unmatched_trackers = []
for t, trk in enumerate(trackers):
if(t not in matched_indices[:,1]):
unmatched_trackers.append(t)
#filter out matched with low IOU
matches = []
for m in matched_indices:
if(iou_matrix[m[0], m[1]]<iou_threshold):
unmatched_detections.append(m[0])
unmatched_trackers.append(m[1])
else:
matches.append(m.reshape(1,2))
if(len(matches)==0):
matches = np.empty((0,2),dtype=int)
else:
matches = np.concatenate(matches,axis=0)
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)

View file

@ -0,0 +1,12 @@
# Trial number: 137
# HOTA, MOTA, IDF1: [55.567]
ocsort:
asso_func: giou
conf_thres: 0.5122620708221085
delta_t: 1
det_thresh: 0
inertia: 0.3941737016672115
iou_thresh: 0.22136877277096445
max_age: 50
min_hits: 1
use_byte: false

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,328 @@
"""
This script is adopted from the SORT script by Alex Bewley alex@bewley.ai
"""
from __future__ import print_function
import numpy as np
from .association import *
from ultralytics.yolo.utils.ops import xywh2xyxy
def k_previous_obs(observations, cur_age, k):
if len(observations) == 0:
return [-1, -1, -1, -1, -1]
for i in range(k):
dt = k - i
if cur_age - dt in observations:
return observations[cur_age-dt]
max_age = max(observations.keys())
return observations[max_age]
def convert_bbox_to_z(bbox):
"""
Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
[x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
the aspect ratio
"""
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
x = bbox[0] + w/2.
y = bbox[1] + h/2.
s = w * h # scale is just area
r = w / float(h+1e-6)
return np.array([x, y, s, r]).reshape((4, 1))
def convert_x_to_bbox(x, score=None):
"""
Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
[x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
"""
w = np.sqrt(x[2] * x[3])
h = x[2] / w
if(score == None):
return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2.]).reshape((1, 4))
else:
return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2., score]).reshape((1, 5))
def speed_direction(bbox1, bbox2):
cx1, cy1 = (bbox1[0]+bbox1[2]) / 2.0, (bbox1[1]+bbox1[3])/2.0
cx2, cy2 = (bbox2[0]+bbox2[2]) / 2.0, (bbox2[1]+bbox2[3])/2.0
speed = np.array([cy2-cy1, cx2-cx1])
norm = np.sqrt((cy2-cy1)**2 + (cx2-cx1)**2) + 1e-6
return speed / norm
class KalmanBoxTracker(object):
"""
This class represents the internal state of individual tracked objects observed as bbox.
"""
count = 0
def __init__(self, bbox, cls, delta_t=3, orig=False):
"""
Initialises a tracker using initial bounding box.
"""
# define constant velocity model
if not orig:
from .kalmanfilter import KalmanFilterNew as KalmanFilter
self.kf = KalmanFilter(dim_x=7, dim_z=4)
else:
from filterpy.kalman import KalmanFilter
self.kf = KalmanFilter(dim_x=7, dim_z=4)
self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0], [0, 0, 1, 0, 0, 0, 1], [
0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 1]])
self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0]])
self.kf.R[2:, 2:] *= 10.
self.kf.P[4:, 4:] *= 1000. # give high uncertainty to the unobservable initial velocities
self.kf.P *= 10.
self.kf.Q[-1, -1] *= 0.01
self.kf.Q[4:, 4:] *= 0.01
self.kf.x[:4] = convert_bbox_to_z(bbox)
self.time_since_update = 0
self.id = KalmanBoxTracker.count
KalmanBoxTracker.count += 1
self.history = []
self.hits = 0
self.hit_streak = 0
self.age = 0
self.conf = bbox[-1]
self.cls = cls
"""
NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of
function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a
fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), let's bear it for now.
"""
self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder
self.observations = dict()
self.history_observations = []
self.velocity = None
self.delta_t = delta_t
def update(self, bbox, cls):
"""
Updates the state vector with observed bbox.
"""
if bbox is not None:
self.conf = bbox[-1]
self.cls = cls
if self.last_observation.sum() >= 0: # no previous observation
previous_box = None
for i in range(self.delta_t):
dt = self.delta_t - i
if self.age - dt in self.observations:
previous_box = self.observations[self.age-dt]
break
if previous_box is None:
previous_box = self.last_observation
"""
Estimate the track speed direction with observations \Delta t steps away
"""
self.velocity = speed_direction(previous_box, bbox)
"""
Insert new observations. This is a ugly way to maintain both self.observations
and self.history_observations. Bear it for the moment.
"""
self.last_observation = bbox
self.observations[self.age] = bbox
self.history_observations.append(bbox)
self.time_since_update = 0
self.history = []
self.hits += 1
self.hit_streak += 1
self.kf.update(convert_bbox_to_z(bbox))
else:
self.kf.update(bbox)
def predict(self):
"""
Advances the state vector and returns the predicted bounding box estimate.
"""
if((self.kf.x[6]+self.kf.x[2]) <= 0):
self.kf.x[6] *= 0.0
self.kf.predict()
self.age += 1
if(self.time_since_update > 0):
self.hit_streak = 0
self.time_since_update += 1
self.history.append(convert_x_to_bbox(self.kf.x))
return self.history[-1]
def get_state(self):
"""
Returns the current bounding box estimate.
"""
return convert_x_to_bbox(self.kf.x)
"""
We support multiple ways for association cost calculation, by default
we use IoU. GIoU may have better performance in some situations. We note
that we hardly normalize the cost by all methods to (0,1) which may not be
the best practice.
"""
ASSO_FUNCS = { "iou": iou_batch,
"giou": giou_batch,
"ciou": ciou_batch,
"diou": diou_batch,
"ct_dist": ct_dist}
class OCSort(object):
def __init__(self, det_thresh, max_age=30, min_hits=3,
iou_threshold=0.3, delta_t=3, asso_func="iou", inertia=0.2, use_byte=False):
"""
Sets key parameters for SORT
"""
self.max_age = max_age
self.min_hits = min_hits
self.iou_threshold = iou_threshold
self.trackers = []
self.frame_count = 0
self.det_thresh = det_thresh
self.delta_t = delta_t
self.asso_func = ASSO_FUNCS[asso_func]
self.inertia = inertia
self.use_byte = use_byte
KalmanBoxTracker.count = 0
def update(self, dets, _):
"""
Params:
dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...]
Requires: this method must be called once for each frame even with empty detections (use np.empty((0, 5)) for frames without detections).
Returns the a similar array, where the last column is the object ID.
NOTE: The number of objects returned may differ from the number of detections provided.
"""
self.frame_count += 1
xyxys = dets[:, 0:4]
confs = dets[:, 4]
clss = dets[:, 5]
classes = clss.numpy()
xyxys = xyxys.numpy()
confs = confs.numpy()
output_results = np.column_stack((xyxys, confs, classes))
inds_low = confs > 0.1
inds_high = confs < self.det_thresh
inds_second = np.logical_and(inds_low, inds_high) # self.det_thresh > score > 0.1, for second matching
dets_second = output_results[inds_second] # detections for second matching
remain_inds = confs > self.det_thresh
dets = output_results[remain_inds]
# get predicted locations from existing trackers.
trks = np.zeros((len(self.trackers), 5))
to_del = []
ret = []
for t, trk in enumerate(trks):
pos = self.trackers[t].predict()[0]
trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
if np.any(np.isnan(pos)):
to_del.append(t)
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
for t in reversed(to_del):
self.trackers.pop(t)
velocities = np.array(
[trk.velocity if trk.velocity is not None else np.array((0, 0)) for trk in self.trackers])
last_boxes = np.array([trk.last_observation for trk in self.trackers])
k_observations = np.array(
[k_previous_obs(trk.observations, trk.age, self.delta_t) for trk in self.trackers])
"""
First round of association
"""
matched, unmatched_dets, unmatched_trks = associate(
dets, trks, self.iou_threshold, velocities, k_observations, self.inertia)
for m in matched:
self.trackers[m[1]].update(dets[m[0], :5], dets[m[0], 5])
"""
Second round of associaton by OCR
"""
# BYTE association
if self.use_byte and len(dets_second) > 0 and unmatched_trks.shape[0] > 0:
u_trks = trks[unmatched_trks]
iou_left = self.asso_func(dets_second, u_trks) # iou between low score detections and unmatched tracks
iou_left = np.array(iou_left)
if iou_left.max() > self.iou_threshold:
"""
NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
uniform here for simplicity
"""
matched_indices = linear_assignment(-iou_left)
to_remove_trk_indices = []
for m in matched_indices:
det_ind, trk_ind = m[0], unmatched_trks[m[1]]
if iou_left[m[0], m[1]] < self.iou_threshold:
continue
self.trackers[trk_ind].update(dets_second[det_ind, :5], dets_second[det_ind, 5])
to_remove_trk_indices.append(trk_ind)
unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices))
if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
left_dets = dets[unmatched_dets]
left_trks = last_boxes[unmatched_trks]
iou_left = self.asso_func(left_dets, left_trks)
iou_left = np.array(iou_left)
if iou_left.max() > self.iou_threshold:
"""
NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
uniform here for simplicity
"""
rematched_indices = linear_assignment(-iou_left)
to_remove_det_indices = []
to_remove_trk_indices = []
for m in rematched_indices:
det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]]
if iou_left[m[0], m[1]] < self.iou_threshold:
continue
self.trackers[trk_ind].update(dets[det_ind, :5], dets[det_ind, 5])
to_remove_det_indices.append(det_ind)
to_remove_trk_indices.append(trk_ind)
unmatched_dets = np.setdiff1d(unmatched_dets, np.array(to_remove_det_indices))
unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices))
for m in unmatched_trks:
self.trackers[m].update(None, None)
# create and initialise new trackers for unmatched detections
for i in unmatched_dets:
trk = KalmanBoxTracker(dets[i, :5], dets[i, 5], delta_t=self.delta_t)
self.trackers.append(trk)
i = len(self.trackers)
for trk in reversed(self.trackers):
if trk.last_observation.sum() < 0:
d = trk.get_state()[0]
else:
"""
this is optional to use the recent observation or the kalman filter prediction,
we didn't notice significant difference here
"""
d = trk.last_observation[:4]
if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
# +1 as MOT benchmark requires positive
ret.append(np.concatenate((d, [trk.id+1], [trk.cls], [trk.conf])).reshape(1, -1))
i -= 1
# remove dead tracklet
if(trk.time_since_update > self.max_age):
self.trackers.pop(i)
if(len(ret) > 0):
return np.concatenate(ret)
return np.empty((0, 5))

View file

@ -0,0 +1,313 @@
import argparse
import os
# limit the number of cpus used by high performance libraries
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
import sys
import numpy as np
from pathlib import Path
import torch
import time
import platform
import pandas as pd
import subprocess
import torch.backends.cudnn as cudnn
from torch.utils.mobile_optimizer import optimize_for_mobile
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0].parents[0] # yolov5 strongsort root directory
WEIGHTS = ROOT / 'weights'
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
if str(ROOT / 'yolov5') not in sys.path:
sys.path.append(str(ROOT / 'yolov5')) # add yolov5 ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
import logging
from ultralytics.yolo.utils.torch_utils import select_device
from ultralytics.yolo.utils import LOGGER, colorstr, ops
from ultralytics.yolo.utils.checks import check_requirements, check_version
from trackers.strongsort.deep.models import build_model
from trackers.strongsort.deep.reid_model_factory import get_model_name, load_pretrained_weights
def file_size(path):
# Return file/dir size (MB)
path = Path(path)
if path.is_file():
return path.stat().st_size / 1E6
elif path.is_dir():
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
else:
return 0.0
def export_formats():
# YOLOv5 export formats
x = [
['PyTorch', '-', '.pt', True, True],
['TorchScript', 'torchscript', '.torchscript', True, True],
['ONNX', 'onnx', '.onnx', True, True],
['OpenVINO', 'openvino', '_openvino_model', True, False],
['TensorRT', 'engine', '.engine', False, True],
['TensorFlow Lite', 'tflite', '.tflite', True, False],
]
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
# YOLOv5 TorchScript model export
try:
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
f = file.with_suffix('.torchscript')
ts = torch.jit.trace(model, im, strict=False)
if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
optimize_for_mobile(ts)._save_for_lite_interpreter(str(f))
else:
ts.save(str(f))
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return f
except Exception as e:
LOGGER.info(f'{prefix} export failure: {e}')
def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')):
# ONNX export
try:
check_requirements(('onnx',))
import onnx
f = file.with_suffix('.onnx')
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
if dynamic:
dynamic = {'images': {0: 'batch'}} # shape(1,3,640,640)
dynamic['output'] = {0: 'batch'} # shape(1,25200,85)
torch.onnx.export(
model.cpu() if dynamic else model, # --dynamic only compatible with cpu
im.cpu() if dynamic else im,
f,
verbose=False,
opset_version=opset,
do_constant_folding=True,
input_names=['images'],
output_names=['output'],
dynamic_axes=dynamic or None
)
# Checks
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
onnx.save(model_onnx, f)
# Simplify
if simplify:
try:
cuda = torch.cuda.is_available()
check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
import onnxsim
LOGGER.info(f'simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(model_onnx)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
LOGGER.info(f'simplifier failure: {e}')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return f
except Exception as e:
LOGGER.info(f'export failure: {e}')
def export_openvino(file, half, prefix=colorstr('OpenVINO:')):
# YOLOv5 OpenVINO export
check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
import openvino.inference_engine as ie
try:
LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
f = str(file).replace('.pt', f'_openvino_model{os.sep}')
cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
subprocess.check_output(cmd.split()) # export
except Exception as e:
LOGGER.info(f'export failure: {e}')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return f
def export_tflite(file, half, prefix=colorstr('TFLite:')):
# YOLOv5 OpenVINO export
try:
check_requirements(('openvino2tensorflow', 'tensorflow', 'tensorflow_datasets')) # requires openvino-dev: https://pypi.org/project/openvino-dev/
import openvino.inference_engine as ie
LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
output = Path(str(file).replace(f'_openvino_model{os.sep}', f'_tflite_model{os.sep}'))
modelxml = list(Path(file).glob('*.xml'))[0]
cmd = f"openvino2tensorflow \
--model_path {modelxml} \
--model_output_path {output} \
--output_pb \
--output_saved_model \
--output_no_quant_float32_tflite \
--output_dynamic_range_quant_tflite"
subprocess.check_output(cmd.split()) # export
LOGGER.info(f'{prefix} export success, results saved in {output} ({file_size(f):.1f} MB)')
return f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
try:
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
try:
import tensorrt as trt
except Exception:
if platform.system() == 'Linux':
check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',))
import tensorrt as trt
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
grid = model.model[-1].anchor_grid
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
model.model[-1].anchor_grid = grid
else: # TensorRT >= 8
check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
export_onnx(model, im, file, 12, dynamic, simplify) # opset 13
onnx = file.with_suffix('.onnx')
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
f = file.with_suffix('.engine') # TensorRT engine file
logger = trt.Logger(trt.Logger.INFO)
if verbose:
logger.min_severity = trt.Logger.Severity.VERBOSE
builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = workspace * 1 << 30
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(str(onnx)):
raise RuntimeError(f'failed to load ONNX file: {onnx}')
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
LOGGER.info(f'{prefix} Network Description:')
for inp in inputs:
LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
for out in outputs:
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
if dynamic:
if im.shape[0] <= 1:
LOGGER.warning(f"{prefix}WARNING: --dynamic model requires maximum --batch-size argument")
profile = builder.create_optimization_profile()
for inp in inputs:
profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
config.add_optimization_profile(profile)
LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}')
if builder.platform_has_fast_fp16 and half:
config.set_flag(trt.BuilderFlag.FP16)
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
t.write(engine.serialize())
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="ReID export")
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[256, 128], help='image (h, w)')
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
parser.add_argument('--weights', nargs='+', type=str, default=WEIGHTS / 'osnet_x0_25_msmt17.pt', help='model.pt path(s)')
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
parser.add_argument('--include',
nargs='+',
default=['torchscript'],
help='torchscript, onnx, openvino, engine')
args = parser.parse_args()
t = time.time()
include = [x.lower() for x in args.include] # to lowercase
fmts = tuple(export_formats()['Argument'][1:]) # --include arguments
flags = [x in include for x in fmts]
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'
jit, onnx, openvino, engine, tflite = flags # export booleans
args.device = select_device(args.device)
if args.half:
assert args.device.type != 'cpu', '--half only compatible with GPU export, i.e. use --device 0'
assert not args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
if type(args.weights) is list:
args.weights = Path(args.weights[0])
model = build_model(
get_model_name(args.weights),
num_classes=1,
pretrained=not (args.weights and args.weights.is_file() and args.weights.suffix == '.pt'),
use_gpu=args.device
).to(args.device)
load_pretrained_weights(model, args.weights)
model.eval()
if args.optimize:
assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
im = torch.zeros(args.batch_size, 3, args.imgsz[0], args.imgsz[1]).to(args.device) # image size(1,3,640,480) BCHW iDetection
for _ in range(2):
y = model(im) # dry runs
if args.half:
im, model = im.half(), model.half() # to FP16
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {args.weights} with output shape {shape} ({file_size(args.weights):.1f} MB)")
# Exports
f = [''] * len(fmts) # exported filenames
if jit:
f[0] = export_torchscript(model, im, args.weights, args.optimize) # opset 12
if engine: # TensorRT required before ONNX
f[1] = export_engine(model, im, args.weights, args.half, args.dynamic, args.simplify, args.workspace, args.verbose)
if onnx: # OpenVINO requires ONNX
f[2] = export_onnx(model, im, args.weights, args.opset, args.dynamic, args.simplify) # opset 12
if openvino:
f[3] = export_openvino(args.weights, args.half)
if tflite:
export_tflite(f, False)
# Finish
f = [str(x) for x in f if x] # filter out '' and None
if any(f):
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
f"\nResults saved to {colorstr('bold', args.weights.parent.resolve())}"
f"\nVisualize: https://netron.app")

View file

View file

@ -0,0 +1,11 @@
strongsort:
ecc: true
ema_alpha: 0.8962157769329083
max_age: 40
max_dist: 0.1594374041012136
max_iou_dist: 0.5431835667667874
max_unmatched_preds: 0
mc_lambda: 0.995
n_init: 3
nn_budget: 100
conf_thres: 0.5122620708221085

View file

@ -0,0 +1,122 @@
from __future__ import absolute_import
import torch
from .pcb import *
from .mlfn import *
from .hacnn import *
from .osnet import *
from .senet import *
from .mudeep import *
from .nasnet import *
from .resnet import *
from .densenet import *
from .xception import *
from .osnet_ain import *
from .resnetmid import *
from .shufflenet import *
from .squeezenet import *
from .inceptionv4 import *
from .mobilenetv2 import *
from .resnet_ibn_a import *
from .resnet_ibn_b import *
from .shufflenetv2 import *
from .inceptionresnetv2 import *
__model_factory = {
# image classification models
'resnet18': resnet18,
'resnet34': resnet34,
'resnet50': resnet50,
'resnet101': resnet101,
'resnet152': resnet152,
'resnext50_32x4d': resnext50_32x4d,
'resnext101_32x8d': resnext101_32x8d,
'resnet50_fc512': resnet50_fc512,
'se_resnet50': se_resnet50,
'se_resnet50_fc512': se_resnet50_fc512,
'se_resnet101': se_resnet101,
'se_resnext50_32x4d': se_resnext50_32x4d,
'se_resnext101_32x4d': se_resnext101_32x4d,
'densenet121': densenet121,
'densenet169': densenet169,
'densenet201': densenet201,
'densenet161': densenet161,
'densenet121_fc512': densenet121_fc512,
'inceptionresnetv2': inceptionresnetv2,
'inceptionv4': inceptionv4,
'xception': xception,
'resnet50_ibn_a': resnet50_ibn_a,
'resnet50_ibn_b': resnet50_ibn_b,
# lightweight models
'nasnsetmobile': nasnetamobile,
'mobilenetv2_x1_0': mobilenetv2_x1_0,
'mobilenetv2_x1_4': mobilenetv2_x1_4,
'shufflenet': shufflenet,
'squeezenet1_0': squeezenet1_0,
'squeezenet1_0_fc512': squeezenet1_0_fc512,
'squeezenet1_1': squeezenet1_1,
'shufflenet_v2_x0_5': shufflenet_v2_x0_5,
'shufflenet_v2_x1_0': shufflenet_v2_x1_0,
'shufflenet_v2_x1_5': shufflenet_v2_x1_5,
'shufflenet_v2_x2_0': shufflenet_v2_x2_0,
# reid-specific models
'mudeep': MuDeep,
'resnet50mid': resnet50mid,
'hacnn': HACNN,
'pcb_p6': pcb_p6,
'pcb_p4': pcb_p4,
'mlfn': mlfn,
'osnet_x1_0': osnet_x1_0,
'osnet_x0_75': osnet_x0_75,
'osnet_x0_5': osnet_x0_5,
'osnet_x0_25': osnet_x0_25,
'osnet_ibn_x1_0': osnet_ibn_x1_0,
'osnet_ain_x1_0': osnet_ain_x1_0,
'osnet_ain_x0_75': osnet_ain_x0_75,
'osnet_ain_x0_5': osnet_ain_x0_5,
'osnet_ain_x0_25': osnet_ain_x0_25
}
def show_avai_models():
"""Displays available models.
Examples::
>>> from torchreid import models
>>> models.show_avai_models()
"""
print(list(__model_factory.keys()))
def build_model(
name, num_classes, loss='softmax', pretrained=True, use_gpu=True
):
"""A function wrapper for building a model.
Args:
name (str): model name.
num_classes (int): number of training identities.
loss (str, optional): loss function to optimize the model. Currently
supports "softmax" and "triplet". Default is "softmax".
pretrained (bool, optional): whether to load ImageNet-pretrained weights.
Default is True.
use_gpu (bool, optional): whether to use gpu. Default is True.
Returns:
nn.Module
Examples::
>>> from torchreid import models
>>> model = models.build_model('resnet50', 751, loss='softmax')
"""
avai_models = list(__model_factory.keys())
if name not in avai_models:
raise KeyError(
'Unknown model: {}. Must be one of {}'.format(name, avai_models)
)
return __model_factory[name](
num_classes=num_classes,
loss=loss,
pretrained=pretrained,
use_gpu=use_gpu
)

View file

@ -0,0 +1,380 @@
"""
Code source: https://github.com/pytorch/vision
"""
from __future__ import division, absolute_import
import re
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils import model_zoo
__all__ = [
'densenet121', 'densenet169', 'densenet201', 'densenet161',
'densenet121_fc512'
]
model_urls = {
'densenet121':
'https://download.pytorch.org/models/densenet121-a639ec97.pth',
'densenet169':
'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
'densenet201':
'https://download.pytorch.org/models/densenet201-c1103571.pth',
'densenet161':
'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}
class _DenseLayer(nn.Sequential):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
super(_DenseLayer, self).__init__()
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
self.add_module(
'conv1',
nn.Conv2d(
num_input_features,
bn_size * growth_rate,
kernel_size=1,
stride=1,
bias=False
)
),
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module('relu2', nn.ReLU(inplace=True)),
self.add_module(
'conv2',
nn.Conv2d(
bn_size * growth_rate,
growth_rate,
kernel_size=3,
stride=1,
padding=1,
bias=False
)
),
self.drop_rate = drop_rate
def forward(self, x):
new_features = super(_DenseLayer, self).forward(x)
if self.drop_rate > 0:
new_features = F.dropout(
new_features, p=self.drop_rate, training=self.training
)
return torch.cat([x, new_features], 1)
class _DenseBlock(nn.Sequential):
def __init__(
self, num_layers, num_input_features, bn_size, growth_rate, drop_rate
):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(
num_input_features + i*growth_rate, growth_rate, bn_size,
drop_rate
)
self.add_module('denselayer%d' % (i+1), layer)
class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
super(_Transition, self).__init__()
self.add_module('norm', nn.BatchNorm2d(num_input_features))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module(
'conv',
nn.Conv2d(
num_input_features,
num_output_features,
kernel_size=1,
stride=1,
bias=False
)
)
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
class DenseNet(nn.Module):
"""Densely connected network.
Reference:
Huang et al. Densely Connected Convolutional Networks. CVPR 2017.
Public keys:
- ``densenet121``: DenseNet121.
- ``densenet169``: DenseNet169.
- ``densenet201``: DenseNet201.
- ``densenet161``: DenseNet161.
- ``densenet121_fc512``: DenseNet121 + FC.
"""
def __init__(
self,
num_classes,
loss,
growth_rate=32,
block_config=(6, 12, 24, 16),
num_init_features=64,
bn_size=4,
drop_rate=0,
fc_dims=None,
dropout_p=None,
**kwargs
):
super(DenseNet, self).__init__()
self.loss = loss
# First convolution
self.features = nn.Sequential(
OrderedDict(
[
(
'conv0',
nn.Conv2d(
3,
num_init_features,
kernel_size=7,
stride=2,
padding=3,
bias=False
)
),
('norm0', nn.BatchNorm2d(num_init_features)),
('relu0', nn.ReLU(inplace=True)),
(
'pool0',
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
),
]
)
)
# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(
num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
drop_rate=drop_rate
)
self.features.add_module('denseblock%d' % (i+1), block)
num_features = num_features + num_layers*growth_rate
if i != len(block_config) - 1:
trans = _Transition(
num_input_features=num_features,
num_output_features=num_features // 2
)
self.features.add_module('transition%d' % (i+1), trans)
num_features = num_features // 2
# Final batch norm
self.features.add_module('norm5', nn.BatchNorm2d(num_features))
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.feature_dim = num_features
self.fc = self._construct_fc_layer(fc_dims, num_features, dropout_p)
# Linear layer
self.classifier = nn.Linear(self.feature_dim, num_classes)
self._init_params()
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
"""Constructs fully connected layer.
Args:
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
input_dim (int): input dimension
dropout_p (float): dropout probability, if None, dropout is unused
"""
if fc_dims is None:
self.feature_dim = input_dim
return None
assert isinstance(
fc_dims, (list, tuple)
), 'fc_dims must be either list or tuple, but got {}'.format(
type(fc_dims)
)
layers = []
for dim in fc_dims:
layers.append(nn.Linear(input_dim, dim))
layers.append(nn.BatchNorm1d(dim))
layers.append(nn.ReLU(inplace=True))
if dropout_p is not None:
layers.append(nn.Dropout(p=dropout_p))
input_dim = dim
self.feature_dim = fc_dims[-1]
return nn.Sequential(*layers)
def _init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu'
)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
f = self.features(x)
f = F.relu(f, inplace=True)
v = self.global_avgpool(f)
v = v.view(v.size(0), -1)
if self.fc is not None:
v = self.fc(v)
if not self.training:
return v
y = self.classifier(v)
if self.loss == 'softmax':
return y
elif self.loss == 'triplet':
return y, v
else:
raise KeyError('Unsupported loss: {}'.format(self.loss))
def init_pretrained_weights(model, model_url):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
pretrain_dict = model_zoo.load_url(model_url)
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$'
)
for key in list(pretrain_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
pretrain_dict[new_key] = pretrain_dict[key]
del pretrain_dict[key]
model_dict = model.state_dict()
pretrain_dict = {
k: v
for k, v in pretrain_dict.items()
if k in model_dict and model_dict[k].size() == v.size()
}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
"""
Dense network configurations:
--
densenet121: num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16)
densenet169: num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32)
densenet201: num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32)
densenet161: num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24)
"""
def densenet121(num_classes, loss='softmax', pretrained=True, **kwargs):
model = DenseNet(
num_classes=num_classes,
loss=loss,
num_init_features=64,
growth_rate=32,
block_config=(6, 12, 24, 16),
fc_dims=None,
dropout_p=None,
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['densenet121'])
return model
def densenet169(num_classes, loss='softmax', pretrained=True, **kwargs):
model = DenseNet(
num_classes=num_classes,
loss=loss,
num_init_features=64,
growth_rate=32,
block_config=(6, 12, 32, 32),
fc_dims=None,
dropout_p=None,
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['densenet169'])
return model
def densenet201(num_classes, loss='softmax', pretrained=True, **kwargs):
model = DenseNet(
num_classes=num_classes,
loss=loss,
num_init_features=64,
growth_rate=32,
block_config=(6, 12, 48, 32),
fc_dims=None,
dropout_p=None,
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['densenet201'])
return model
def densenet161(num_classes, loss='softmax', pretrained=True, **kwargs):
model = DenseNet(
num_classes=num_classes,
loss=loss,
num_init_features=96,
growth_rate=48,
block_config=(6, 12, 36, 24),
fc_dims=None,
dropout_p=None,
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['densenet161'])
return model
def densenet121_fc512(num_classes, loss='softmax', pretrained=True, **kwargs):
model = DenseNet(
num_classes=num_classes,
loss=loss,
num_init_features=64,
growth_rate=32,
block_config=(6, 12, 24, 16),
fc_dims=[512],
dropout_p=None,
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['densenet121'])
return model

View file

@ -0,0 +1,414 @@
from __future__ import division, absolute_import
import torch
from torch import nn
from torch.nn import functional as F
__all__ = ['HACNN']
class ConvBlock(nn.Module):
"""Basic convolutional block.
convolution + batch normalization + relu.
Args:
in_c (int): number of input channels.
out_c (int): number of output channels.
k (int or tuple): kernel size.
s (int or tuple): stride.
p (int or tuple): padding.
"""
def __init__(self, in_c, out_c, k, s=1, p=0):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
self.bn = nn.BatchNorm2d(out_c)
def forward(self, x):
return F.relu(self.bn(self.conv(x)))
class InceptionA(nn.Module):
def __init__(self, in_channels, out_channels):
super(InceptionA, self).__init__()
mid_channels = out_channels // 4
self.stream1 = nn.Sequential(
ConvBlock(in_channels, mid_channels, 1),
ConvBlock(mid_channels, mid_channels, 3, p=1),
)
self.stream2 = nn.Sequential(
ConvBlock(in_channels, mid_channels, 1),
ConvBlock(mid_channels, mid_channels, 3, p=1),
)
self.stream3 = nn.Sequential(
ConvBlock(in_channels, mid_channels, 1),
ConvBlock(mid_channels, mid_channels, 3, p=1),
)
self.stream4 = nn.Sequential(
nn.AvgPool2d(3, stride=1, padding=1),
ConvBlock(in_channels, mid_channels, 1),
)
def forward(self, x):
s1 = self.stream1(x)
s2 = self.stream2(x)
s3 = self.stream3(x)
s4 = self.stream4(x)
y = torch.cat([s1, s2, s3, s4], dim=1)
return y
class InceptionB(nn.Module):
def __init__(self, in_channels, out_channels):
super(InceptionB, self).__init__()
mid_channels = out_channels // 4
self.stream1 = nn.Sequential(
ConvBlock(in_channels, mid_channels, 1),
ConvBlock(mid_channels, mid_channels, 3, s=2, p=1),
)
self.stream2 = nn.Sequential(
ConvBlock(in_channels, mid_channels, 1),
ConvBlock(mid_channels, mid_channels, 3, p=1),
ConvBlock(mid_channels, mid_channels, 3, s=2, p=1),
)
self.stream3 = nn.Sequential(
nn.MaxPool2d(3, stride=2, padding=1),
ConvBlock(in_channels, mid_channels * 2, 1),
)
def forward(self, x):
s1 = self.stream1(x)
s2 = self.stream2(x)
s3 = self.stream3(x)
y = torch.cat([s1, s2, s3], dim=1)
return y
class SpatialAttn(nn.Module):
"""Spatial Attention (Sec. 3.1.I.1)"""
def __init__(self):
super(SpatialAttn, self).__init__()
self.conv1 = ConvBlock(1, 1, 3, s=2, p=1)
self.conv2 = ConvBlock(1, 1, 1)
def forward(self, x):
# global cross-channel averaging
x = x.mean(1, keepdim=True)
# 3-by-3 conv
x = self.conv1(x)
# bilinear resizing
x = F.upsample(
x, (x.size(2) * 2, x.size(3) * 2),
mode='bilinear',
align_corners=True
)
# scaling conv
x = self.conv2(x)
return x
class ChannelAttn(nn.Module):
"""Channel Attention (Sec. 3.1.I.2)"""
def __init__(self, in_channels, reduction_rate=16):
super(ChannelAttn, self).__init__()
assert in_channels % reduction_rate == 0
self.conv1 = ConvBlock(in_channels, in_channels // reduction_rate, 1)
self.conv2 = ConvBlock(in_channels // reduction_rate, in_channels, 1)
def forward(self, x):
# squeeze operation (global average pooling)
x = F.avg_pool2d(x, x.size()[2:])
# excitation operation (2 conv layers)
x = self.conv1(x)
x = self.conv2(x)
return x
class SoftAttn(nn.Module):
"""Soft Attention (Sec. 3.1.I)
Aim: Spatial Attention + Channel Attention
Output: attention maps with shape identical to input.
"""
def __init__(self, in_channels):
super(SoftAttn, self).__init__()
self.spatial_attn = SpatialAttn()
self.channel_attn = ChannelAttn(in_channels)
self.conv = ConvBlock(in_channels, in_channels, 1)
def forward(self, x):
y_spatial = self.spatial_attn(x)
y_channel = self.channel_attn(x)
y = y_spatial * y_channel
y = torch.sigmoid(self.conv(y))
return y
class HardAttn(nn.Module):
"""Hard Attention (Sec. 3.1.II)"""
def __init__(self, in_channels):
super(HardAttn, self).__init__()
self.fc = nn.Linear(in_channels, 4 * 2)
self.init_params()
def init_params(self):
self.fc.weight.data.zero_()
self.fc.bias.data.copy_(
torch.tensor(
[0, -0.75, 0, -0.25, 0, 0.25, 0, 0.75], dtype=torch.float
)
)
def forward(self, x):
# squeeze operation (global average pooling)
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1))
# predict transformation parameters
theta = torch.tanh(self.fc(x))
theta = theta.view(-1, 4, 2)
return theta
class HarmAttn(nn.Module):
"""Harmonious Attention (Sec. 3.1)"""
def __init__(self, in_channels):
super(HarmAttn, self).__init__()
self.soft_attn = SoftAttn(in_channels)
self.hard_attn = HardAttn(in_channels)
def forward(self, x):
y_soft_attn = self.soft_attn(x)
theta = self.hard_attn(x)
return y_soft_attn, theta
class HACNN(nn.Module):
"""Harmonious Attention Convolutional Neural Network.
Reference:
Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018.
Public keys:
- ``hacnn``: HACNN.
"""
# Args:
# num_classes (int): number of classes to predict
# nchannels (list): number of channels AFTER concatenation
# feat_dim (int): feature dimension for a single stream
# learn_region (bool): whether to learn region features (i.e. local branch)
def __init__(
self,
num_classes,
loss='softmax',
nchannels=[128, 256, 384],
feat_dim=512,
learn_region=True,
use_gpu=True,
**kwargs
):
super(HACNN, self).__init__()
self.loss = loss
self.learn_region = learn_region
self.use_gpu = use_gpu
self.conv = ConvBlock(3, 32, 3, s=2, p=1)
# Construct Inception + HarmAttn blocks
# ============== Block 1 ==============
self.inception1 = nn.Sequential(
InceptionA(32, nchannels[0]),
InceptionB(nchannels[0], nchannels[0]),
)
self.ha1 = HarmAttn(nchannels[0])
# ============== Block 2 ==============
self.inception2 = nn.Sequential(
InceptionA(nchannels[0], nchannels[1]),
InceptionB(nchannels[1], nchannels[1]),
)
self.ha2 = HarmAttn(nchannels[1])
# ============== Block 3 ==============
self.inception3 = nn.Sequential(
InceptionA(nchannels[1], nchannels[2]),
InceptionB(nchannels[2], nchannels[2]),
)
self.ha3 = HarmAttn(nchannels[2])
self.fc_global = nn.Sequential(
nn.Linear(nchannels[2], feat_dim),
nn.BatchNorm1d(feat_dim),
nn.ReLU(),
)
self.classifier_global = nn.Linear(feat_dim, num_classes)
if self.learn_region:
self.init_scale_factors()
self.local_conv1 = InceptionB(32, nchannels[0])
self.local_conv2 = InceptionB(nchannels[0], nchannels[1])
self.local_conv3 = InceptionB(nchannels[1], nchannels[2])
self.fc_local = nn.Sequential(
nn.Linear(nchannels[2] * 4, feat_dim),
nn.BatchNorm1d(feat_dim),
nn.ReLU(),
)
self.classifier_local = nn.Linear(feat_dim, num_classes)
self.feat_dim = feat_dim * 2
else:
self.feat_dim = feat_dim
def init_scale_factors(self):
# initialize scale factors (s_w, s_h) for four regions
self.scale_factors = []
self.scale_factors.append(
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
)
self.scale_factors.append(
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
)
self.scale_factors.append(
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
)
self.scale_factors.append(
torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)
)
def stn(self, x, theta):
"""Performs spatial transform
x: (batch, channel, height, width)
theta: (batch, 2, 3)
"""
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)
return x
def transform_theta(self, theta_i, region_idx):
"""Transforms theta to include (s_w, s_h), resulting in (batch, 2, 3)"""
scale_factors = self.scale_factors[region_idx]
theta = torch.zeros(theta_i.size(0), 2, 3)
theta[:, :, :2] = scale_factors
theta[:, :, -1] = theta_i
if self.use_gpu:
theta = theta.cuda()
return theta
def forward(self, x):
assert x.size(2) == 160 and x.size(3) == 64, \
'Input size does not match, expected (160, 64) but got ({}, {})'.format(x.size(2), x.size(3))
x = self.conv(x)
# ============== Block 1 ==============
# global branch
x1 = self.inception1(x)
x1_attn, x1_theta = self.ha1(x1)
x1_out = x1 * x1_attn
# local branch
if self.learn_region:
x1_local_list = []
for region_idx in range(4):
x1_theta_i = x1_theta[:, region_idx, :]
x1_theta_i = self.transform_theta(x1_theta_i, region_idx)
x1_trans_i = self.stn(x, x1_theta_i)
x1_trans_i = F.upsample(
x1_trans_i, (24, 28), mode='bilinear', align_corners=True
)
x1_local_i = self.local_conv1(x1_trans_i)
x1_local_list.append(x1_local_i)
# ============== Block 2 ==============
# Block 2
# global branch
x2 = self.inception2(x1_out)
x2_attn, x2_theta = self.ha2(x2)
x2_out = x2 * x2_attn
# local branch
if self.learn_region:
x2_local_list = []
for region_idx in range(4):
x2_theta_i = x2_theta[:, region_idx, :]
x2_theta_i = self.transform_theta(x2_theta_i, region_idx)
x2_trans_i = self.stn(x1_out, x2_theta_i)
x2_trans_i = F.upsample(
x2_trans_i, (12, 14), mode='bilinear', align_corners=True
)
x2_local_i = x2_trans_i + x1_local_list[region_idx]
x2_local_i = self.local_conv2(x2_local_i)
x2_local_list.append(x2_local_i)
# ============== Block 3 ==============
# Block 3
# global branch
x3 = self.inception3(x2_out)
x3_attn, x3_theta = self.ha3(x3)
x3_out = x3 * x3_attn
# local branch
if self.learn_region:
x3_local_list = []
for region_idx in range(4):
x3_theta_i = x3_theta[:, region_idx, :]
x3_theta_i = self.transform_theta(x3_theta_i, region_idx)
x3_trans_i = self.stn(x2_out, x3_theta_i)
x3_trans_i = F.upsample(
x3_trans_i, (6, 7), mode='bilinear', align_corners=True
)
x3_local_i = x3_trans_i + x2_local_list[region_idx]
x3_local_i = self.local_conv3(x3_local_i)
x3_local_list.append(x3_local_i)
# ============== Feature generation ==============
# global branch
x_global = F.avg_pool2d(x3_out,
x3_out.size()[2:]
).view(x3_out.size(0), x3_out.size(1))
x_global = self.fc_global(x_global)
# local branch
if self.learn_region:
x_local_list = []
for region_idx in range(4):
x_local_i = x3_local_list[region_idx]
x_local_i = F.avg_pool2d(x_local_i,
x_local_i.size()[2:]
).view(x_local_i.size(0), -1)
x_local_list.append(x_local_i)
x_local = torch.cat(x_local_list, 1)
x_local = self.fc_local(x_local)
if not self.training:
# l2 normalization before concatenation
if self.learn_region:
x_global = x_global / x_global.norm(p=2, dim=1, keepdim=True)
x_local = x_local / x_local.norm(p=2, dim=1, keepdim=True)
return torch.cat([x_global, x_local], 1)
else:
return x_global
prelogits_global = self.classifier_global(x_global)
if self.learn_region:
prelogits_local = self.classifier_local(x_local)
if self.loss == 'softmax':
if self.learn_region:
return (prelogits_global, prelogits_local)
else:
return prelogits_global
elif self.loss == 'triplet':
if self.learn_region:
return (prelogits_global, prelogits_local), (x_global, x_local)
else:
return prelogits_global, x_global
else:
raise KeyError("Unsupported loss: {}".format(self.loss))

View file

@ -0,0 +1,361 @@
"""
Code imported from https://github.com/Cadene/pretrained-models.pytorch
"""
from __future__ import division, absolute_import
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
__all__ = ['inceptionresnetv2']
pretrained_settings = {
'inceptionresnetv2': {
'imagenet': {
'url':
'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
'input_space': 'RGB',
'input_size': [3, 299, 299],
'input_range': [0, 1],
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'num_classes': 1000
},
'imagenet+background': {
'url':
'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
'input_space': 'RGB',
'input_size': [3, 299, 299],
'input_range': [0, 1],
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'num_classes': 1001
}
}
}
class BasicConv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False
) # verify bias false
self.bn = nn.BatchNorm2d(
out_planes,
eps=0.001, # value found in tensorflow
momentum=0.1, # default pytorch value
affine=True
)
self.relu = nn.ReLU(inplace=False)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class Mixed_5b(nn.Module):
def __init__(self):
super(Mixed_5b, self).__init__()
self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
self.branch1 = nn.Sequential(
BasicConv2d(192, 48, kernel_size=1, stride=1),
BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
)
self.branch2 = nn.Sequential(
BasicConv2d(192, 64, kernel_size=1, stride=1),
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
)
self.branch3 = nn.Sequential(
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
BasicConv2d(192, 64, kernel_size=1, stride=1)
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Block35(nn.Module):
def __init__(self, scale=1.0):
super(Block35, self).__init__()
self.scale = scale
self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
self.branch1 = nn.Sequential(
BasicConv2d(320, 32, kernel_size=1, stride=1),
BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
)
self.branch2 = nn.Sequential(
BasicConv2d(320, 32, kernel_size=1, stride=1),
BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
)
self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
self.relu = nn.ReLU(inplace=False)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
out = torch.cat((x0, x1, x2), 1)
out = self.conv2d(out)
out = out * self.scale + x
out = self.relu(out)
return out
class Mixed_6a(nn.Module):
def __init__(self):
super(Mixed_6a, self).__init__()
self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
self.branch1 = nn.Sequential(
BasicConv2d(320, 256, kernel_size=1, stride=1),
BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
BasicConv2d(256, 384, kernel_size=3, stride=2)
)
self.branch2 = nn.MaxPool2d(3, stride=2)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
out = torch.cat((x0, x1, x2), 1)
return out
class Block17(nn.Module):
def __init__(self, scale=1.0):
super(Block17, self).__init__()
self.scale = scale
self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
self.branch1 = nn.Sequential(
BasicConv2d(1088, 128, kernel_size=1, stride=1),
BasicConv2d(
128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)
),
BasicConv2d(
160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)
)
)
self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
self.relu = nn.ReLU(inplace=False)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
out = torch.cat((x0, x1), 1)
out = self.conv2d(out)
out = out * self.scale + x
out = self.relu(out)
return out
class Mixed_7a(nn.Module):
def __init__(self):
super(Mixed_7a, self).__init__()
self.branch0 = nn.Sequential(
BasicConv2d(1088, 256, kernel_size=1, stride=1),
BasicConv2d(256, 384, kernel_size=3, stride=2)
)
self.branch1 = nn.Sequential(
BasicConv2d(1088, 256, kernel_size=1, stride=1),
BasicConv2d(256, 288, kernel_size=3, stride=2)
)
self.branch2 = nn.Sequential(
BasicConv2d(1088, 256, kernel_size=1, stride=1),
BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
BasicConv2d(288, 320, kernel_size=3, stride=2)
)
self.branch3 = nn.MaxPool2d(3, stride=2)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Block8(nn.Module):
def __init__(self, scale=1.0, noReLU=False):
super(Block8, self).__init__()
self.scale = scale
self.noReLU = noReLU
self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
self.branch1 = nn.Sequential(
BasicConv2d(2080, 192, kernel_size=1, stride=1),
BasicConv2d(
192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)
),
BasicConv2d(
224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)
)
)
self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
if not self.noReLU:
self.relu = nn.ReLU(inplace=False)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
out = torch.cat((x0, x1), 1)
out = self.conv2d(out)
out = out * self.scale + x
if not self.noReLU:
out = self.relu(out)
return out
# ----------------
# Model Definition
# ----------------
class InceptionResNetV2(nn.Module):
"""Inception-ResNet-V2.
Reference:
Szegedy et al. Inception-v4, Inception-ResNet and the Impact of Residual
Connections on Learning. AAAI 2017.
Public keys:
- ``inceptionresnetv2``: Inception-ResNet-V2.
"""
def __init__(self, num_classes, loss='softmax', **kwargs):
super(InceptionResNetV2, self).__init__()
self.loss = loss
# Modules
self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
self.conv2d_2b = BasicConv2d(
32, 64, kernel_size=3, stride=1, padding=1
)
self.maxpool_3a = nn.MaxPool2d(3, stride=2)
self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
self.maxpool_5a = nn.MaxPool2d(3, stride=2)
self.mixed_5b = Mixed_5b()
self.repeat = nn.Sequential(
Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17),
Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17),
Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17),
Block35(scale=0.17)
)
self.mixed_6a = Mixed_6a()
self.repeat_1 = nn.Sequential(
Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10),
Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10),
Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10),
Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10),
Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10),
Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10),
Block17(scale=0.10), Block17(scale=0.10)
)
self.mixed_7a = Mixed_7a()
self.repeat_2 = nn.Sequential(
Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20),
Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20),
Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20)
)
self.block8 = Block8(noReLU=True)
self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(1536, num_classes)
def load_imagenet_weights(self):
settings = pretrained_settings['inceptionresnetv2']['imagenet']
pretrain_dict = model_zoo.load_url(settings['url'])
model_dict = self.state_dict()
pretrain_dict = {
k: v
for k, v in pretrain_dict.items()
if k in model_dict and model_dict[k].size() == v.size()
}
model_dict.update(pretrain_dict)
self.load_state_dict(model_dict)
def featuremaps(self, x):
x = self.conv2d_1a(x)
x = self.conv2d_2a(x)
x = self.conv2d_2b(x)
x = self.maxpool_3a(x)
x = self.conv2d_3b(x)
x = self.conv2d_4a(x)
x = self.maxpool_5a(x)
x = self.mixed_5b(x)
x = self.repeat(x)
x = self.mixed_6a(x)
x = self.repeat_1(x)
x = self.mixed_7a(x)
x = self.repeat_2(x)
x = self.block8(x)
x = self.conv2d_7b(x)
return x
def forward(self, x):
f = self.featuremaps(x)
v = self.global_avgpool(f)
v = v.view(v.size(0), -1)
if not self.training:
return v
y = self.classifier(v)
if self.loss == 'softmax':
return y
elif self.loss == 'triplet':
return y, v
else:
raise KeyError('Unsupported loss: {}'.format(self.loss))
def inceptionresnetv2(num_classes, loss='softmax', pretrained=True, **kwargs):
model = InceptionResNetV2(num_classes=num_classes, loss=loss, **kwargs)
if pretrained:
model.load_imagenet_weights()
return model

View file

@ -0,0 +1,381 @@
from __future__ import division, absolute_import
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
__all__ = ['inceptionv4']
"""
Code imported from https://github.com/Cadene/pretrained-models.pytorch
"""
pretrained_settings = {
'inceptionv4': {
'imagenet': {
'url':
'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth',
'input_space': 'RGB',
'input_size': [3, 299, 299],
'input_range': [0, 1],
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'num_classes': 1000
},
'imagenet+background': {
'url':
'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth',
'input_space': 'RGB',
'input_size': [3, 299, 299],
'input_range': [0, 1],
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'num_classes': 1001
}
}
}
class BasicConv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False
) # verify bias false
self.bn = nn.BatchNorm2d(
out_planes,
eps=0.001, # value found in tensorflow
momentum=0.1, # default pytorch value
affine=True
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class Mixed_3a(nn.Module):
def __init__(self):
super(Mixed_3a, self).__init__()
self.maxpool = nn.MaxPool2d(3, stride=2)
self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2)
def forward(self, x):
x0 = self.maxpool(x)
x1 = self.conv(x)
out = torch.cat((x0, x1), 1)
return out
class Mixed_4a(nn.Module):
def __init__(self):
super(Mixed_4a, self).__init__()
self.branch0 = nn.Sequential(
BasicConv2d(160, 64, kernel_size=1, stride=1),
BasicConv2d(64, 96, kernel_size=3, stride=1)
)
self.branch1 = nn.Sequential(
BasicConv2d(160, 64, kernel_size=1, stride=1),
BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)),
BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)),
BasicConv2d(64, 96, kernel_size=(3, 3), stride=1)
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
out = torch.cat((x0, x1), 1)
return out
class Mixed_5a(nn.Module):
def __init__(self):
super(Mixed_5a, self).__init__()
self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2)
self.maxpool = nn.MaxPool2d(3, stride=2)
def forward(self, x):
x0 = self.conv(x)
x1 = self.maxpool(x)
out = torch.cat((x0, x1), 1)
return out
class Inception_A(nn.Module):
def __init__(self):
super(Inception_A, self).__init__()
self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1)
self.branch1 = nn.Sequential(
BasicConv2d(384, 64, kernel_size=1, stride=1),
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1)
)
self.branch2 = nn.Sequential(
BasicConv2d(384, 64, kernel_size=1, stride=1),
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
)
self.branch3 = nn.Sequential(
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
BasicConv2d(384, 96, kernel_size=1, stride=1)
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Reduction_A(nn.Module):
def __init__(self):
super(Reduction_A, self).__init__()
self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2)
self.branch1 = nn.Sequential(
BasicConv2d(384, 192, kernel_size=1, stride=1),
BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1),
BasicConv2d(224, 256, kernel_size=3, stride=2)
)
self.branch2 = nn.MaxPool2d(3, stride=2)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
out = torch.cat((x0, x1, x2), 1)
return out
class Inception_B(nn.Module):
def __init__(self):
super(Inception_B, self).__init__()
self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1)
self.branch1 = nn.Sequential(
BasicConv2d(1024, 192, kernel_size=1, stride=1),
BasicConv2d(
192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)
),
BasicConv2d(
224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0)
)
)
self.branch2 = nn.Sequential(
BasicConv2d(1024, 192, kernel_size=1, stride=1),
BasicConv2d(
192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)
),
BasicConv2d(
192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)
),
BasicConv2d(
224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)
),
BasicConv2d(
224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)
)
)
self.branch3 = nn.Sequential(
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
BasicConv2d(1024, 128, kernel_size=1, stride=1)
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Reduction_B(nn.Module):
def __init__(self):
super(Reduction_B, self).__init__()
self.branch0 = nn.Sequential(
BasicConv2d(1024, 192, kernel_size=1, stride=1),
BasicConv2d(192, 192, kernel_size=3, stride=2)
)
self.branch1 = nn.Sequential(
BasicConv2d(1024, 256, kernel_size=1, stride=1),
BasicConv2d(
256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)
),
BasicConv2d(
256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)
), BasicConv2d(320, 320, kernel_size=3, stride=2)
)
self.branch2 = nn.MaxPool2d(3, stride=2)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
out = torch.cat((x0, x1, x2), 1)
return out
class Inception_C(nn.Module):
def __init__(self):
super(Inception_C, self).__init__()
self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1)
self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1)
self.branch1_1a = BasicConv2d(
384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)
)
self.branch1_1b = BasicConv2d(
384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)
)
self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1)
self.branch2_1 = BasicConv2d(
384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0)
)
self.branch2_2 = BasicConv2d(
448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1)
)
self.branch2_3a = BasicConv2d(
512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)
)
self.branch2_3b = BasicConv2d(
512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)
)
self.branch3 = nn.Sequential(
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
BasicConv2d(1536, 256, kernel_size=1, stride=1)
)
def forward(self, x):
x0 = self.branch0(x)
x1_0 = self.branch1_0(x)
x1_1a = self.branch1_1a(x1_0)
x1_1b = self.branch1_1b(x1_0)
x1 = torch.cat((x1_1a, x1_1b), 1)
x2_0 = self.branch2_0(x)
x2_1 = self.branch2_1(x2_0)
x2_2 = self.branch2_2(x2_1)
x2_3a = self.branch2_3a(x2_2)
x2_3b = self.branch2_3b(x2_2)
x2 = torch.cat((x2_3a, x2_3b), 1)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class InceptionV4(nn.Module):
"""Inception-v4.
Reference:
Szegedy et al. Inception-v4, Inception-ResNet and the Impact of Residual
Connections on Learning. AAAI 2017.
Public keys:
- ``inceptionv4``: InceptionV4.
"""
def __init__(self, num_classes, loss, **kwargs):
super(InceptionV4, self).__init__()
self.loss = loss
self.features = nn.Sequential(
BasicConv2d(3, 32, kernel_size=3, stride=2),
BasicConv2d(32, 32, kernel_size=3, stride=1),
BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1),
Mixed_3a(),
Mixed_4a(),
Mixed_5a(),
Inception_A(),
Inception_A(),
Inception_A(),
Inception_A(),
Reduction_A(), # Mixed_6a
Inception_B(),
Inception_B(),
Inception_B(),
Inception_B(),
Inception_B(),
Inception_B(),
Inception_B(),
Reduction_B(), # Mixed_7a
Inception_C(),
Inception_C(),
Inception_C()
)
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(1536, num_classes)
def forward(self, x):
f = self.features(x)
v = self.global_avgpool(f)
v = v.view(v.size(0), -1)
if not self.training:
return v
y = self.classifier(v)
if self.loss == 'softmax':
return y
elif self.loss == 'triplet':
return y, v
else:
raise KeyError('Unsupported loss: {}'.format(self.loss))
def init_pretrained_weights(model, model_url):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {
k: v
for k, v in pretrain_dict.items()
if k in model_dict and model_dict[k].size() == v.size()
}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
def inceptionv4(num_classes, loss='softmax', pretrained=True, **kwargs):
model = InceptionV4(num_classes, loss, **kwargs)
if pretrained:
model_url = pretrained_settings['inceptionv4']['imagenet']['url']
init_pretrained_weights(model, model_url)
return model

View file

@ -0,0 +1,269 @@
from __future__ import division, absolute_import
import torch
import torch.utils.model_zoo as model_zoo
from torch import nn
from torch.nn import functional as F
__all__ = ['mlfn']
model_urls = {
# training epoch = 5, top1 = 51.6
'imagenet':
'https://mega.nz/#!YHxAhaxC!yu9E6zWl0x5zscSouTdbZu8gdFFytDdl-RAdD2DEfpk',
}
class MLFNBlock(nn.Module):
def __init__(
self, in_channels, out_channels, stride, fsm_channels, groups=32
):
super(MLFNBlock, self).__init__()
self.groups = groups
mid_channels = out_channels // 2
# Factor Modules
self.fm_conv1 = nn.Conv2d(in_channels, mid_channels, 1, bias=False)
self.fm_bn1 = nn.BatchNorm2d(mid_channels)
self.fm_conv2 = nn.Conv2d(
mid_channels,
mid_channels,
3,
stride=stride,
padding=1,
bias=False,
groups=self.groups
)
self.fm_bn2 = nn.BatchNorm2d(mid_channels)
self.fm_conv3 = nn.Conv2d(mid_channels, out_channels, 1, bias=False)
self.fm_bn3 = nn.BatchNorm2d(out_channels)
# Factor Selection Module
self.fsm = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, fsm_channels[0], 1),
nn.BatchNorm2d(fsm_channels[0]),
nn.ReLU(inplace=True),
nn.Conv2d(fsm_channels[0], fsm_channels[1], 1),
nn.BatchNorm2d(fsm_channels[1]),
nn.ReLU(inplace=True),
nn.Conv2d(fsm_channels[1], self.groups, 1),
nn.BatchNorm2d(self.groups),
nn.Sigmoid(),
)
self.downsample = None
if in_channels != out_channels or stride > 1:
self.downsample = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, 1, stride=stride, bias=False
),
nn.BatchNorm2d(out_channels),
)
def forward(self, x):
residual = x
s = self.fsm(x)
# reduce dimension
x = self.fm_conv1(x)
x = self.fm_bn1(x)
x = F.relu(x, inplace=True)
# group convolution
x = self.fm_conv2(x)
x = self.fm_bn2(x)
x = F.relu(x, inplace=True)
# factor selection
b, c = x.size(0), x.size(1)
n = c // self.groups
ss = s.repeat(1, n, 1, 1) # from (b, g, 1, 1) to (b, g*n=c, 1, 1)
ss = ss.view(b, n, self.groups, 1, 1)
ss = ss.permute(0, 2, 1, 3, 4).contiguous()
ss = ss.view(b, c, 1, 1)
x = ss * x
# recover dimension
x = self.fm_conv3(x)
x = self.fm_bn3(x)
x = F.relu(x, inplace=True)
if self.downsample is not None:
residual = self.downsample(residual)
return F.relu(residual + x, inplace=True), s
class MLFN(nn.Module):
"""Multi-Level Factorisation Net.
Reference:
Chang et al. Multi-Level Factorisation Net for
Person Re-Identification. CVPR 2018.
Public keys:
- ``mlfn``: MLFN (Multi-Level Factorisation Net).
"""
def __init__(
self,
num_classes,
loss='softmax',
groups=32,
channels=[64, 256, 512, 1024, 2048],
embed_dim=1024,
**kwargs
):
super(MLFN, self).__init__()
self.loss = loss
self.groups = groups
# first convolutional layer
self.conv1 = nn.Conv2d(3, channels[0], 7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(channels[0])
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
# main body
self.feature = nn.ModuleList(
[
# layer 1-3
MLFNBlock(channels[0], channels[1], 1, [128, 64], self.groups),
MLFNBlock(channels[1], channels[1], 1, [128, 64], self.groups),
MLFNBlock(channels[1], channels[1], 1, [128, 64], self.groups),
# layer 4-7
MLFNBlock(
channels[1], channels[2], 2, [256, 128], self.groups
),
MLFNBlock(
channels[2], channels[2], 1, [256, 128], self.groups
),
MLFNBlock(
channels[2], channels[2], 1, [256, 128], self.groups
),
MLFNBlock(
channels[2], channels[2], 1, [256, 128], self.groups
),
# layer 8-13
MLFNBlock(
channels[2], channels[3], 2, [512, 128], self.groups
),
MLFNBlock(
channels[3], channels[3], 1, [512, 128], self.groups
),
MLFNBlock(
channels[3], channels[3], 1, [512, 128], self.groups
),
MLFNBlock(
channels[3], channels[3], 1, [512, 128], self.groups
),
MLFNBlock(
channels[3], channels[3], 1, [512, 128], self.groups
),
MLFNBlock(
channels[3], channels[3], 1, [512, 128], self.groups
),
# layer 14-16
MLFNBlock(
channels[3], channels[4], 2, [512, 128], self.groups
),
MLFNBlock(
channels[4], channels[4], 1, [512, 128], self.groups
),
MLFNBlock(
channels[4], channels[4], 1, [512, 128], self.groups
),
]
)
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
# projection functions
self.fc_x = nn.Sequential(
nn.Conv2d(channels[4], embed_dim, 1, bias=False),
nn.BatchNorm2d(embed_dim),
nn.ReLU(inplace=True),
)
self.fc_s = nn.Sequential(
nn.Conv2d(self.groups * 16, embed_dim, 1, bias=False),
nn.BatchNorm2d(embed_dim),
nn.ReLU(inplace=True),
)
self.classifier = nn.Linear(embed_dim, num_classes)
self.init_params()
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu'
)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x, inplace=True)
x = self.maxpool(x)
s_hat = []
for block in self.feature:
x, s = block(x)
s_hat.append(s)
s_hat = torch.cat(s_hat, 1)
x = self.global_avgpool(x)
x = self.fc_x(x)
s_hat = self.fc_s(s_hat)
v = (x+s_hat) * 0.5
v = v.view(v.size(0), -1)
if not self.training:
return v
y = self.classifier(v)
if self.loss == 'softmax':
return y
elif self.loss == 'triplet':
return y, v
else:
raise KeyError('Unsupported loss: {}'.format(self.loss))
def init_pretrained_weights(model, model_url):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {
k: v
for k, v in pretrain_dict.items()
if k in model_dict and model_dict[k].size() == v.size()
}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
def mlfn(num_classes, loss='softmax', pretrained=True, **kwargs):
model = MLFN(num_classes, loss, **kwargs)
if pretrained:
# init_pretrained_weights(model, model_urls['imagenet'])
import warnings
warnings.warn(
'The imagenet pretrained weights need to be manually downloaded from {}'
.format(model_urls['imagenet'])
)
return model

View file

@ -0,0 +1,274 @@
from __future__ import division, absolute_import
import torch.utils.model_zoo as model_zoo
from torch import nn
from torch.nn import functional as F
__all__ = ['mobilenetv2_x1_0', 'mobilenetv2_x1_4']
model_urls = {
# 1.0: top-1 71.3
'mobilenetv2_x1_0':
'https://mega.nz/#!NKp2wAIA!1NH1pbNzY_M2hVk_hdsxNM1NUOWvvGPHhaNr-fASF6c',
# 1.4: top-1 73.9
'mobilenetv2_x1_4':
'https://mega.nz/#!RGhgEIwS!xN2s2ZdyqI6vQ3EwgmRXLEW3khr9tpXg96G9SUJugGk',
}
class ConvBlock(nn.Module):
"""Basic convolutional block.
convolution (bias discarded) + batch normalization + relu6.
Args:
in_c (int): number of input channels.
out_c (int): number of output channels.
k (int or tuple): kernel size.
s (int or tuple): stride.
p (int or tuple): padding.
g (int): number of blocked connections from input channels
to output channels (default: 1).
"""
def __init__(self, in_c, out_c, k, s=1, p=0, g=1):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(
in_c, out_c, k, stride=s, padding=p, bias=False, groups=g
)
self.bn = nn.BatchNorm2d(out_c)
def forward(self, x):
return F.relu6(self.bn(self.conv(x)))
class Bottleneck(nn.Module):
def __init__(self, in_channels, out_channels, expansion_factor, stride=1):
super(Bottleneck, self).__init__()
mid_channels = in_channels * expansion_factor
self.use_residual = stride == 1 and in_channels == out_channels
self.conv1 = ConvBlock(in_channels, mid_channels, 1)
self.dwconv2 = ConvBlock(
mid_channels, mid_channels, 3, stride, 1, g=mid_channels
)
self.conv3 = nn.Sequential(
nn.Conv2d(mid_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
)
def forward(self, x):
m = self.conv1(x)
m = self.dwconv2(m)
m = self.conv3(m)
if self.use_residual:
return x + m
else:
return m
class MobileNetV2(nn.Module):
"""MobileNetV2.
Reference:
Sandler et al. MobileNetV2: Inverted Residuals and
Linear Bottlenecks. CVPR 2018.
Public keys:
- ``mobilenetv2_x1_0``: MobileNetV2 x1.0.
- ``mobilenetv2_x1_4``: MobileNetV2 x1.4.
"""
def __init__(
self,
num_classes,
width_mult=1,
loss='softmax',
fc_dims=None,
dropout_p=None,
**kwargs
):
super(MobileNetV2, self).__init__()
self.loss = loss
self.in_channels = int(32 * width_mult)
self.feature_dim = int(1280 * width_mult) if width_mult > 1 else 1280
# construct layers
self.conv1 = ConvBlock(3, self.in_channels, 3, s=2, p=1)
self.conv2 = self._make_layer(
Bottleneck, 1, int(16 * width_mult), 1, 1
)
self.conv3 = self._make_layer(
Bottleneck, 6, int(24 * width_mult), 2, 2
)
self.conv4 = self._make_layer(
Bottleneck, 6, int(32 * width_mult), 3, 2
)
self.conv5 = self._make_layer(
Bottleneck, 6, int(64 * width_mult), 4, 2
)
self.conv6 = self._make_layer(
Bottleneck, 6, int(96 * width_mult), 3, 1
)
self.conv7 = self._make_layer(
Bottleneck, 6, int(160 * width_mult), 3, 2
)
self.conv8 = self._make_layer(
Bottleneck, 6, int(320 * width_mult), 1, 1
)
self.conv9 = ConvBlock(self.in_channels, self.feature_dim, 1)
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = self._construct_fc_layer(
fc_dims, self.feature_dim, dropout_p
)
self.classifier = nn.Linear(self.feature_dim, num_classes)
self._init_params()
def _make_layer(self, block, t, c, n, s):
# t: expansion factor
# c: output channels
# n: number of blocks
# s: stride for first layer
layers = []
layers.append(block(self.in_channels, c, t, s))
self.in_channels = c
for i in range(1, n):
layers.append(block(self.in_channels, c, t))
return nn.Sequential(*layers)
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
"""Constructs fully connected layer.
Args:
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
input_dim (int): input dimension
dropout_p (float): dropout probability, if None, dropout is unused
"""
if fc_dims is None:
self.feature_dim = input_dim
return None
assert isinstance(
fc_dims, (list, tuple)
), 'fc_dims must be either list or tuple, but got {}'.format(
type(fc_dims)
)
layers = []
for dim in fc_dims:
layers.append(nn.Linear(input_dim, dim))
layers.append(nn.BatchNorm1d(dim))
layers.append(nn.ReLU(inplace=True))
if dropout_p is not None:
layers.append(nn.Dropout(p=dropout_p))
input_dim = dim
self.feature_dim = fc_dims[-1]
return nn.Sequential(*layers)
def _init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu'
)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def featuremaps(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = self.conv6(x)
x = self.conv7(x)
x = self.conv8(x)
x = self.conv9(x)
return x
def forward(self, x):
f = self.featuremaps(x)
v = self.global_avgpool(f)
v = v.view(v.size(0), -1)
if self.fc is not None:
v = self.fc(v)
if not self.training:
return v
y = self.classifier(v)
if self.loss == 'softmax':
return y
elif self.loss == 'triplet':
return y, v
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
def init_pretrained_weights(model, model_url):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {
k: v
for k, v in pretrain_dict.items()
if k in model_dict and model_dict[k].size() == v.size()
}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
def mobilenetv2_x1_0(num_classes, loss, pretrained=True, **kwargs):
model = MobileNetV2(
num_classes,
loss=loss,
width_mult=1,
fc_dims=None,
dropout_p=None,
**kwargs
)
if pretrained:
# init_pretrained_weights(model, model_urls['mobilenetv2_x1_0'])
import warnings
warnings.warn(
'The imagenet pretrained weights need to be manually downloaded from {}'
.format(model_urls['mobilenetv2_x1_0'])
)
return model
def mobilenetv2_x1_4(num_classes, loss, pretrained=True, **kwargs):
model = MobileNetV2(
num_classes,
loss=loss,
width_mult=1.4,
fc_dims=None,
dropout_p=None,
**kwargs
)
if pretrained:
# init_pretrained_weights(model, model_urls['mobilenetv2_x1_4'])
import warnings
warnings.warn(
'The imagenet pretrained weights need to be manually downloaded from {}'
.format(model_urls['mobilenetv2_x1_4'])
)
return model

View file

@ -0,0 +1,206 @@
from __future__ import division, absolute_import
import torch
from torch import nn
from torch.nn import functional as F
__all__ = ['MuDeep']
class ConvBlock(nn.Module):
"""Basic convolutional block.
convolution + batch normalization + relu.
Args:
in_c (int): number of input channels.
out_c (int): number of output channels.
k (int or tuple): kernel size.
s (int or tuple): stride.
p (int or tuple): padding.
"""
def __init__(self, in_c, out_c, k, s, p):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
self.bn = nn.BatchNorm2d(out_c)
def forward(self, x):
return F.relu(self.bn(self.conv(x)))
class ConvLayers(nn.Module):
"""Preprocessing layers."""
def __init__(self):
super(ConvLayers, self).__init__()
self.conv1 = ConvBlock(3, 48, k=3, s=1, p=1)
self.conv2 = ConvBlock(48, 96, k=3, s=1, p=1)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.maxpool(x)
return x
class MultiScaleA(nn.Module):
"""Multi-scale stream layer A (Sec.3.1)"""
def __init__(self):
super(MultiScaleA, self).__init__()
self.stream1 = nn.Sequential(
ConvBlock(96, 96, k=1, s=1, p=0),
ConvBlock(96, 24, k=3, s=1, p=1),
)
self.stream2 = nn.Sequential(
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
ConvBlock(96, 24, k=1, s=1, p=0),
)
self.stream3 = ConvBlock(96, 24, k=1, s=1, p=0)
self.stream4 = nn.Sequential(
ConvBlock(96, 16, k=1, s=1, p=0),
ConvBlock(16, 24, k=3, s=1, p=1),
ConvBlock(24, 24, k=3, s=1, p=1),
)
def forward(self, x):
s1 = self.stream1(x)
s2 = self.stream2(x)
s3 = self.stream3(x)
s4 = self.stream4(x)
y = torch.cat([s1, s2, s3, s4], dim=1)
return y
class Reduction(nn.Module):
"""Reduction layer (Sec.3.1)"""
def __init__(self):
super(Reduction, self).__init__()
self.stream1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.stream2 = ConvBlock(96, 96, k=3, s=2, p=1)
self.stream3 = nn.Sequential(
ConvBlock(96, 48, k=1, s=1, p=0),
ConvBlock(48, 56, k=3, s=1, p=1),
ConvBlock(56, 64, k=3, s=2, p=1),
)
def forward(self, x):
s1 = self.stream1(x)
s2 = self.stream2(x)
s3 = self.stream3(x)
y = torch.cat([s1, s2, s3], dim=1)
return y
class MultiScaleB(nn.Module):
"""Multi-scale stream layer B (Sec.3.1)"""
def __init__(self):
super(MultiScaleB, self).__init__()
self.stream1 = nn.Sequential(
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
ConvBlock(256, 256, k=1, s=1, p=0),
)
self.stream2 = nn.Sequential(
ConvBlock(256, 64, k=1, s=1, p=0),
ConvBlock(64, 128, k=(1, 3), s=1, p=(0, 1)),
ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)),
)
self.stream3 = ConvBlock(256, 256, k=1, s=1, p=0)
self.stream4 = nn.Sequential(
ConvBlock(256, 64, k=1, s=1, p=0),
ConvBlock(64, 64, k=(1, 3), s=1, p=(0, 1)),
ConvBlock(64, 128, k=(3, 1), s=1, p=(1, 0)),
ConvBlock(128, 128, k=(1, 3), s=1, p=(0, 1)),
ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)),
)
def forward(self, x):
s1 = self.stream1(x)
s2 = self.stream2(x)
s3 = self.stream3(x)
s4 = self.stream4(x)
return s1, s2, s3, s4
class Fusion(nn.Module):
"""Saliency-based learning fusion layer (Sec.3.2)"""
def __init__(self):
super(Fusion, self).__init__()
self.a1 = nn.Parameter(torch.rand(1, 256, 1, 1))
self.a2 = nn.Parameter(torch.rand(1, 256, 1, 1))
self.a3 = nn.Parameter(torch.rand(1, 256, 1, 1))
self.a4 = nn.Parameter(torch.rand(1, 256, 1, 1))
# We add an average pooling layer to reduce the spatial dimension
# of feature maps, which differs from the original paper.
self.avgpool = nn.AvgPool2d(kernel_size=4, stride=4, padding=0)
def forward(self, x1, x2, x3, x4):
s1 = self.a1.expand_as(x1) * x1
s2 = self.a2.expand_as(x2) * x2
s3 = self.a3.expand_as(x3) * x3
s4 = self.a4.expand_as(x4) * x4
y = self.avgpool(s1 + s2 + s3 + s4)
return y
class MuDeep(nn.Module):
"""Multiscale deep neural network.
Reference:
Qian et al. Multi-scale Deep Learning Architectures
for Person Re-identification. ICCV 2017.
Public keys:
- ``mudeep``: Multiscale deep neural network.
"""
def __init__(self, num_classes, loss='softmax', **kwargs):
super(MuDeep, self).__init__()
self.loss = loss
self.block1 = ConvLayers()
self.block2 = MultiScaleA()
self.block3 = Reduction()
self.block4 = MultiScaleB()
self.block5 = Fusion()
# Due to this fully connected layer, input image has to be fixed
# in shape, i.e. (3, 256, 128), such that the last convolutional feature
# maps are of shape (256, 16, 8). If input shape is changed,
# the input dimension of this layer has to be changed accordingly.
self.fc = nn.Sequential(
nn.Linear(256 * 16 * 8, 4096),
nn.BatchNorm1d(4096),
nn.ReLU(),
)
self.classifier = nn.Linear(4096, num_classes)
self.feat_dim = 4096
def featuremaps(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(*x)
return x
def forward(self, x):
x = self.featuremaps(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
y = self.classifier(x)
if not self.training:
return x
if self.loss == 'softmax':
return y
elif self.loss == 'triplet':
return y, x
else:
raise KeyError('Unsupported loss: {}'.format(self.loss))

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,598 @@
from __future__ import division, absolute_import
import warnings
import torch
from torch import nn
from torch.nn import functional as F
__all__ = [
'osnet_x1_0', 'osnet_x0_75', 'osnet_x0_5', 'osnet_x0_25', 'osnet_ibn_x1_0'
]
pretrained_urls = {
'osnet_x1_0':
'https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY',
'osnet_x0_75':
'https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq',
'osnet_x0_5':
'https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i',
'osnet_x0_25':
'https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs',
'osnet_ibn_x1_0':
'https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l'
}
##########
# Basic layers
##########
class ConvLayer(nn.Module):
"""Convolution layer (conv + bn + relu)."""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
groups=1,
IN=False
):
super(ConvLayer, self).__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=False,
groups=groups
)
if IN:
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
else:
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class Conv1x1(nn.Module):
"""1x1 convolution + bn + relu."""
def __init__(self, in_channels, out_channels, stride=1, groups=1):
super(Conv1x1, self).__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
1,
stride=stride,
padding=0,
bias=False,
groups=groups
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class Conv1x1Linear(nn.Module):
"""1x1 convolution + bn (w/o non-linearity)."""
def __init__(self, in_channels, out_channels, stride=1):
super(Conv1x1Linear, self).__init__()
self.conv = nn.Conv2d(
in_channels, out_channels, 1, stride=stride, padding=0, bias=False
)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class Conv3x3(nn.Module):
"""3x3 convolution + bn + relu."""
def __init__(self, in_channels, out_channels, stride=1, groups=1):
super(Conv3x3, self).__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
3,
stride=stride,
padding=1,
bias=False,
groups=groups
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class LightConv3x3(nn.Module):
"""Lightweight 3x3 convolution.
1x1 (linear) + dw 3x3 (nonlinear).
"""
def __init__(self, in_channels, out_channels):
super(LightConv3x3, self).__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, 1, stride=1, padding=0, bias=False
)
self.conv2 = nn.Conv2d(
out_channels,
out_channels,
3,
stride=1,
padding=1,
bias=False,
groups=out_channels
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.bn(x)
x = self.relu(x)
return x
##########
# Building blocks for omni-scale feature learning
##########
class ChannelGate(nn.Module):
"""A mini-network that generates channel-wise gates conditioned on input tensor."""
def __init__(
self,
in_channels,
num_gates=None,
return_gates=False,
gate_activation='sigmoid',
reduction=16,
layer_norm=False
):
super(ChannelGate, self).__init__()
if num_gates is None:
num_gates = in_channels
self.return_gates = return_gates
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(
in_channels,
in_channels // reduction,
kernel_size=1,
bias=True,
padding=0
)
self.norm1 = None
if layer_norm:
self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(
in_channels // reduction,
num_gates,
kernel_size=1,
bias=True,
padding=0
)
if gate_activation == 'sigmoid':
self.gate_activation = nn.Sigmoid()
elif gate_activation == 'relu':
self.gate_activation = nn.ReLU(inplace=True)
elif gate_activation == 'linear':
self.gate_activation = None
else:
raise RuntimeError(
"Unknown gate activation: {}".format(gate_activation)
)
def forward(self, x):
input = x
x = self.global_avgpool(x)
x = self.fc1(x)
if self.norm1 is not None:
x = self.norm1(x)
x = self.relu(x)
x = self.fc2(x)
if self.gate_activation is not None:
x = self.gate_activation(x)
if self.return_gates:
return x
return input * x
class OSBlock(nn.Module):
"""Omni-scale feature learning block."""
def __init__(
self,
in_channels,
out_channels,
IN=False,
bottleneck_reduction=4,
**kwargs
):
super(OSBlock, self).__init__()
mid_channels = out_channels // bottleneck_reduction
self.conv1 = Conv1x1(in_channels, mid_channels)
self.conv2a = LightConv3x3(mid_channels, mid_channels)
self.conv2b = nn.Sequential(
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
)
self.conv2c = nn.Sequential(
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
)
self.conv2d = nn.Sequential(
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
)
self.gate = ChannelGate(mid_channels)
self.conv3 = Conv1x1Linear(mid_channels, out_channels)
self.downsample = None
if in_channels != out_channels:
self.downsample = Conv1x1Linear(in_channels, out_channels)
self.IN = None
if IN:
self.IN = nn.InstanceNorm2d(out_channels, affine=True)
def forward(self, x):
identity = x
x1 = self.conv1(x)
x2a = self.conv2a(x1)
x2b = self.conv2b(x1)
x2c = self.conv2c(x1)
x2d = self.conv2d(x1)
x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d)
x3 = self.conv3(x2)
if self.downsample is not None:
identity = self.downsample(identity)
out = x3 + identity
if self.IN is not None:
out = self.IN(out)
return F.relu(out)
##########
# Network architecture
##########
class OSNet(nn.Module):
"""Omni-Scale Network.
Reference:
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
- Zhou et al. Learning Generalisable Omni-Scale Representations
for Person Re-Identification. TPAMI, 2021.
"""
def __init__(
self,
num_classes,
blocks,
layers,
channels,
feature_dim=512,
loss='softmax',
IN=False,
**kwargs
):
super(OSNet, self).__init__()
num_blocks = len(blocks)
assert num_blocks == len(layers)
assert num_blocks == len(channels) - 1
self.loss = loss
self.feature_dim = feature_dim
# convolutional backbone
self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN)
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
self.conv2 = self._make_layer(
blocks[0],
layers[0],
channels[0],
channels[1],
reduce_spatial_size=True,
IN=IN
)
self.conv3 = self._make_layer(
blocks[1],
layers[1],
channels[1],
channels[2],
reduce_spatial_size=True
)
self.conv4 = self._make_layer(
blocks[2],
layers[2],
channels[2],
channels[3],
reduce_spatial_size=False
)
self.conv5 = Conv1x1(channels[3], channels[3])
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
# fully connected layer
self.fc = self._construct_fc_layer(
self.feature_dim, channels[3], dropout_p=None
)
# identity classification layer
self.classifier = nn.Linear(self.feature_dim, num_classes)
self._init_params()
def _make_layer(
self,
block,
layer,
in_channels,
out_channels,
reduce_spatial_size,
IN=False
):
layers = []
layers.append(block(in_channels, out_channels, IN=IN))
for i in range(1, layer):
layers.append(block(out_channels, out_channels, IN=IN))
if reduce_spatial_size:
layers.append(
nn.Sequential(
Conv1x1(out_channels, out_channels),
nn.AvgPool2d(2, stride=2)
)
)
return nn.Sequential(*layers)
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
if fc_dims is None or fc_dims < 0:
self.feature_dim = input_dim
return None
if isinstance(fc_dims, int):
fc_dims = [fc_dims]
layers = []
for dim in fc_dims:
layers.append(nn.Linear(input_dim, dim))
layers.append(nn.BatchNorm1d(dim))
layers.append(nn.ReLU(inplace=True))
if dropout_p is not None:
layers.append(nn.Dropout(p=dropout_p))
input_dim = dim
self.feature_dim = fc_dims[-1]
return nn.Sequential(*layers)
def _init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu'
)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def featuremaps(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
return x
def forward(self, x, return_featuremaps=False):
x = self.featuremaps(x)
if return_featuremaps:
return x
v = self.global_avgpool(x)
v = v.view(v.size(0), -1)
if self.fc is not None:
v = self.fc(v)
if not self.training:
return v
y = self.classifier(v)
if self.loss == 'softmax':
return y
elif self.loss == 'triplet':
return y, v
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
def init_pretrained_weights(model, key=''):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
import os
import errno
import gdown
from collections import OrderedDict
def _get_torch_home():
ENV_TORCH_HOME = 'TORCH_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
torch_home = os.path.expanduser(
os.getenv(
ENV_TORCH_HOME,
os.path.join(
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
)
)
)
return torch_home
torch_home = _get_torch_home()
model_dir = os.path.join(torch_home, 'checkpoints')
try:
os.makedirs(model_dir)
except OSError as e:
if e.errno == errno.EEXIST:
# Directory already exists, ignore.
pass
else:
# Unexpected OSError, re-raise.
raise
filename = key + '_imagenet.pth'
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
gdown.download(pretrained_urls[key], cached_file, quiet=False)
state_dict = torch.load(cached_file)
model_dict = model.state_dict()
new_state_dict = OrderedDict()
matched_layers, discarded_layers = [], []
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:] # discard module.
if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_layers.append(k)
else:
discarded_layers.append(k)
model_dict.update(new_state_dict)
model.load_state_dict(model_dict)
if len(matched_layers) == 0:
warnings.warn(
'The pretrained weights from "{}" cannot be loaded, '
'please check the key names manually '
'(** ignored and continue **)'.format(cached_file)
)
else:
print(
'Successfully loaded imagenet pretrained weights from "{}"'.
format(cached_file)
)
if len(discarded_layers) > 0:
print(
'** The following layers are discarded '
'due to unmatched keys or layer size: {}'.
format(discarded_layers)
)
##########
# Instantiation
##########
def osnet_x1_0(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
# standard size (width x1.0)
model = OSNet(
num_classes,
blocks=[OSBlock, OSBlock, OSBlock],
layers=[2, 2, 2],
channels=[64, 256, 384, 512],
loss=loss,
**kwargs
)
if pretrained:
init_pretrained_weights(model, key='osnet_x1_0')
return model
def osnet_x0_75(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
# medium size (width x0.75)
model = OSNet(
num_classes,
blocks=[OSBlock, OSBlock, OSBlock],
layers=[2, 2, 2],
channels=[48, 192, 288, 384],
loss=loss,
**kwargs
)
if pretrained:
init_pretrained_weights(model, key='osnet_x0_75')
return model
def osnet_x0_5(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
# tiny size (width x0.5)
model = OSNet(
num_classes,
blocks=[OSBlock, OSBlock, OSBlock],
layers=[2, 2, 2],
channels=[32, 128, 192, 256],
loss=loss,
**kwargs
)
if pretrained:
init_pretrained_weights(model, key='osnet_x0_5')
return model
def osnet_x0_25(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
# very tiny size (width x0.25)
model = OSNet(
num_classes,
blocks=[OSBlock, OSBlock, OSBlock],
layers=[2, 2, 2],
channels=[16, 64, 96, 128],
loss=loss,
**kwargs
)
if pretrained:
init_pretrained_weights(model, key='osnet_x0_25')
return model
def osnet_ibn_x1_0(
num_classes=1000, pretrained=True, loss='softmax', **kwargs
):
# standard size (width x1.0) + IBN layer
# Ref: Pan et al. Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net. ECCV, 2018.
model = OSNet(
num_classes,
blocks=[OSBlock, OSBlock, OSBlock],
layers=[2, 2, 2],
channels=[64, 256, 384, 512],
loss=loss,
IN=True,
**kwargs
)
if pretrained:
init_pretrained_weights(model, key='osnet_ibn_x1_0')
return model

View file

@ -0,0 +1,609 @@
from __future__ import division, absolute_import
import warnings
import torch
from torch import nn
from torch.nn import functional as F
__all__ = [
'osnet_ain_x1_0', 'osnet_ain_x0_75', 'osnet_ain_x0_5', 'osnet_ain_x0_25'
]
pretrained_urls = {
'osnet_ain_x1_0':
'https://drive.google.com/uc?id=1-CaioD9NaqbHK_kzSMW8VE4_3KcsRjEo',
'osnet_ain_x0_75':
'https://drive.google.com/uc?id=1apy0hpsMypqstfencdH-jKIUEFOW4xoM',
'osnet_ain_x0_5':
'https://drive.google.com/uc?id=1KusKvEYyKGDTUBVRxRiz55G31wkihB6l',
'osnet_ain_x0_25':
'https://drive.google.com/uc?id=1SxQt2AvmEcgWNhaRb2xC4rP6ZwVDP0Wt'
}
##########
# Basic layers
##########
class ConvLayer(nn.Module):
"""Convolution layer (conv + bn + relu)."""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
groups=1,
IN=False
):
super(ConvLayer, self).__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=False,
groups=groups
)
if IN:
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
else:
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return self.relu(x)
class Conv1x1(nn.Module):
"""1x1 convolution + bn + relu."""
def __init__(self, in_channels, out_channels, stride=1, groups=1):
super(Conv1x1, self).__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
1,
stride=stride,
padding=0,
bias=False,
groups=groups
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return self.relu(x)
class Conv1x1Linear(nn.Module):
"""1x1 convolution + bn (w/o non-linearity)."""
def __init__(self, in_channels, out_channels, stride=1, bn=True):
super(Conv1x1Linear, self).__init__()
self.conv = nn.Conv2d(
in_channels, out_channels, 1, stride=stride, padding=0, bias=False
)
self.bn = None
if bn:
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
return x
class Conv3x3(nn.Module):
"""3x3 convolution + bn + relu."""
def __init__(self, in_channels, out_channels, stride=1, groups=1):
super(Conv3x3, self).__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
3,
stride=stride,
padding=1,
bias=False,
groups=groups
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return self.relu(x)
class LightConv3x3(nn.Module):
"""Lightweight 3x3 convolution.
1x1 (linear) + dw 3x3 (nonlinear).
"""
def __init__(self, in_channels, out_channels):
super(LightConv3x3, self).__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, 1, stride=1, padding=0, bias=False
)
self.conv2 = nn.Conv2d(
out_channels,
out_channels,
3,
stride=1,
padding=1,
bias=False,
groups=out_channels
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.bn(x)
return self.relu(x)
class LightConvStream(nn.Module):
"""Lightweight convolution stream."""
def __init__(self, in_channels, out_channels, depth):
super(LightConvStream, self).__init__()
assert depth >= 1, 'depth must be equal to or larger than 1, but got {}'.format(
depth
)
layers = []
layers += [LightConv3x3(in_channels, out_channels)]
for i in range(depth - 1):
layers += [LightConv3x3(out_channels, out_channels)]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
##########
# Building blocks for omni-scale feature learning
##########
class ChannelGate(nn.Module):
"""A mini-network that generates channel-wise gates conditioned on input tensor."""
def __init__(
self,
in_channels,
num_gates=None,
return_gates=False,
gate_activation='sigmoid',
reduction=16,
layer_norm=False
):
super(ChannelGate, self).__init__()
if num_gates is None:
num_gates = in_channels
self.return_gates = return_gates
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(
in_channels,
in_channels // reduction,
kernel_size=1,
bias=True,
padding=0
)
self.norm1 = None
if layer_norm:
self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
self.relu = nn.ReLU()
self.fc2 = nn.Conv2d(
in_channels // reduction,
num_gates,
kernel_size=1,
bias=True,
padding=0
)
if gate_activation == 'sigmoid':
self.gate_activation = nn.Sigmoid()
elif gate_activation == 'relu':
self.gate_activation = nn.ReLU()
elif gate_activation == 'linear':
self.gate_activation = None
else:
raise RuntimeError(
"Unknown gate activation: {}".format(gate_activation)
)
def forward(self, x):
input = x
x = self.global_avgpool(x)
x = self.fc1(x)
if self.norm1 is not None:
x = self.norm1(x)
x = self.relu(x)
x = self.fc2(x)
if self.gate_activation is not None:
x = self.gate_activation(x)
if self.return_gates:
return x
return input * x
class OSBlock(nn.Module):
"""Omni-scale feature learning block."""
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
super(OSBlock, self).__init__()
assert T >= 1
assert out_channels >= reduction and out_channels % reduction == 0
mid_channels = out_channels // reduction
self.conv1 = Conv1x1(in_channels, mid_channels)
self.conv2 = nn.ModuleList()
for t in range(1, T + 1):
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
self.gate = ChannelGate(mid_channels)
self.conv3 = Conv1x1Linear(mid_channels, out_channels)
self.downsample = None
if in_channels != out_channels:
self.downsample = Conv1x1Linear(in_channels, out_channels)
def forward(self, x):
identity = x
x1 = self.conv1(x)
x2 = 0
for conv2_t in self.conv2:
x2_t = conv2_t(x1)
x2 = x2 + self.gate(x2_t)
x3 = self.conv3(x2)
if self.downsample is not None:
identity = self.downsample(identity)
out = x3 + identity
return F.relu(out)
class OSBlockINin(nn.Module):
"""Omni-scale feature learning block with instance normalization."""
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
super(OSBlockINin, self).__init__()
assert T >= 1
assert out_channels >= reduction and out_channels % reduction == 0
mid_channels = out_channels // reduction
self.conv1 = Conv1x1(in_channels, mid_channels)
self.conv2 = nn.ModuleList()
for t in range(1, T + 1):
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
self.gate = ChannelGate(mid_channels)
self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn=False)
self.downsample = None
if in_channels != out_channels:
self.downsample = Conv1x1Linear(in_channels, out_channels)
self.IN = nn.InstanceNorm2d(out_channels, affine=True)
def forward(self, x):
identity = x
x1 = self.conv1(x)
x2 = 0
for conv2_t in self.conv2:
x2_t = conv2_t(x1)
x2 = x2 + self.gate(x2_t)
x3 = self.conv3(x2)
x3 = self.IN(x3) # IN inside residual
if self.downsample is not None:
identity = self.downsample(identity)
out = x3 + identity
return F.relu(out)
##########
# Network architecture
##########
class OSNet(nn.Module):
"""Omni-Scale Network.
Reference:
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
- Zhou et al. Learning Generalisable Omni-Scale Representations
for Person Re-Identification. TPAMI, 2021.
"""
def __init__(
self,
num_classes,
blocks,
layers,
channels,
feature_dim=512,
loss='softmax',
conv1_IN=False,
**kwargs
):
super(OSNet, self).__init__()
num_blocks = len(blocks)
assert num_blocks == len(layers)
assert num_blocks == len(channels) - 1
self.loss = loss
self.feature_dim = feature_dim
# convolutional backbone
self.conv1 = ConvLayer(
3, channels[0], 7, stride=2, padding=3, IN=conv1_IN
)
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
self.conv2 = self._make_layer(
blocks[0], layers[0], channels[0], channels[1]
)
self.pool2 = nn.Sequential(
Conv1x1(channels[1], channels[1]), nn.AvgPool2d(2, stride=2)
)
self.conv3 = self._make_layer(
blocks[1], layers[1], channels[1], channels[2]
)
self.pool3 = nn.Sequential(
Conv1x1(channels[2], channels[2]), nn.AvgPool2d(2, stride=2)
)
self.conv4 = self._make_layer(
blocks[2], layers[2], channels[2], channels[3]
)
self.conv5 = Conv1x1(channels[3], channels[3])
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
# fully connected layer
self.fc = self._construct_fc_layer(
self.feature_dim, channels[3], dropout_p=None
)
# identity classification layer
self.classifier = nn.Linear(self.feature_dim, num_classes)
self._init_params()
def _make_layer(self, blocks, layer, in_channels, out_channels):
layers = []
layers += [blocks[0](in_channels, out_channels)]
for i in range(1, len(blocks)):
layers += [blocks[i](out_channels, out_channels)]
return nn.Sequential(*layers)
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
if fc_dims is None or fc_dims < 0:
self.feature_dim = input_dim
return None
if isinstance(fc_dims, int):
fc_dims = [fc_dims]
layers = []
for dim in fc_dims:
layers.append(nn.Linear(input_dim, dim))
layers.append(nn.BatchNorm1d(dim))
layers.append(nn.ReLU())
if dropout_p is not None:
layers.append(nn.Dropout(p=dropout_p))
input_dim = dim
self.feature_dim = fc_dims[-1]
return nn.Sequential(*layers)
def _init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu'
)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.InstanceNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def featuremaps(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.conv2(x)
x = self.pool2(x)
x = self.conv3(x)
x = self.pool3(x)
x = self.conv4(x)
x = self.conv5(x)
return x
def forward(self, x, return_featuremaps=False):
x = self.featuremaps(x)
if return_featuremaps:
return x
v = self.global_avgpool(x)
v = v.view(v.size(0), -1)
if self.fc is not None:
v = self.fc(v)
if not self.training:
return v
y = self.classifier(v)
if self.loss == 'softmax':
return y
elif self.loss == 'triplet':
return y, v
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
def init_pretrained_weights(model, key=''):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
import os
import errno
import gdown
from collections import OrderedDict
def _get_torch_home():
ENV_TORCH_HOME = 'TORCH_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
torch_home = os.path.expanduser(
os.getenv(
ENV_TORCH_HOME,
os.path.join(
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
)
)
)
return torch_home
torch_home = _get_torch_home()
model_dir = os.path.join(torch_home, 'checkpoints')
try:
os.makedirs(model_dir)
except OSError as e:
if e.errno == errno.EEXIST:
# Directory already exists, ignore.
pass
else:
# Unexpected OSError, re-raise.
raise
filename = key + '_imagenet.pth'
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
gdown.download(pretrained_urls[key], cached_file, quiet=False)
state_dict = torch.load(cached_file)
model_dict = model.state_dict()
new_state_dict = OrderedDict()
matched_layers, discarded_layers = [], []
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:] # discard module.
if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_layers.append(k)
else:
discarded_layers.append(k)
model_dict.update(new_state_dict)
model.load_state_dict(model_dict)
if len(matched_layers) == 0:
warnings.warn(
'The pretrained weights from "{}" cannot be loaded, '
'please check the key names manually '
'(** ignored and continue **)'.format(cached_file)
)
else:
print(
'Successfully loaded imagenet pretrained weights from "{}"'.
format(cached_file)
)
if len(discarded_layers) > 0:
print(
'** The following layers are discarded '
'due to unmatched keys or layer size: {}'.
format(discarded_layers)
)
##########
# Instantiation
##########
def osnet_ain_x1_0(
num_classes=1000, pretrained=True, loss='softmax', **kwargs
):
model = OSNet(
num_classes,
blocks=[
[OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin],
[OSBlockINin, OSBlock]
],
layers=[2, 2, 2],
channels=[64, 256, 384, 512],
loss=loss,
conv1_IN=True,
**kwargs
)
if pretrained:
init_pretrained_weights(model, key='osnet_ain_x1_0')
return model
def osnet_ain_x0_75(
num_classes=1000, pretrained=True, loss='softmax', **kwargs
):
model = OSNet(
num_classes,
blocks=[
[OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin],
[OSBlockINin, OSBlock]
],
layers=[2, 2, 2],
channels=[48, 192, 288, 384],
loss=loss,
conv1_IN=True,
**kwargs
)
if pretrained:
init_pretrained_weights(model, key='osnet_ain_x0_75')
return model
def osnet_ain_x0_5(
num_classes=1000, pretrained=True, loss='softmax', **kwargs
):
model = OSNet(
num_classes,
blocks=[
[OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin],
[OSBlockINin, OSBlock]
],
layers=[2, 2, 2],
channels=[32, 128, 192, 256],
loss=loss,
conv1_IN=True,
**kwargs
)
if pretrained:
init_pretrained_weights(model, key='osnet_ain_x0_5')
return model
def osnet_ain_x0_25(
num_classes=1000, pretrained=True, loss='softmax', **kwargs
):
model = OSNet(
num_classes,
blocks=[
[OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin],
[OSBlockINin, OSBlock]
],
layers=[2, 2, 2],
channels=[16, 64, 96, 128],
loss=loss,
conv1_IN=True,
**kwargs
)
if pretrained:
init_pretrained_weights(model, key='osnet_ain_x0_25')
return model

View file

@ -0,0 +1,314 @@
from __future__ import division, absolute_import
import torch.utils.model_zoo as model_zoo
from torch import nn
from torch.nn import functional as F
__all__ = ['pcb_p6', 'pcb_p4']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False
)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class DimReduceLayer(nn.Module):
def __init__(self, in_channels, out_channels, nonlinear):
super(DimReduceLayer, self).__init__()
layers = []
layers.append(
nn.Conv2d(
in_channels, out_channels, 1, stride=1, padding=0, bias=False
)
)
layers.append(nn.BatchNorm2d(out_channels))
if nonlinear == 'relu':
layers.append(nn.ReLU(inplace=True))
elif nonlinear == 'leakyrelu':
layers.append(nn.LeakyReLU(0.1))
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class PCB(nn.Module):
"""Part-based Convolutional Baseline.
Reference:
Sun et al. Beyond Part Models: Person Retrieval with Refined
Part Pooling (and A Strong Convolutional Baseline). ECCV 2018.
Public keys:
- ``pcb_p4``: PCB with 4-part strips.
- ``pcb_p6``: PCB with 6-part strips.
"""
def __init__(
self,
num_classes,
loss,
block,
layers,
parts=6,
reduced_dim=256,
nonlinear='relu',
**kwargs
):
self.inplanes = 64
super(PCB, self).__init__()
self.loss = loss
self.parts = parts
self.feature_dim = 512 * block.expansion
# backbone network
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1)
# pcb layers
self.parts_avgpool = nn.AdaptiveAvgPool2d((self.parts, 1))
self.dropout = nn.Dropout(p=0.5)
self.conv5 = DimReduceLayer(
512 * block.expansion, reduced_dim, nonlinear=nonlinear
)
self.feature_dim = reduced_dim
self.classifier = nn.ModuleList(
[
nn.Linear(self.feature_dim, num_classes)
for _ in range(self.parts)
]
)
self._init_params()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False
),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def _init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu'
)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def featuremaps(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def forward(self, x):
f = self.featuremaps(x)
v_g = self.parts_avgpool(f)
if not self.training:
v_g = F.normalize(v_g, p=2, dim=1)
return v_g.view(v_g.size(0), -1)
v_g = self.dropout(v_g)
v_h = self.conv5(v_g)
y = []
for i in range(self.parts):
v_h_i = v_h[:, :, i, :]
v_h_i = v_h_i.view(v_h_i.size(0), -1)
y_i = self.classifier[i](v_h_i)
y.append(y_i)
if self.loss == 'softmax':
return y
elif self.loss == 'triplet':
v_g = F.normalize(v_g, p=2, dim=1)
return y, v_g.view(v_g.size(0), -1)
else:
raise KeyError('Unsupported loss: {}'.format(self.loss))
def init_pretrained_weights(model, model_url):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {
k: v
for k, v in pretrain_dict.items()
if k in model_dict and model_dict[k].size() == v.size()
}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
def pcb_p6(num_classes, loss='softmax', pretrained=True, **kwargs):
model = PCB(
num_classes=num_classes,
loss=loss,
block=Bottleneck,
layers=[3, 4, 6, 3],
last_stride=1,
parts=6,
reduced_dim=256,
nonlinear='relu',
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['resnet50'])
return model
def pcb_p4(num_classes, loss='softmax', pretrained=True, **kwargs):
model = PCB(
num_classes=num_classes,
loss=loss,
block=Bottleneck,
layers=[3, 4, 6, 3],
last_stride=1,
parts=4,
reduced_dim=256,
nonlinear='relu',
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['resnet50'])
return model

View file

@ -0,0 +1,530 @@
"""
Code source: https://github.com/pytorch/vision
"""
from __future__ import division, absolute_import
import torch.utils.model_zoo as model_zoo
from torch import nn
__all__ = [
'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
'resnext50_32x4d', 'resnext101_32x8d', 'resnet50_fc512'
]
model_urls = {
'resnet18':
'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34':
'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50':
'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101':
'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152':
'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d':
'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d':
'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
}
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation
)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(
in_planes, out_planes, kernel_size=1, stride=stride, bias=False
)
class BasicBlock(nn.Module):
expansion = 1
def __init__(
self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None
):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError(
'BasicBlock only supports groups=1 and base_width=64'
)
if dilation > 1:
raise NotImplementedError(
"Dilation > 1 not supported in BasicBlock"
)
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(
self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None
):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width/64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
"""Residual network.
Reference:
- He et al. Deep Residual Learning for Image Recognition. CVPR 2016.
- Xie et al. Aggregated Residual Transformations for Deep Neural Networks. CVPR 2017.
Public keys:
- ``resnet18``: ResNet18.
- ``resnet34``: ResNet34.
- ``resnet50``: ResNet50.
- ``resnet101``: ResNet101.
- ``resnet152``: ResNet152.
- ``resnext50_32x4d``: ResNeXt50.
- ``resnext101_32x8d``: ResNeXt101.
- ``resnet50_fc512``: ResNet50 + FC.
"""
def __init__(
self,
num_classes,
loss,
block,
layers,
zero_init_residual=False,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
norm_layer=None,
last_stride=2,
fc_dims=None,
dropout_p=None,
**kwargs
):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.loss = loss
self.feature_dim = 512 * block.expansion
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
"replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".
format(replace_stride_with_dilation)
)
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(
3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(
block,
128,
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0]
)
self.layer3 = self._make_layer(
block,
256,
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1]
)
self.layer4 = self._make_layer(
block,
512,
layers[3],
stride=last_stride,
dilate=replace_stride_with_dilation[2]
)
self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = self._construct_fc_layer(
fc_dims, 512 * block.expansion, dropout_p
)
self.classifier = nn.Linear(self.feature_dim, num_classes)
self._init_params()
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(
block(
self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer
)
)
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer
)
)
return nn.Sequential(*layers)
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
"""Constructs fully connected layer
Args:
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
input_dim (int): input dimension
dropout_p (float): dropout probability, if None, dropout is unused
"""
if fc_dims is None:
self.feature_dim = input_dim
return None
assert isinstance(
fc_dims, (list, tuple)
), 'fc_dims must be either list or tuple, but got {}'.format(
type(fc_dims)
)
layers = []
for dim in fc_dims:
layers.append(nn.Linear(input_dim, dim))
layers.append(nn.BatchNorm1d(dim))
layers.append(nn.ReLU(inplace=True))
if dropout_p is not None:
layers.append(nn.Dropout(p=dropout_p))
input_dim = dim
self.feature_dim = fc_dims[-1]
return nn.Sequential(*layers)
def _init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu'
)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def featuremaps(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def forward(self, x):
f = self.featuremaps(x)
v = self.global_avgpool(f)
v = v.view(v.size(0), -1)
if self.fc is not None:
v = self.fc(v)
if not self.training:
return v
y = self.classifier(v)
if self.loss == 'softmax':
return y
elif self.loss == 'triplet':
return y, v
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
def init_pretrained_weights(model, model_url):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {
k: v
for k, v in pretrain_dict.items()
if k in model_dict and model_dict[k].size() == v.size()
}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
"""ResNet"""
def resnet18(num_classes, loss='softmax', pretrained=True, **kwargs):
model = ResNet(
num_classes=num_classes,
loss=loss,
block=BasicBlock,
layers=[2, 2, 2, 2],
last_stride=2,
fc_dims=None,
dropout_p=None,
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['resnet18'])
return model
def resnet34(num_classes, loss='softmax', pretrained=True, **kwargs):
model = ResNet(
num_classes=num_classes,
loss=loss,
block=BasicBlock,
layers=[3, 4, 6, 3],
last_stride=2,
fc_dims=None,
dropout_p=None,
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['resnet34'])
return model
def resnet50(num_classes, loss='softmax', pretrained=True, **kwargs):
model = ResNet(
num_classes=num_classes,
loss=loss,
block=Bottleneck,
layers=[3, 4, 6, 3],
last_stride=2,
fc_dims=None,
dropout_p=None,
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['resnet50'])
return model
def resnet101(num_classes, loss='softmax', pretrained=True, **kwargs):
model = ResNet(
num_classes=num_classes,
loss=loss,
block=Bottleneck,
layers=[3, 4, 23, 3],
last_stride=2,
fc_dims=None,
dropout_p=None,
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['resnet101'])
return model
def resnet152(num_classes, loss='softmax', pretrained=True, **kwargs):
model = ResNet(
num_classes=num_classes,
loss=loss,
block=Bottleneck,
layers=[3, 8, 36, 3],
last_stride=2,
fc_dims=None,
dropout_p=None,
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['resnet152'])
return model
"""ResNeXt"""
def resnext50_32x4d(num_classes, loss='softmax', pretrained=True, **kwargs):
model = ResNet(
num_classes=num_classes,
loss=loss,
block=Bottleneck,
layers=[3, 4, 6, 3],
last_stride=2,
fc_dims=None,
dropout_p=None,
groups=32,
width_per_group=4,
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['resnext50_32x4d'])
return model
def resnext101_32x8d(num_classes, loss='softmax', pretrained=True, **kwargs):
model = ResNet(
num_classes=num_classes,
loss=loss,
block=Bottleneck,
layers=[3, 4, 23, 3],
last_stride=2,
fc_dims=None,
dropout_p=None,
groups=32,
width_per_group=8,
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['resnext101_32x8d'])
return model
"""
ResNet + FC
"""
def resnet50_fc512(num_classes, loss='softmax', pretrained=True, **kwargs):
model = ResNet(
num_classes=num_classes,
loss=loss,
block=Bottleneck,
layers=[3, 4, 6, 3],
last_stride=1,
fc_dims=[512],
dropout_p=None,
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['resnet50'])
return model

View file

@ -0,0 +1,289 @@
"""
Credit to https://github.com/XingangPan/IBN-Net.
"""
from __future__ import division, absolute_import
import math
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
__all__ = ['resnet50_ibn_a']
model_urls = {
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class IBN(nn.Module):
def __init__(self, planes):
super(IBN, self).__init__()
half1 = int(planes / 2)
self.half = half1
half2 = planes - half1
self.IN = nn.InstanceNorm2d(half1, affine=True)
self.BN = nn.BatchNorm2d(half2)
def forward(self, x):
split = torch.split(x, self.half, 1)
out1 = self.IN(split[0].contiguous())
out2 = self.BN(split[1].contiguous())
out = torch.cat((out1, out2), 1)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
if ibn:
self.bn1 = IBN(planes)
else:
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False
)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
"""Residual network + IBN layer.
Reference:
- He et al. Deep Residual Learning for Image Recognition. CVPR 2016.
- Pan et al. Two at Once: Enhancing Learning and Generalization
Capacities via IBN-Net. ECCV 2018.
"""
def __init__(
self,
block,
layers,
num_classes=1000,
loss='softmax',
fc_dims=None,
dropout_p=None,
**kwargs
):
scale = 64
self.inplanes = scale
super(ResNet, self).__init__()
self.loss = loss
self.feature_dim = scale * 8 * block.expansion
self.conv1 = nn.Conv2d(
3, scale, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = nn.BatchNorm2d(scale)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, scale, layers[0])
self.layer2 = self._make_layer(block, scale * 2, layers[1], stride=2)
self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2)
self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = self._construct_fc_layer(
fc_dims, scale * 8 * block.expansion, dropout_p
)
self.classifier = nn.Linear(self.feature_dim, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.InstanceNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False
),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
ibn = True
if planes == 512:
ibn = False
layers.append(block(self.inplanes, planes, ibn, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, ibn))
return nn.Sequential(*layers)
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
"""Constructs fully connected layer
Args:
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
input_dim (int): input dimension
dropout_p (float): dropout probability, if None, dropout is unused
"""
if fc_dims is None:
self.feature_dim = input_dim
return None
assert isinstance(
fc_dims, (list, tuple)
), 'fc_dims must be either list or tuple, but got {}'.format(
type(fc_dims)
)
layers = []
for dim in fc_dims:
layers.append(nn.Linear(input_dim, dim))
layers.append(nn.BatchNorm1d(dim))
layers.append(nn.ReLU(inplace=True))
if dropout_p is not None:
layers.append(nn.Dropout(p=dropout_p))
input_dim = dim
self.feature_dim = fc_dims[-1]
return nn.Sequential(*layers)
def featuremaps(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def forward(self, x):
f = self.featuremaps(x)
v = self.avgpool(f)
v = v.view(v.size(0), -1)
if self.fc is not None:
v = self.fc(v)
if not self.training:
return v
y = self.classifier(v)
if self.loss == 'softmax':
return y
elif self.loss == 'triplet':
return y, v
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
def init_pretrained_weights(model, model_url):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {
k: v
for k, v in pretrain_dict.items()
if k in model_dict and model_dict[k].size() == v.size()
}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
def resnet50_ibn_a(num_classes, loss='softmax', pretrained=False, **kwargs):
model = ResNet(
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, loss=loss, **kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['resnet50'])
return model

View file

@ -0,0 +1,274 @@
"""
Credit to https://github.com/XingangPan/IBN-Net.
"""
from __future__ import division, absolute_import
import math
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
__all__ = ['resnet50_ibn_b']
model_urls = {
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, IN=False):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False
)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.IN = None
if IN:
self.IN = nn.InstanceNorm2d(planes * 4, affine=True)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
if self.IN is not None:
out = self.IN(out)
out = self.relu(out)
return out
class ResNet(nn.Module):
"""Residual network + IBN layer.
Reference:
- He et al. Deep Residual Learning for Image Recognition. CVPR 2016.
- Pan et al. Two at Once: Enhancing Learning and Generalization
Capacities via IBN-Net. ECCV 2018.
"""
def __init__(
self,
block,
layers,
num_classes=1000,
loss='softmax',
fc_dims=None,
dropout_p=None,
**kwargs
):
scale = 64
self.inplanes = scale
super(ResNet, self).__init__()
self.loss = loss
self.feature_dim = scale * 8 * block.expansion
self.conv1 = nn.Conv2d(
3, scale, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = nn.InstanceNorm2d(scale, affine=True)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(
block, scale, layers[0], stride=1, IN=True
)
self.layer2 = self._make_layer(
block, scale * 2, layers[1], stride=2, IN=True
)
self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2)
self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = self._construct_fc_layer(
fc_dims, scale * 8 * block.expansion, dropout_p
)
self.classifier = nn.Linear(self.feature_dim, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.InstanceNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, IN=False):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False
),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks - 1):
layers.append(block(self.inplanes, planes))
layers.append(block(self.inplanes, planes, IN=IN))
return nn.Sequential(*layers)
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
"""Constructs fully connected layer
Args:
fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
input_dim (int): input dimension
dropout_p (float): dropout probability, if None, dropout is unused
"""
if fc_dims is None:
self.feature_dim = input_dim
return None
assert isinstance(
fc_dims, (list, tuple)
), 'fc_dims must be either list or tuple, but got {}'.format(
type(fc_dims)
)
layers = []
for dim in fc_dims:
layers.append(nn.Linear(input_dim, dim))
layers.append(nn.BatchNorm1d(dim))
layers.append(nn.ReLU(inplace=True))
if dropout_p is not None:
layers.append(nn.Dropout(p=dropout_p))
input_dim = dim
self.feature_dim = fc_dims[-1]
return nn.Sequential(*layers)
def featuremaps(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def forward(self, x):
f = self.featuremaps(x)
v = self.avgpool(f)
v = v.view(v.size(0), -1)
if self.fc is not None:
v = self.fc(v)
if not self.training:
return v
y = self.classifier(v)
if self.loss == 'softmax':
return y
elif self.loss == 'triplet':
return y, v
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
def init_pretrained_weights(model, model_url):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {
k: v
for k, v in pretrain_dict.items()
if k in model_dict and model_dict[k].size() == v.size()
}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
def resnet50_ibn_b(num_classes, loss='softmax', pretrained=False, **kwargs):
model = ResNet(
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, loss=loss, **kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['resnet50'])
return model

Some files were not shown because too many files have changed in this diff Show more