diff --git a/.claude/settings.local.json b/.claude/settings.local.json deleted file mode 100644 index 9e296ac..0000000 --- a/.claude/settings.local.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "permissions": { - "allow": [ - "Bash(dir:*)", - "WebSearch", - "Bash(mkdir:*)" - ], - "deny": [], - "ask": [] - } -} \ No newline at end of file diff --git a/.gitea/workflows/build.yml b/.gitea/workflows/build.yml index dc4f18d..585009f 100644 --- a/.gitea/workflows/build.yml +++ b/.gitea/workflows/build.yml @@ -51,7 +51,7 @@ jobs: registry: git.siwatsystem.com username: ${{ github.actor }} password: ${{ secrets.RUNNER_TOKEN }} - + - name: Build and push base Docker image uses: docker/build-push-action@v4 with: @@ -79,7 +79,7 @@ jobs: registry: git.siwatsystem.com username: ${{ github.actor }} password: ${{ secrets.RUNNER_TOKEN }} - + - name: Build and push Docker image uses: docker/build-push-action@v4 with: @@ -103,4 +103,10 @@ jobs: - name: Deploy stack run: | echo "Pulling and starting containers on server..." - ssh -i ~/.ssh/id_rsa ${{ vars.DEPLOY_USER_CMS }}@${{ vars.DEPLOY_HOST_CMS }} "cd ~/cms-system-k8s && docker compose -f docker-compose.staging.yml -f docker-compose.production.yml pull && docker compose -f docker-compose.staging.yml -f docker-compose.production.yml up -d" + if [ "${{ github.ref_name }}" = "main" ]; then + echo "Deploying production stack..." + ssh -i ~/.ssh/id_rsa ${{ vars.DEPLOY_USER_CMS }}@${{ vars.DEPLOY_HOST_CMS }} "cd ~/cms-system-k8s && docker compose -f docker-compose.production.yml pull && docker compose -f docker-compose.production.yml up -d" + else + echo "Deploying staging stack..." + ssh -i ~/.ssh/id_rsa ${{ vars.DEPLOY_USER_CMS }}@${{ vars.DEPLOY_HOST_CMS }} "cd ~/cms-system-k8s && docker compose -f docker-compose.staging.yml pull && docker compose -f docker-compose.staging.yml up -d" + fi \ No newline at end of file diff --git a/.gitignore b/.gitignore index 2da89cb..b36f421 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,11 @@ -/models +# Do not know how to use +archive/ +Dockerfile + +# /models app.log *.pt - -images +.venv/ # All pycache directories __pycache__/ @@ -12,3 +15,7 @@ mptas detector_worker.log .gitignore no_frame_debug.log + + +# Result from tracker +feeder/runs/ \ No newline at end of file diff --git a/Dockerfile.base b/Dockerfile.base index 9684325..3700920 100644 --- a/Dockerfile.base +++ b/Dockerfile.base @@ -1,130 +1,15 @@ -# Base image with complete ML and hardware acceleration stack -FROM pytorch/pytorch:2.8.0-cuda12.6-cudnn9-runtime +# Base image with all ML dependencies +FROM python:3.13-bookworm -# Install build dependencies and system libraries -RUN apt-get update && apt-get install -y \ - # Build tools - build-essential \ - cmake \ - git \ - pkg-config \ - wget \ - unzip \ - yasm \ - nasm \ - # Additional dependencies for FFmpeg/NVIDIA build - libtool \ - libc6 \ - libc6-dev \ - libnuma1 \ - libnuma-dev \ - # Essential compilation libraries - gcc \ - g++ \ - libc6-dev \ - linux-libc-dev \ - # System libraries - libgl1-mesa-glx \ - libglib2.0-0 \ - libgomp1 \ - # Core media libraries (essential ones only) - libjpeg-dev \ - libpng-dev \ - libx264-dev \ - libx265-dev \ - libvpx-dev \ - libmp3lame-dev \ - libv4l-dev \ - # TurboJPEG for fast JPEG encoding - libturbojpeg0-dev \ - # Python development - python3-dev \ - python3-numpy \ - && rm -rf /var/lib/apt/lists/* +# Install system dependencies +RUN apt update && apt install -y libgl1 && rm -rf /var/lib/apt/lists/* -# Add NVIDIA CUDA repository and install minimal development tools -RUN apt-get update && apt-get install -y wget gnupg && \ - wget -O - https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub | apt-key add - && \ - echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \ - apt-get update && \ - apt-get install -y \ - cuda-nvcc-12-6 \ - cuda-cudart-dev-12-6 \ - libnpp-dev-12-6 \ - && apt-get remove -y wget gnupg && \ - apt-get autoremove -y && \ - rm -rf /var/lib/apt/lists/* - -# Ensure CUDA paths are available -ENV PATH="/usr/local/cuda/bin:${PATH}" -ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}" - -# Install NVIDIA Video Codec SDK headers (official method) -RUN cd /tmp && \ - git clone https://git.videolan.org/git/ffmpeg/nv-codec-headers.git && \ - cd nv-codec-headers && \ - make install && \ - cd / && rm -rf /tmp/* - -# Build FFmpeg from source with NVIDIA CUDA support -RUN cd /tmp && \ - echo "Building FFmpeg with NVIDIA CUDA support..." && \ - # Download FFmpeg source (official method) - git clone https://git.ffmpeg.org/ffmpeg.git ffmpeg/ && \ - cd ffmpeg && \ - # Configure with NVIDIA support (simplified to avoid configure issues) - ./configure \ - --prefix=/usr/local \ - --enable-shared \ - --disable-static \ - --enable-nonfree \ - --enable-gpl \ - --enable-cuda-nvcc \ - --enable-cuvid \ - --enable-nvdec \ - --enable-nvenc \ - --enable-libnpp \ - --extra-cflags=-I/usr/local/cuda/include \ - --extra-ldflags=-L/usr/local/cuda/lib64 \ - --enable-libx264 \ - --enable-libx265 \ - --enable-libvpx \ - --enable-libmp3lame && \ - # Build and install - make -j$(nproc) && \ - make install && \ - ldconfig && \ - # Verify CUVID decoders are available - echo "=== Verifying FFmpeg CUVID Support ===" && \ - (ffmpeg -hide_banner -decoders 2>/dev/null | grep cuvid || echo "No CUVID decoders found") && \ - echo "=== Verifying FFmpeg NVENC Support ===" && \ - (ffmpeg -hide_banner -encoders 2>/dev/null | grep nvenc || echo "No NVENC encoders found") && \ - cd / && rm -rf /tmp/* - -# Set environment variables for maximum hardware acceleration -ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/lib:${LD_LIBRARY_PATH}" -ENV PKG_CONFIG_PATH="/usr/local/lib/pkgconfig:${PKG_CONFIG_PATH}" -ENV PYTHONPATH="/usr/local/lib/python3.10/dist-packages:${PYTHONPATH}" - -# Optimized environment variables for hardware acceleration -ENV OPENCV_FFMPEG_CAPTURE_OPTIONS="rtsp_transport;tcp|hwaccel;cuda|hwaccel_device;0|video_codec;h264_cuvid|hwaccel_output_format;cuda" -ENV OPENCV_FFMPEG_WRITER_OPTIONS="video_codec;h264_nvenc|preset;fast|tune;zerolatency|gpu;0" -ENV CUDA_VISIBLE_DEVICES=0 -ENV NVIDIA_VISIBLE_DEVICES=all -ENV NVIDIA_DRIVER_CAPABILITIES=compute,video,utility - -# Copy and install base requirements (exclude opencv-python since we built from source) +# Copy and install base requirements (ML dependencies that rarely change) COPY requirements.base.txt . -RUN grep -v opencv-python requirements.base.txt > requirements.tmp && \ - mv requirements.tmp requirements.base.txt && \ - pip install --no-cache-dir -r requirements.base.txt +RUN pip install --no-cache-dir -r requirements.base.txt # Set working directory WORKDIR /app -# Create images directory for bind mount -RUN mkdir -p /app/images && \ - chmod 755 /app/images - # This base image will be reused for all worker builds CMD ["python3", "-m", "fastapi", "run", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/REFACTOR_PLAN.md b/REFACTOR_PLAN.md deleted file mode 100644 index e940ffd..0000000 --- a/REFACTOR_PLAN.md +++ /dev/null @@ -1,545 +0,0 @@ -# Detector Worker Refactoring Plan - -## Project Overview - -Transform the current monolithic structure (~4000 lines across `app.py` and `siwatsystem/pympta.py`) into a modular, maintainable system with clear separation of concerns. The goal is to make the sophisticated computer vision pipeline easily understandable for other engineers while maintaining all existing functionality. - -## Current System Flow Understanding - -### Validated System Flow -1. **WebSocket Connection** → Backend connects and sends `setSubscriptionList` -2. **Model Management** → Download unique `.mpta` files to `models/` and extract -3. **Tracking Phase** → Continuous tracking with `front_rear_detection_v1.pt` -4. **Validation Phase** → Validate stable car (not just passing by) -5. **Pipeline Execution** → - - Detect car with `yolo11m.pt` - - **Branch 1**: Front/rear detection → crop frontal → save to Redis + brand classification - - **Branch 2**: Body type classification from car crop -6. **Communication** → Send `imageDetection` → Backend generates `sessionId` → Fueling starts -7. **Post-Fueling** → Backend clears `sessionId` → Continue tracking same car to avoid re-pipeline - -### Core Responsibilities Identified -1. **WebSocket Communication** - Message handling and protocol compliance -2. **Stream Management** - RTSP/HTTP frame processing and buffering -3. **Model Management** - MPTA download, extraction, and loading -4. **Pipeline Configuration** - Parse `pipeline.json` and setup execution flow -5. **Vehicle Tracking** - Continuous tracking and car identification -6. **Validation Logic** - Stable car detection vs. passing-by cars -7. **Detection Pipeline** - Main ML pipeline with parallel branches -8. **Data Persistence** - Redis/PostgreSQL operations -9. **Session Management** - Handle session IDs and lifecycle - -## Proposed Directory Structure - -``` -core/ -├── communication/ -│ ├── __init__.py -│ ├── websocket.py # WebSocket message handling & protocol -│ ├── messages.py # Message types and validation -│ ├── models.py # Message data structures -│ └── state.py # Worker state management -├── streaming/ -│ ├── __init__.py -│ ├── manager.py # Stream coordination and lifecycle -│ ├── readers.py # RTSP/HTTP frame readers -│ └── buffers.py # Frame buffering and caching -├── models/ -│ ├── __init__.py -│ ├── manager.py # MPTA download and model loading -│ ├── pipeline.py # Pipeline.json parser and config -│ └── inference.py # YOLO model wrapper and optimization -├── tracking/ -│ ├── __init__.py -│ ├── tracker.py # Vehicle tracking with front_rear_detection_v1 -│ ├── validator.py # Stable car validation logic -│ └── integration.py # Tracking-pipeline integration -├── detection/ -│ ├── __init__.py -│ ├── pipeline.py # Main detection pipeline orchestration -│ └── branches.py # Parallel branch processing (brand/bodytype) -└── storage/ - ├── __init__.py - ├── redis.py # Redis operations and image storage - └── database.py # PostgreSQL operations (existing - will be moved) -``` - -## Implementation Strategy (Feature-by-Feature Testing) - -### Phase 1: Communication Layer -- WebSocket message handling (setSubscriptionList, sessionId management) -- HTTP API endpoints (camera image retrieval) -- Worker state reporting - -### Phase 2: Pipeline Configuration Reader -- Parse `pipeline.json` -- Model dependency resolution -- Branch configuration setup - -### Phase 3: Tracking System -- Continuous vehicle tracking -- Car identification and persistence - -### Phase 4: Tracking Validator -- Stable car detection logic -- Passing-by vs. fueling car differentiation - -### Phase 5: Model Pipeline Execution -- Main detection pipeline -- Parallel branch processing -- Redis/DB integration - -### Phase 6: Post-Session Tracking Validation -- Same car validation after sessionId cleared -- Prevent duplicate pipeline execution - -## Key Preservation Requirements -- **HTTP Endpoint**: `/camera/{camera_id}/image` must remain unchanged -- **WebSocket Protocol**: Full compliance with `worker.md` specification -- **MPTA Format**: Maintain compatibility with existing model archives -- **Database Schema**: Keep existing PostgreSQL structure -- **Redis Integration**: Preserve image storage and pub/sub functionality -- **Configuration**: Maintain `config.json` compatibility -- **Logging**: Preserve structured logging format - -## Expected Benefits -- **Maintainability**: Single responsibility modules (~200-400 lines each) -- **Testability**: Independent testing of each component -- **Readability**: Clear separation of concerns -- **Scalability**: Easy to extend and modify individual components -- **Documentation**: Self-documenting code structure - ---- - -# Comprehensive TODO List - -## ✅ Phase 1: Project Setup & Communication Layer - COMPLETED - -### 1.1 Project Structure Setup -- ✅ Create `core/` directory structure -- ✅ Create all module directories and `__init__.py` files -- ✅ Set up logging configuration for new modules -- ✅ Update imports in existing files to prepare for migration - -### 1.2 Communication Module (`core/communication/`) -- ✅ **Create `models.py`** - Message data structures - - ✅ Define WebSocket message models (SubscriptionList, StateReport, etc.) - - ✅ Add validation schemas for incoming messages - - ✅ Create response models for outgoing messages - -- ✅ **Create `messages.py`** - Message types and validation - - ✅ Implement message type constants - - ✅ Add message validation functions - - ✅ Create message builders for common responses - -- ✅ **Create `websocket.py`** - WebSocket message handling - - ✅ Extract WebSocket connection management from `app.py` - - ✅ Implement message routing and dispatching - - ✅ Add connection lifecycle management (connect, disconnect, reconnect) - - ✅ Handle `setSubscriptionList` message processing - - ✅ Handle `setSessionId` and `setProgressionStage` messages - - ✅ Handle `requestState` and `patchSessionResult` messages - -- ✅ **Create `state.py`** - Worker state management - - ✅ Extract state reporting logic from `app.py` - - ✅ Implement system metrics collection (CPU, memory, GPU) - - ✅ Manage active subscriptions state - - ✅ Handle session ID mapping and storage - -### 1.3 HTTP API Preservation -- ✅ **Preserve `/camera/{camera_id}/image` endpoint** - - ✅ Extract REST API logic from `app.py` - - ✅ Ensure frame caching mechanism works with new structure - - ✅ Maintain exact same response format and error handling - -### 1.4 Testing Phase 1 -- ✅ Test WebSocket connection and message handling -- ✅ Test HTTP API endpoint functionality -- ✅ Verify state reporting works correctly -- ✅ Test session management functionality - -### 1.5 Phase 1 Results -- ✅ **Modular Architecture**: Transformed ~900 lines into 4 focused modules (~200 lines each) -- ✅ **WebSocket Protocol**: Full compliance with worker.md specification -- ✅ **System Metrics**: Real-time CPU, memory, GPU monitoring -- ✅ **State Management**: Thread-safe subscription and session tracking -- ✅ **Backward Compatibility**: All existing endpoints preserved -- ✅ **Modern FastAPI**: Lifespan events, Pydantic v2 compatibility - -## ✅ Phase 2: Pipeline Configuration & Model Management - COMPLETED - -### 2.1 Models Module (`core/models/`) -- ✅ **Create `pipeline.py`** - Pipeline.json parser - - ✅ Extract pipeline configuration parsing from `pympta.py` - - ✅ Implement pipeline validation - - ✅ Add configuration schema validation - - ✅ Handle Redis and PostgreSQL configuration parsing - -- ✅ **Create `manager.py`** - MPTA download and model loading - - ✅ Extract MPTA download logic from `pympta.py` - - ✅ Implement ZIP extraction and validation - - ✅ Add model file management and caching - - ✅ Handle model loading with GPU optimization - - ✅ Implement model dependency resolution - -- ✅ **Create `inference.py`** - YOLO model wrapper - - ✅ Create unified YOLO model interface - - ✅ Add inference optimization and caching - - ✅ Implement batch processing capabilities - - ✅ Handle model switching and memory management - -### 2.2 Testing Phase 2 -- ✅ Test MPTA file download and extraction -- ✅ Test pipeline.json parsing and validation -- ✅ Test model loading with different configurations -- ✅ Verify GPU optimization works correctly - -### 2.3 Phase 2 Results -- ✅ **ModelManager**: Downloads, extracts, and manages MPTA files with model ID-based directory structure -- ✅ **PipelineParser**: Parses and validates pipeline.json with full support for Redis, PostgreSQL, tracking, and branches -- ✅ **YOLOWrapper**: Unified interface for YOLO models with caching, tracking, and classification support -- ✅ **Model Caching**: Shared model cache across instances to optimize memory usage -- ✅ **Dependency Resolution**: Automatically identifies and tracks all model file dependencies - -## ✅ Phase 3: Streaming System - COMPLETED - -### 3.1 Streaming Module (`core/streaming/`) -- ✅ **Create `readers.py`** - RTSP/HTTP frame readers - - ✅ Extract `frame_reader` function from `app.py` - - ✅ Extract `snapshot_reader` function from `app.py` - - ✅ Add connection management and retry logic - - ✅ Implement frame rate control and optimization - -- ✅ **Create `buffers.py`** - Frame buffering and caching - - ✅ Extract frame buffer management from `app.py` - - ✅ Implement efficient frame caching for REST API - - ✅ Add buffer size management and memory optimization - -- ✅ **Create `manager.py`** - Stream coordination - - ✅ Extract stream lifecycle management from `app.py` - - ✅ Implement shared stream optimization - - ✅ Add subscription reconciliation logic - - ✅ Handle stream sharing across multiple subscriptions - -### 3.2 Testing Phase 3 -- ✅ Test RTSP stream reading and buffering -- ✅ Test HTTP snapshot capture functionality -- ✅ Test shared stream optimization -- ✅ Verify frame caching for REST API access - -### 3.3 Phase 3 Results -- ✅ **RTSPReader**: OpenCV-based RTSP stream reader with automatic reconnection and frame callbacks -- ✅ **HTTPSnapshotReader**: Periodic HTTP snapshot capture with HTTPBasicAuth and HTTPDigestAuth support -- ✅ **FrameBuffer**: Thread-safe frame storage with automatic aging and cleanup -- ✅ **CacheBuffer**: Enhanced frame cache with cropping support and highest quality JPEG encoding (default quality=100) -- ✅ **StreamManager**: Complete stream lifecycle management with shared optimization and subscription reconciliation -- ✅ **Authentication Support**: Proper handling of credentials in URLs with automatic auth type detection -- ✅ **Real Camera Testing**: Verified with authenticated RTSP (1280x720) and HTTP snapshot (2688x1520) cameras -- ✅ **Production Ready**: Stable concurrent streaming from multiple camera sources -- ✅ **Dependencies**: Added opencv-python, numpy, and requests to requirements.txt - -### 3.4 Recent Streaming Enhancements (Post-Phase 3) -- ✅ **Format-Specific Optimization**: Tailored for 1280x720@6fps RTSP streams and 2560x1440 HTTP snapshots -- ✅ **H.264 Error Recovery**: Enhanced error handling for corrupted frames with automatic stream recovery -- ✅ **Frame Validation**: Implemented corruption detection using edge density analysis -- ✅ **Buffer Size Optimization**: Adjusted buffer limits to 3MB for RTSP frames (1280x720x3 bytes) -- ✅ **FFMPEG Integration**: Added environment variables to suppress verbose H.264 decoder errors -- ✅ **URL Preservation**: Maintained clean RTSP URLs without parameter injection -- ✅ **Type Detection**: Automatic stream type detection based on frame dimensions -- ✅ **Quality Settings**: Format-specific JPEG quality (90% for RTSP, 95% for HTTP) - -## ✅ Phase 4: Vehicle Tracking System - COMPLETED - -### 4.1 Tracking Module (`core/tracking/`) -- ✅ **Create `tracker.py`** - Vehicle tracking implementation (305 lines) - - ✅ Implement continuous tracking with configurable model (front_rear_detection_v1.pt) - - ✅ Add vehicle identification and persistence with TrackedVehicle dataclass - - ✅ Implement tracking state management with thread-safe operations - - ✅ Add bounding box tracking and motion analysis with position history - - ✅ Multi-class tracking support for complex detection scenarios - -- ✅ **Create `validator.py`** - Stable car validation (417 lines) - - ✅ Implement stable car detection algorithm with multiple validation criteria - - ✅ Add passing-by vs. fueling car differentiation using velocity and position analysis - - ✅ Implement validation thresholds and timing with configurable parameters - - ✅ Add confidence scoring for validation decisions with state history - - ✅ Advanced motion analysis with velocity smoothing and position variance - -- ✅ **Create `integration.py`** - Tracking-pipeline integration (547 lines) - - ✅ Connect tracking system with main pipeline through TrackingPipelineIntegration - - ✅ Handle tracking state transitions and session management - - ✅ Implement post-session tracking validation with cooldown periods - - ✅ Add same-car validation after sessionId cleared with 30-second cooldown - - ✅ Car abandonment detection with automatic timeout monitoring - - ✅ Mock detection system for backend communication - - ✅ Async pipeline execution with proper error handling - -### 4.2 Testing Phase 4 -- ✅ Test continuous vehicle tracking functionality -- ✅ Test stable car validation logic -- ✅ Test integration with existing pipeline -- ✅ Verify tracking performance and accuracy -- ✅ Test car abandonment detection with null detection messages -- ✅ Verify session management and progression stage handling - -### 4.3 Phase 4 Results -- ✅ **VehicleTracker**: Complete tracking implementation with YOLO tracking integration, position history, and stability calculations -- ✅ **StableCarValidator**: Sophisticated validation logic using velocity, position variance, and state consistency -- ✅ **TrackingPipelineIntegration**: Full integration with pipeline system including session management and async processing -- ✅ **StreamManager Integration**: Updated streaming manager to process tracking on every frame with proper threading -- ✅ **Thread-Safe Operations**: All tracking operations are thread-safe with proper locking mechanisms -- ✅ **Configurable Parameters**: All tracking parameters are configurable through pipeline.json -- ✅ **Session Management**: Complete session lifecycle management with post-fueling validation -- ✅ **Statistics and Monitoring**: Comprehensive statistics collection for tracking performance -- ✅ **Car Abandonment Detection**: Automatic detection when cars leave without fueling, sends `detection: null` to backend -- ✅ **Message Protocol**: Fixed JSON serialization to include `detection: null` for abandonment notifications -- ✅ **Streaming Optimization**: Enhanced RTSP/HTTP readers for 1280x720@6fps RTSP and 2560x1440 HTTP snapshots -- ✅ **Error Recovery**: Improved H.264 error handling and corrupted frame detection - -## ✅ Phase 5: Detection Pipeline System - COMPLETED - -### 5.1 Detection Module (`core/detection/`) ✅ -- ✅ **Create `pipeline.py`** - Main detection orchestration (574 lines) - - ✅ Extracted main pipeline execution from `pympta.py` with full orchestration - - ✅ Implemented detection flow coordination with async execution - - ✅ Added pipeline state management with comprehensive statistics - - ✅ Handled pipeline result aggregation with branch synchronization - - ✅ Redis and database integration with error handling - - ✅ Immediate and parallel action execution with template resolution - -- ✅ **Create `branches.py`** - Parallel branch processing (442 lines) - - ✅ Extracted parallel branch execution from `pympta.py` - - ✅ Implemented ThreadPoolExecutor-based parallel processing - - ✅ Added branch synchronization and result collection - - ✅ Handled branch failure and retry logic with graceful degradation - - ✅ Support for nested branches and model caching - - ✅ Both detection and classification model support - -### 5.2 Storage Module (`core/storage/`) ✅ -- ✅ **Create `redis.py`** - Redis operations (410 lines) - - ✅ Extracted Redis action execution from `pympta.py` - - ✅ Implemented async image storage with region cropping - - ✅ Added pub/sub messaging functionality with JSON support - - ✅ Handled Redis connection management and retry logic - - ✅ Added statistics tracking and health monitoring - - ✅ Support for various image formats (JPEG, PNG) with quality control - -- ✅ **Move `database.py`** - PostgreSQL operations (339 lines) - - ✅ Moved existing `archive/siwatsystem/database.py` to `core/storage/` - - ✅ Updated imports and integration points - - ✅ Ensured compatibility with new module structure - - ✅ Added session management and statistics methods - - ✅ Enhanced error handling and connection management - -### 5.3 Integration Updates ✅ -- ✅ **Updated `core/tracking/integration.py`** - - ✅ Added DetectionPipeline integration - - ✅ Replaced placeholder `_execute_pipeline` with real implementation - - ✅ Added detection pipeline initialization and cleanup - - ✅ Integrated with existing tracking system flow - - ✅ Maintained backward compatibility with test mode - -### 5.4 Testing Phase 5 ✅ -- ✅ Verified module imports work correctly -- ✅ All new modules follow established coding patterns -- ✅ Integration points properly connected -- ✅ Error handling and cleanup methods implemented -- ✅ Statistics and monitoring capabilities added - -### 5.5 Phase 5 Results ✅ -- ✅ **DetectionPipeline**: Complete detection orchestration with Redis/PostgreSQL integration, async execution, and comprehensive error handling -- ✅ **BranchProcessor**: Parallel branch execution with ThreadPoolExecutor, model caching, and nested branch support -- ✅ **RedisManager**: Async Redis operations with image storage, pub/sub messaging, and connection management -- ✅ **DatabaseManager**: Enhanced PostgreSQL operations with session management and statistics -- ✅ **Module Integration**: Seamless integration with existing tracking system while maintaining compatibility -- ✅ **Error Handling**: Comprehensive error handling and graceful degradation throughout all components -- ✅ **Performance**: Optimized parallel processing and caching for high-performance pipeline execution - -## ✅ Additional Implemented Features (Not in Original Plan) - -### License Plate Recognition Integration (`core/storage/license_plate.py`) ✅ -- ✅ **LicensePlateManager**: Subscribes to Redis channel `license_results` for external LPR service -- ✅ **Multi-format Support**: Handles various message formats from LPR service -- ✅ **Result Caching**: 5-minute TTL for license plate results -- ✅ **WebSocket Integration**: Sends combined `imageDetection` messages with license data -- ✅ **Asynchronous Processing**: Non-blocking Redis pub/sub listener - -### Advanced Session State Management (`core/communication/state.py`) ✅ -- ✅ **Session ID Mapping**: Per-display session identifier tracking -- ✅ **Progression Stage Tracking**: Workflow state per display (welcome, car_wait_staff, finished, cleared) -- ✅ **Thread-Safe Operations**: RLock-based synchronization for concurrent access -- ✅ **Comprehensive State Reporting**: Full system state for debugging - -### Car Abandonment Detection (`core/tracking/integration.py`) ✅ -- ✅ **Abandonment Monitoring**: Detects cars leaving without completing fueling -- ✅ **Timeout Configuration**: 3-second abandonment timeout -- ✅ **Null Detection Messages**: Sends `detection: null` to backend for abandoned cars -- ✅ **Automatic Cleanup**: Removes abandoned sessions from tracking - -### Enhanced Message Protocol (`core/communication/models.py`) ✅ -- ✅ **PatchSessionResult**: Session data patching support -- ✅ **SetProgressionStage**: Workflow stage management messages -- ✅ **Null Detection Handling**: Support for abandonment notifications -- ✅ **Complex Detection Structure**: Supports both classification and null states - -### Comprehensive Timeout and Cooldown Systems ✅ -- ✅ **Post-Session Cooldown**: 30-second cooldown after session clearing -- ✅ **Processing Cooldown**: 10-second cooldown for repeated processing -- ✅ **Abandonment Timeout**: 3-second timeout for car abandonment detection -- ✅ **Vehicle Expiration**: 2-second timeout for tracking cleanup -- ✅ **Stream Timeouts**: 30-second connection timeout management - -## 📋 Phase 6: Integration & Final Testing - -### 6.1 Main Application Refactoring -- [ ] **Refactor `app.py`** - - [ ] Remove extracted functionality - - [ ] Update to use new modular structure - - [ ] Maintain FastAPI application structure - - [ ] Update imports and dependencies - -- [ ] **Clean up `siwatsystem/pympta.py`** - - [ ] Remove extracted functionality - - [ ] Keep only necessary legacy compatibility code - - [ ] Update imports to use new modules - -### 6.2 Post-Session Tracking Validation -- [ ] Implement same-car validation after sessionId cleared -- [ ] Add logic to prevent duplicate pipeline execution -- [ ] Test tracking persistence through session lifecycle -- [ ] Verify correct behavior during edge cases - -### 6.3 Configuration & Documentation -- [ ] Update configuration handling for new structure -- [ ] Ensure `config.json` compatibility maintained -- [ ] Update logging configuration for all modules -- [ ] Add module-level documentation - -### 6.4 Comprehensive Testing -- [ ] **Integration Testing** - - [ ] Test complete system flow end-to-end - - [ ] Test all WebSocket message types - - [ ] Test HTTP API endpoints - - [ ] Test error handling and recovery - -- [ ] **Performance Testing** - - [ ] Verify system performance is maintained - - [ ] Test memory usage optimization - - [ ] Test GPU utilization efficiency - - [ ] Benchmark against original implementation - -- [ ] **Edge Case Testing** - - [ ] Test connection failures and reconnection - - [ ] Test model loading failures - - [ ] Test stream interruption handling - - [ ] Test concurrent subscription management - -### 6.5 Logging Optimization & Cleanup ✅ -- ✅ **Removed Debug Frame Saving** - - ✅ Removed hard-coded debug frame saving in `core/detection/pipeline.py` - - ✅ Removed hard-coded debug frame saving in `core/detection/branches.py` - - ✅ Eliminated absolute debug paths for production use - -- ✅ **Eliminated Test/Mock Functionality** - - ✅ Removed `save_frame_for_testing` function from `core/streaming/buffers.py` - - ✅ Removed `save_test_frames` configuration from `StreamConfig` - - ✅ Cleaned up test frame saving calls in stream manager - - ✅ Updated module exports to remove test functions - -- ✅ **Reduced Verbose Logging** - - ✅ Commented out verbose frame storage logging (every frame) - - ✅ Converted debug-level info logs to proper debug level - - ✅ Reduced repetitive frame dimension logging - - ✅ Maintained important model results and detection confidence logging - - ✅ Kept critical pipeline execution and error messages - -- ✅ **Production-Ready Logging** - - ✅ Clean startup and initialization messages - - ✅ Clear model loading and pipeline status - - ✅ Preserved detection results with confidence scores - - ✅ Maintained session management and tracking messages - - ✅ Kept important error and warning messages - -### 6.6 Final Cleanup -- [ ] Remove any remaining duplicate code -- [ ] Optimize imports across all modules -- [ ] Clean up temporary files and debugging code -- [ ] Update project documentation - -## 📋 Post-Refactoring Tasks - -### Documentation Updates -- [ ] Update `CLAUDE.md` with new architecture -- [ ] Create module-specific documentation -- [ ] Update installation and deployment guides -- [ ] Add troubleshooting guide for new structure - -### Code Quality -- [ ] Add type hints to all new modules -- [ ] Implement proper error handling patterns -- [ ] Add logging consistency across modules -- [ ] Ensure proper resource cleanup - -### Future Enhancements (Optional) -- [ ] Add unit tests for each module -- [ ] Implement monitoring and metrics collection -- [ ] Add configuration validation -- [ ] Consider adding dependency injection container - ---- - -## Success Criteria - -✅ **Modularity**: Each module has a single, clear responsibility -✅ **Testability**: Each phase can be tested independently -✅ **Maintainability**: Code is easy to understand and modify -✅ **Compatibility**: All existing functionality preserved -✅ **Performance**: System performance is maintained or improved -✅ **Documentation**: Clear documentation for new architecture - -## Risk Mitigation - -- **Feature-by-feature testing** ensures functionality is preserved at each step -- **Gradual migration** minimizes risk of breaking existing functionality -- **Preserve critical interfaces** (WebSocket protocol, HTTP endpoints) -- **Maintain backward compatibility** with existing configurations -- **Comprehensive testing** at each phase before proceeding - ---- - -## 🎯 Current Status Summary - -### ✅ Completed Phases (95% Complete) -- **Phase 1**: Communication Layer - ✅ COMPLETED -- **Phase 2**: Pipeline Configuration & Model Management - ✅ COMPLETED -- **Phase 3**: Streaming System - ✅ COMPLETED -- **Phase 4**: Vehicle Tracking System - ✅ COMPLETED -- **Phase 5**: Detection Pipeline System - ✅ COMPLETED -- **Additional Features**: License Plate Recognition, Car Abandonment, Session Management - ✅ COMPLETED - -### 📋 Remaining Work (5%) -- **Phase 6**: Final Integration & Testing - - Main application cleanup (`app.py` and `pympta.py`) - - Comprehensive integration testing - - Performance benchmarking - - Documentation updates - -### 🚀 Production Ready Features -- ✅ **Modular Architecture**: ~4000 lines refactored into 20+ focused modules -- ✅ **WebSocket Protocol**: Full compliance with all message types -- ✅ **License Plate Recognition**: External LPR service integration via Redis -- ✅ **Car Abandonment Detection**: Automatic detection and notification -- ✅ **Session Management**: Complete lifecycle with progression stages -- ✅ **Parallel Processing**: ThreadPoolExecutor for branch execution -- ✅ **Redis Integration**: Pub/sub, image storage, LPR subscription -- ✅ **PostgreSQL Integration**: Automatic schema management, combined updates -- ✅ **Stream Optimization**: Shared streams, format-specific handling -- ✅ **Error Recovery**: H.264 corruption detection, automatic reconnection -- ✅ **Production Logging**: Clean, informative logging without debug clutter - -### 📊 Metrics -- **Modules Created**: 20+ specialized modules -- **Lines Per Module**: ~200-500 (highly maintainable) -- **Test Coverage**: Feature-by-feature validation completed -- **Performance**: Maintained or improved from original implementation -- **Backward Compatibility**: 100% preserved \ No newline at end of file diff --git a/app.py b/app.py index 8e17400..09cb227 100644 --- a/app.py +++ b/app.py @@ -1,586 +1,903 @@ -""" -Detector Worker - Main FastAPI Application -Refactored modular architecture for computer vision pipeline processing. -""" -import json -import logging +from typing import Any, Dict import os +import json import time +import queue +import torch import cv2 -from contextlib import asynccontextmanager -from typing import Dict, Any +import numpy as np +import base64 +import logging +import threading +import requests +import asyncio +import psutil +import zipfile +from urllib.parse import urlparse from fastapi import FastAPI, WebSocket, HTTPException +from fastapi.websockets import WebSocketDisconnect from fastapi.responses import Response +from websockets.exceptions import ConnectionClosedError +from ultralytics import YOLO -# Import new modular communication system -from core.communication.websocket import websocket_endpoint -from core.communication.state import worker_state +# Import shared pipeline functions +from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline + +app = FastAPI() + +# Global dictionaries to keep track of models and streams +# "models" now holds a nested dict: { camera_id: { modelId: model_tree } } +models: Dict[str, Dict[str, Any]] = {} +streams: Dict[str, Dict[str, Any]] = {} +# Store session IDs per display +session_ids: Dict[str, int] = {} +# Track shared camera streams by camera URL +camera_streams: Dict[str, Dict[str, Any]] = {} +# Map subscriptions to their camera URL +subscription_to_camera: Dict[str, str] = {} +# Store latest frames for REST API access (separate from processing buffer) +latest_frames: Dict[str, Any] = {} + +with open("config.json", "r") as f: + config = json.load(f) + +poll_interval = config.get("poll_interval_ms", 100) +reconnect_interval = config.get("reconnect_interval_sec", 5) +TARGET_FPS = config.get("target_fps", 10) +poll_interval = 1000 / TARGET_FPS +logging.info(f"Poll interval: {poll_interval}ms") +max_streams = config.get("max_streams", 5) +max_retries = config.get("max_retries", 3) # Configure logging logging.basicConfig( - level=logging.DEBUG, + level=logging.INFO, # Set to INFO level for less verbose output format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", handlers=[ - logging.FileHandler("detector_worker.log"), - logging.StreamHandler() + logging.FileHandler("detector_worker.log"), # Write logs to a file + logging.StreamHandler() # Also output to console ] ) +# Create a logger specifically for this application logger = logging.getLogger("detector_worker") -logger.setLevel(logging.DEBUG) +logger.setLevel(logging.DEBUG) # Set app-specific logger to DEBUG level -# Frames are now stored in the shared cache buffer from core.streaming.buffers -# latest_frames = {} # Deprecated - using shared_cache_buffer instead +# Ensure all other libraries (including root) use at least INFO level +logging.getLogger().setLevel(logging.INFO) +logger.info("Starting detector worker application") +logger.info(f"Configuration: Target FPS: {TARGET_FPS}, Max streams: {max_streams}, Max retries: {max_retries}") -# Health monitoring recovery handlers -def _handle_stream_restart_recovery(component: str, details: Dict[str, Any]) -> bool: - """Handle stream restart recovery at the application level.""" - try: - from core.streaming.manager import shared_stream_manager - - # Extract camera ID from component name (e.g., "stream_cam-001" -> "cam-001") - if component.startswith("stream_"): - camera_id = component[7:] # Remove "stream_" prefix - else: - camera_id = component - - logger.info(f"Attempting stream restart recovery for {camera_id}") - - # Find and restart the subscription - subscriptions = shared_stream_manager.get_all_subscriptions() - for sub_info in subscriptions: - if sub_info.camera_id == camera_id: - # Remove and re-add the subscription - shared_stream_manager.remove_subscription(sub_info.subscription_id) - time.sleep(1.0) # Brief delay - - # Re-add subscription - success = shared_stream_manager.add_subscription( - sub_info.subscription_id, - sub_info.stream_config, - sub_info.crop_coords, - sub_info.model_id, - sub_info.model_url, - sub_info.tracking_integration - ) - - if success: - logger.info(f"Stream restart recovery successful for {camera_id}") - return True - else: - logger.error(f"Stream restart recovery failed for {camera_id}") - return False - - logger.warning(f"No subscription found for camera {camera_id} during recovery") - return False - - except Exception as e: - logger.error(f"Error in stream restart recovery for {component}: {e}") - return False - - -def _handle_stream_reconnect_recovery(component: str, details: Dict[str, Any]) -> bool: - """Handle stream reconnect recovery at the application level.""" - try: - from core.streaming.manager import shared_stream_manager - - # Extract camera ID from component name - if component.startswith("stream_"): - camera_id = component[7:] - else: - camera_id = component - - logger.info(f"Attempting stream reconnect recovery for {camera_id}") - - # For reconnect, we just need to trigger the stream's internal reconnect - # The stream readers handle their own reconnection logic - active_cameras = shared_stream_manager.get_active_cameras() - - if camera_id in active_cameras: - logger.info(f"Stream reconnect recovery triggered for {camera_id}") - return True - else: - logger.warning(f"Camera {camera_id} not found in active cameras during reconnect recovery") - return False - - except Exception as e: - logger.error(f"Error in stream reconnect recovery for {component}: {e}") - return False - -# Lifespan event handler (modern FastAPI approach) -@asynccontextmanager -async def lifespan(app: FastAPI): - """Application lifespan management.""" - # Startup - logger.info("Detector Worker started successfully") - - # Initialize health monitoring system - try: - from core.monitoring.health import health_monitor - from core.monitoring.stream_health import stream_health_tracker - from core.monitoring.thread_health import thread_health_monitor - from core.monitoring.recovery import recovery_manager - - # Start health monitoring - health_monitor.start() - logger.info("Health monitoring system started") - - # Register recovery handlers for stream management - from core.streaming.manager import shared_stream_manager - recovery_manager.register_recovery_handler( - "restart_stream", - _handle_stream_restart_recovery - ) - recovery_manager.register_recovery_handler( - "reconnect", - _handle_stream_reconnect_recovery - ) - - logger.info("Recovery handlers registered") - - except Exception as e: - logger.error(f"Failed to initialize health monitoring: {e}") - - logger.info("WebSocket endpoint available at: ws://0.0.0.0:8001/") - logger.info("HTTP camera endpoint available at: http://0.0.0.0:8001/camera/{camera_id}/image") - logger.info("Health check available at: http://0.0.0.0:8001/health") - logger.info("Detailed health monitoring available at: http://0.0.0.0:8001/health/detailed") - logger.info("Ready and waiting for backend WebSocket connections") - - yield - - # Shutdown - logger.info("Detector Worker shutting down...") - - # Stop health monitoring - try: - from core.monitoring.health import health_monitor - health_monitor.stop() - logger.info("Health monitoring system stopped") - except Exception as e: - logger.error(f"Error stopping health monitoring: {e}") - - # Clear all state - worker_state.set_subscriptions([]) - worker_state.session_ids.clear() - worker_state.progression_stages.clear() - # latest_frames.clear() # No longer needed - frames are in shared_cache_buffer - logger.info("Detector Worker shutdown complete") - -# Create FastAPI application with detailed WebSocket logging -app = FastAPI(title="Detector Worker", version="2.0.0", lifespan=lifespan) - -# Add middleware to log all requests -@app.middleware("http") -async def log_requests(request, call_next): - start_time = time.time() - response = await call_next(request) - process_time = time.time() - start_time - logger.debug(f"HTTP {request.method} {request.url} - {response.status_code} ({process_time:.3f}s)") - return response - -# Load configuration -config_path = "config.json" -if os.path.exists(config_path): - with open(config_path, "r") as f: - config = json.load(f) - logger.info(f"Loaded configuration from {config_path}") -else: - # Default configuration - config = { - "poll_interval_ms": 100, - "reconnect_interval_sec": 5, - "target_fps": 10, - "max_streams": 20, - "max_retries": 3 - } - logger.warning(f"Configuration file {config_path} not found, using defaults") - -# Ensure models directory exists +# Ensure the models directory exists os.makedirs("models", exist_ok=True) logger.info("Ensured models directory exists") -# Stream manager already initialized at module level with max_streams=20 -# Calling initialize_stream_manager() creates a NEW instance, breaking references -# from core.streaming import initialize_stream_manager -# initialize_stream_manager(max_streams=config.get('max_streams', 10)) -logger.info(f"Using stream manager with max_streams=20 (module-level initialization)") +# Constants for heartbeat and timeouts +HEARTBEAT_INTERVAL = 2 # seconds +WORKER_TIMEOUT_MS = 10000 +logger.debug(f"Heartbeat interval set to {HEARTBEAT_INTERVAL} seconds") -# Frames are now stored in the shared cache buffer from core.streaming.buffers -# latest_frames = {} # Deprecated - using shared_cache_buffer instead - -logger.info("Starting detector worker application (refactored)") -logger.info(f"Configuration: Target FPS: {config.get('target_fps', 10)}, " - f"Max streams: {config.get('max_streams', 5)}, " - f"Max retries: {config.get('max_retries', 3)}") - - -@app.websocket("/") -async def websocket_handler(websocket: WebSocket): - """ - Main WebSocket endpoint for backend communication. - Handles all protocol messages according to worker.md specification. - """ - client_info = f"{websocket.client.host}:{websocket.client.port}" if websocket.client else "unknown" - logger.info(f"[RX ← Backend] New WebSocket connection request from {client_info}") +# Locks for thread-safe operations +streams_lock = threading.Lock() +models_lock = threading.Lock() +logger.debug("Initialized thread locks") +# Add helper to download mpta ZIP file from a remote URL +def download_mpta(url: str, dest_path: str) -> str: try: - await websocket_endpoint(websocket) + logger.info(f"Starting download of model from {url} to {dest_path}") + os.makedirs(os.path.dirname(dest_path), exist_ok=True) + response = requests.get(url, stream=True) + if response.status_code == 200: + file_size = int(response.headers.get('content-length', 0)) + logger.info(f"Model file size: {file_size/1024/1024:.2f} MB") + downloaded = 0 + with open(dest_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + downloaded += len(chunk) + if file_size > 0 and downloaded % (file_size // 10) < 8192: # Log approximately every 10% + logger.debug(f"Download progress: {downloaded/file_size*100:.1f}%") + logger.info(f"Successfully downloaded mpta file from {url} to {dest_path}") + return dest_path + else: + logger.error(f"Failed to download mpta file (status code {response.status_code}): {response.text}") + return None except Exception as e: - logger.error(f"WebSocket handler error for {client_info}: {e}", exc_info=True) + logger.error(f"Exception downloading mpta file from {url}: {str(e)}", exc_info=True) + return None +# Add helper to fetch snapshot image from HTTP/HTTPS URL +def fetch_snapshot(url: str): + try: + from requests.auth import HTTPBasicAuth, HTTPDigestAuth + + # Parse URL to extract credentials + parsed = urlparse(url) + + # Prepare headers - some cameras require User-Agent + headers = { + 'User-Agent': 'Mozilla/5.0 (compatible; DetectorWorker/1.0)' + } + + # Reconstruct URL without credentials + clean_url = f"{parsed.scheme}://{parsed.hostname}" + if parsed.port: + clean_url += f":{parsed.port}" + clean_url += parsed.path + if parsed.query: + clean_url += f"?{parsed.query}" + + auth = None + if parsed.username and parsed.password: + # Try HTTP Digest authentication first (common for IP cameras) + try: + auth = HTTPDigestAuth(parsed.username, parsed.password) + response = requests.get(clean_url, auth=auth, headers=headers, timeout=10) + if response.status_code == 200: + logger.debug(f"Successfully authenticated using HTTP Digest for {clean_url}") + elif response.status_code == 401: + # If Digest fails, try Basic auth + logger.debug(f"HTTP Digest failed, trying Basic auth for {clean_url}") + auth = HTTPBasicAuth(parsed.username, parsed.password) + response = requests.get(clean_url, auth=auth, headers=headers, timeout=10) + if response.status_code == 200: + logger.debug(f"Successfully authenticated using HTTP Basic for {clean_url}") + except Exception as auth_error: + logger.debug(f"Authentication setup error: {auth_error}") + # Fallback to original URL with embedded credentials + response = requests.get(url, headers=headers, timeout=10) + else: + # No credentials in URL, make request as-is + response = requests.get(url, headers=headers, timeout=10) + + if response.status_code == 200: + # Convert response content to numpy array + nparr = np.frombuffer(response.content, np.uint8) + # Decode image + frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + if frame is not None: + logger.debug(f"Successfully fetched snapshot from {clean_url}, shape: {frame.shape}") + return frame + else: + logger.error(f"Failed to decode image from snapshot URL: {clean_url}") + return None + else: + logger.error(f"Failed to fetch snapshot (status code {response.status_code}): {clean_url}") + return None + except Exception as e: + logger.error(f"Exception fetching snapshot from {url}: {str(e)}") + return None +# Helper to get crop coordinates from stream +def get_crop_coords(stream): + return { + "cropX1": stream.get("cropX1"), + "cropY1": stream.get("cropY1"), + "cropX2": stream.get("cropX2"), + "cropY2": stream.get("cropY2") + } + +#################################################### +# REST API endpoint for image retrieval +#################################################### @app.get("/camera/{camera_id}/image") async def get_camera_image(camera_id: str): """ - HTTP endpoint to retrieve the latest frame from a camera as JPEG image. - - This endpoint is preserved for backward compatibility with existing systems. - - Args: - camera_id: The subscription identifier (e.g., "display-001;cam-001") - - Returns: - JPEG image as binary response - - Raises: - HTTPException: 404 if camera not found or no frame available - HTTPException: 500 if encoding fails + Get the current frame from a camera as JPEG image """ try: + # URL decode the camera_id to handle encoded characters like %3B for semicolon from urllib.parse import unquote - - # URL decode the camera_id to handle encoded characters original_camera_id = camera_id camera_id = unquote(camera_id) logger.debug(f"REST API request: original='{original_camera_id}', decoded='{camera_id}'") - - # Check if camera is in active subscriptions - subscription = worker_state.get_subscription(camera_id) - if not subscription: - logger.warning(f"Camera ID '{camera_id}' not found in active subscriptions") - available_cameras = list(worker_state.subscriptions.keys()) - logger.debug(f"Available cameras: {available_cameras}") - raise HTTPException( - status_code=404, - detail=f"Camera {camera_id} not found or not active" - ) - - # Extract actual camera_id from subscription identifier (displayId;cameraId) - # Frames are stored using just the camera_id part - actual_camera_id = camera_id.split(';')[-1] if ';' in camera_id else camera_id - - # Get frame from the shared cache buffer - from core.streaming.buffers import shared_cache_buffer - - # Only show buffer debug info if camera not found (to reduce log spam) - available_cameras = shared_cache_buffer.frame_buffer.get_camera_list() - - frame = shared_cache_buffer.get_frame(actual_camera_id) - if frame is None: - logger.warning(f"\033[93m[API] No frame for '{actual_camera_id}' - Available: {available_cameras}\033[0m") - raise HTTPException( - status_code=404, - detail=f"No frame available for camera {actual_camera_id}" - ) - - # Successful frame retrieval - log only occasionally to avoid spam - + + with streams_lock: + if camera_id not in streams: + logger.warning(f"Camera ID '{camera_id}' not found in streams. Current streams: {list(streams.keys())}") + raise HTTPException(status_code=404, detail=f"Camera {camera_id} not found or not active") + + # Check if we have a cached frame for this camera + if camera_id not in latest_frames: + logger.warning(f"No cached frame available for camera '{camera_id}'.") + raise HTTPException(status_code=404, detail=f"No frame available for camera {camera_id}") + + frame = latest_frames[camera_id] + logger.debug(f"Retrieved cached frame for camera '{camera_id}', frame shape: {frame.shape}") # Encode frame as JPEG success, buffer_img = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 85]) if not success: raise HTTPException(status_code=500, detail="Failed to encode image as JPEG") - + # Return image as binary response return Response(content=buffer_img.tobytes(), media_type="image/jpeg") - + except HTTPException: raise except Exception as e: logger.error(f"Error retrieving image for camera {camera_id}: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") +#################################################### +# Detection and frame processing functions +#################################################### +@app.websocket("/") +async def detect(websocket: WebSocket): + logger.info("WebSocket connection accepted") + persistent_data_dict = {} -@app.get("/session-image/{session_id}") -async def get_session_image(session_id: int): - """ - HTTP endpoint to retrieve the saved session image by session ID. - - Args: - session_id: The session ID to retrieve the image for - - Returns: - JPEG image as binary response - - Raises: - HTTPException: 404 if no image found for the session - HTTPException: 500 if reading image fails - """ - try: - from pathlib import Path - - # Images directory - images_dir = Path("images") - - if not images_dir.exists(): - logger.warning(f"Images directory does not exist") - raise HTTPException( - status_code=404, - detail=f"No images directory found" - ) - - # Use os.scandir() for efficient file searching (3-5x faster than glob.glob) - # Filter files matching session ID pattern: {session_id}_*.jpg - prefix = f"{session_id}_" - most_recent_file = None - most_recent_mtime = 0 - - with os.scandir(images_dir) as entries: - for entry in entries: - # Filter: must be a file, start with session_id prefix, and end with .jpg - if entry.is_file() and entry.name.startswith(prefix) and entry.name.endswith('.jpg'): - # Use cached stat info from DirEntry (much faster than separate stat calls) - entry_stat = entry.stat() - if entry_stat.st_mtime > most_recent_mtime: - most_recent_mtime = entry_stat.st_mtime - most_recent_file = entry.path - - if not most_recent_file: - logger.warning(f"No image found for session {session_id}") - raise HTTPException( - status_code=404, - detail=f"No image found for session {session_id}" - ) - - logger.info(f"Found session image for session {session_id}: {most_recent_file}") - - # Read the image file - with open(most_recent_file, 'rb') as f: - image_data = f.read() - - # Return image as binary response - return Response(content=image_data, media_type="image/jpeg") - - except HTTPException: - raise - except Exception as e: - logger.error(f"Error retrieving session image for session {session_id}: {str(e)}", exc_info=True) - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - - -@app.get("/health") -async def health_check(): - """Health check endpoint for monitoring.""" - return { - "status": "healthy", - "version": "2.0.0", - "active_subscriptions": len(worker_state.subscriptions), - "active_sessions": len(worker_state.session_ids) - } - - -@app.get("/health/detailed") -async def detailed_health_check(): - """Comprehensive health status with detailed monitoring data.""" - try: - from core.monitoring.health import health_monitor - from core.monitoring.stream_health import stream_health_tracker - from core.monitoring.thread_health import thread_health_monitor - from core.monitoring.recovery import recovery_manager - - # Get comprehensive health status - overall_health = health_monitor.get_health_status() - stream_metrics = stream_health_tracker.get_all_metrics() - thread_info = thread_health_monitor.get_all_thread_info() - recovery_stats = recovery_manager.get_recovery_stats() - - return { - "timestamp": time.time(), - "overall_health": overall_health, - "stream_metrics": stream_metrics, - "thread_health": thread_info, - "recovery_stats": recovery_stats, - "system_info": { - "active_subscriptions": len(worker_state.subscriptions), - "active_sessions": len(worker_state.session_ids), - "version": "2.0.0" - } - } - - except Exception as e: - logger.error(f"Error generating detailed health report: {e}") - raise HTTPException(status_code=500, detail=f"Health monitoring error: {str(e)}") - - -@app.get("/health/streams") -async def stream_health_status(): - """Stream-specific health monitoring.""" - try: - from core.monitoring.stream_health import stream_health_tracker - from core.streaming.buffers import shared_cache_buffer - - stream_metrics = stream_health_tracker.get_all_metrics() - buffer_stats = shared_cache_buffer.get_stats() - - return { - "timestamp": time.time(), - "stream_count": len(stream_metrics), - "stream_metrics": stream_metrics, - "buffer_stats": buffer_stats, - "frame_ages": { - camera_id: { - "age_seconds": time.time() - info["last_frame_time"] if info and info.get("last_frame_time") else None, - "total_frames": info.get("frame_count", 0) if info else 0 - } - for camera_id, info in stream_metrics.items() - } - } - - except Exception as e: - logger.error(f"Error generating stream health report: {e}") - raise HTTPException(status_code=500, detail=f"Stream health error: {str(e)}") - - -@app.get("/health/threads") -async def thread_health_status(): - """Thread-specific health monitoring.""" - try: - from core.monitoring.thread_health import thread_health_monitor - - thread_info = thread_health_monitor.get_all_thread_info() - deadlocks = thread_health_monitor.detect_deadlocks() - - return { - "timestamp": time.time(), - "thread_count": len(thread_info), - "thread_info": thread_info, - "potential_deadlocks": deadlocks, - "summary": { - "responsive_threads": sum(1 for info in thread_info.values() if info.get("is_responsive", False)), - "unresponsive_threads": sum(1 for info in thread_info.values() if not info.get("is_responsive", True)), - "deadlock_count": len(deadlocks) - } - } - - except Exception as e: - logger.error(f"Error generating thread health report: {e}") - raise HTTPException(status_code=500, detail=f"Thread health error: {str(e)}") - - -@app.get("/health/recovery") -async def recovery_status(): - """Recovery system status and history.""" - try: - from core.monitoring.recovery import recovery_manager - - recovery_stats = recovery_manager.get_recovery_stats() - - return { - "timestamp": time.time(), - "recovery_stats": recovery_stats, - "summary": { - "total_recoveries_last_hour": recovery_stats.get("total_recoveries_last_hour", 0), - "components_with_recovery_state": len(recovery_stats.get("recovery_states", {})), - "total_recovery_failures": sum( - state.get("failure_count", 0) - for state in recovery_stats.get("recovery_states", {}).values() - ), - "total_recovery_successes": sum( - state.get("success_count", 0) - for state in recovery_stats.get("recovery_states", {}).values() - ) - } - } - - except Exception as e: - logger.error(f"Error generating recovery status report: {e}") - raise HTTPException(status_code=500, detail=f"Recovery status error: {str(e)}") - - -@app.post("/health/recovery/force/{component}") -async def force_recovery(component: str, action: str = "restart_stream"): - """Force recovery action for a specific component.""" - try: - from core.monitoring.recovery import recovery_manager, RecoveryAction - - # Validate action + async def handle_detection(camera_id, stream, frame, websocket, model_tree, persistent_data): try: - recovery_action = RecoveryAction(action) - except ValueError: - raise HTTPException( - status_code=400, - detail=f"Invalid recovery action: {action}. Valid actions: {[a.value for a in RecoveryAction]}" - ) + # Apply crop if specified + cropped_frame = frame + if all(coord is not None for coord in [stream.get("cropX1"), stream.get("cropY1"), stream.get("cropX2"), stream.get("cropY2")]): + cropX1, cropY1, cropX2, cropY2 = stream["cropX1"], stream["cropY1"], stream["cropX2"], stream["cropY2"] + cropped_frame = frame[cropY1:cropY2, cropX1:cropX2] + logger.debug(f"Applied crop coordinates ({cropX1}, {cropY1}, {cropX2}, {cropY2}) to frame for camera {camera_id}") + + logger.debug(f"Processing frame for camera {camera_id} with model {stream['modelId']}") + start_time = time.time() + + # Extract display identifier for session ID lookup + subscription_parts = stream["subscriptionIdentifier"].split(';') + display_identifier = subscription_parts[0] if subscription_parts else None + session_id = session_ids.get(display_identifier) if display_identifier else None + + # Create context for pipeline execution + pipeline_context = { + "camera_id": camera_id, + "display_id": display_identifier, + "session_id": session_id + } + + detection_result = run_pipeline(cropped_frame, model_tree, context=pipeline_context) + process_time = (time.time() - start_time) * 1000 + logger.debug(f"Detection for camera {camera_id} completed in {process_time:.2f}ms") + + # Log the raw detection result for debugging + logger.debug(f"Raw detection result for camera {camera_id}:\n{json.dumps(detection_result, indent=2, default=str)}") + + # Direct class result (no detections/classifications structure) + if detection_result and isinstance(detection_result, dict) and "class" in detection_result and "confidence" in detection_result: + highest_confidence_detection = { + "class": detection_result.get("class", "none"), + "confidence": detection_result.get("confidence", 1.0), + "box": [0, 0, 0, 0] # Empty bounding box for classifications + } + # Handle case when no detections found or result is empty + elif not detection_result or not detection_result.get("detections"): + # Check if we have classification results + if detection_result and detection_result.get("classifications"): + # Get the highest confidence classification + classifications = detection_result.get("classifications", []) + highest_confidence_class = max(classifications, key=lambda x: x.get("confidence", 0)) if classifications else None + + if highest_confidence_class: + highest_confidence_detection = { + "class": highest_confidence_class.get("class", "none"), + "confidence": highest_confidence_class.get("confidence", 1.0), + "box": [0, 0, 0, 0] # Empty bounding box for classifications + } + else: + highest_confidence_detection = { + "class": "none", + "confidence": 1.0, + "box": [0, 0, 0, 0] + } + else: + highest_confidence_detection = { + "class": "none", + "confidence": 1.0, + "box": [0, 0, 0, 0] + } + else: + # Find detection with highest confidence + detections = detection_result.get("detections", []) + highest_confidence_detection = max(detections, key=lambda x: x.get("confidence", 0)) if detections else { + "class": "none", + "confidence": 1.0, + "box": [0, 0, 0, 0] + } + + # Convert detection format to match protocol - flatten detection attributes + detection_dict = {} + + # Handle different detection result formats + if isinstance(highest_confidence_detection, dict): + # Copy all fields from the detection result + for key, value in highest_confidence_detection.items(): + if key not in ["box", "id"]: # Skip internal fields + detection_dict[key] = value + + detection_data = { + "type": "imageDetection", + "subscriptionIdentifier": stream["subscriptionIdentifier"], + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S.%fZ", time.gmtime()), + "data": { + "detection": detection_dict, + "modelId": stream["modelId"], + "modelName": stream["modelName"] + } + } + + # Add session ID if available + if session_id is not None: + detection_data["sessionId"] = session_id + + if highest_confidence_detection["class"] != "none": + logger.info(f"Camera {camera_id}: Detected {highest_confidence_detection['class']} with confidence {highest_confidence_detection['confidence']:.2f} using model {stream['modelName']}") + + # Log session ID if available + if session_id: + logger.debug(f"Detection associated with session ID: {session_id}") + + await websocket.send_json(detection_data) + logger.debug(f"Sent detection data to client for camera {camera_id}") + return persistent_data + except Exception as e: + logger.error(f"Error in handle_detection for camera {camera_id}: {str(e)}", exc_info=True) + return persistent_data - # Force recovery - success = recovery_manager.force_recovery(component, recovery_action, "manual_api_request") + def frame_reader(camera_id, cap, buffer, stop_event): + retries = 0 + logger.info(f"Starting frame reader thread for camera {camera_id}") + frame_count = 0 + last_log_time = time.time() + + try: + # Log initial camera status and properties + if cap.isOpened(): + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = cap.get(cv2.CAP_PROP_FPS) + logger.info(f"Camera {camera_id} opened successfully with resolution {width}x{height}, FPS: {fps}") + else: + logger.error(f"Camera {camera_id} failed to open initially") + + while not stop_event.is_set(): + try: + if not cap.isOpened(): + logger.error(f"Camera {camera_id} is not open before trying to read") + # Attempt to reopen + cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"]) + time.sleep(reconnect_interval) + continue + + logger.debug(f"Attempting to read frame from camera {camera_id}") + ret, frame = cap.read() + + if not ret: + logger.warning(f"Connection lost for camera: {camera_id}, retry {retries+1}/{max_retries}") + cap.release() + time.sleep(reconnect_interval) + retries += 1 + if retries > max_retries and max_retries != -1: + logger.error(f"Max retries reached for camera: {camera_id}, stopping frame reader") + break + # Re-open + logger.info(f"Attempting to reopen RTSP stream for camera: {camera_id}") + cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"]) + if not cap.isOpened(): + logger.error(f"Failed to reopen RTSP stream for camera: {camera_id}") + continue + logger.info(f"Successfully reopened RTSP stream for camera: {camera_id}") + continue + + # Successfully read a frame + frame_count += 1 + current_time = time.time() + # Log frame stats every 5 seconds + if current_time - last_log_time > 5: + logger.info(f"Camera {camera_id}: Read {frame_count} frames in the last {current_time - last_log_time:.1f} seconds") + frame_count = 0 + last_log_time = current_time + + logger.debug(f"Successfully read frame from camera {camera_id}, shape: {frame.shape}") + retries = 0 + + # Overwrite old frame if buffer is full + if not buffer.empty(): + try: + buffer.get_nowait() + logger.debug(f"[frame_reader] Removed old frame from buffer for camera {camera_id}") + except queue.Empty: + pass + buffer.put(frame) + logger.debug(f"[frame_reader] Added new frame to buffer for camera {camera_id}. Buffer size: {buffer.qsize()}") + + # Short sleep to avoid CPU overuse + time.sleep(0.01) + + except cv2.error as e: + logger.error(f"OpenCV error for camera {camera_id}: {e}", exc_info=True) + cap.release() + time.sleep(reconnect_interval) + retries += 1 + if retries > max_retries and max_retries != -1: + logger.error(f"Max retries reached after OpenCV error for camera {camera_id}") + break + logger.info(f"Attempting to reopen RTSP stream after OpenCV error for camera: {camera_id}") + cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"]) + if not cap.isOpened(): + logger.error(f"Failed to reopen RTSP stream for camera {camera_id} after OpenCV error") + continue + logger.info(f"Successfully reopened RTSP stream after OpenCV error for camera: {camera_id}") + except Exception as e: + logger.error(f"Unexpected error for camera {camera_id}: {str(e)}", exc_info=True) + cap.release() + break + except Exception as e: + logger.error(f"Error in frame_reader thread for camera {camera_id}: {str(e)}", exc_info=True) + finally: + logger.info(f"Frame reader thread for camera {camera_id} is exiting") + if cap and cap.isOpened(): + cap.release() - return { - "timestamp": time.time(), - "component": component, - "action": action, - "success": success, - "message": f"Recovery {'successful' if success else 'failed'} for component {component}" - } + def snapshot_reader(camera_id, snapshot_url, snapshot_interval, buffer, stop_event): + """Frame reader that fetches snapshots from HTTP/HTTPS URL at specified intervals""" + retries = 0 + logger.info(f"Starting snapshot reader thread for camera {camera_id} from {snapshot_url}") + frame_count = 0 + last_log_time = time.time() + + try: + interval_seconds = snapshot_interval / 1000.0 # Convert milliseconds to seconds + logger.info(f"Snapshot interval for camera {camera_id}: {interval_seconds}s") + + while not stop_event.is_set(): + try: + start_time = time.time() + frame = fetch_snapshot(snapshot_url) + + if frame is None: + logger.warning(f"Failed to fetch snapshot for camera: {camera_id}, retry {retries+1}/{max_retries}") + retries += 1 + if retries > max_retries and max_retries != -1: + logger.error(f"Max retries reached for snapshot camera: {camera_id}, stopping reader") + break + time.sleep(min(interval_seconds, reconnect_interval)) + continue + + # Successfully fetched a frame + frame_count += 1 + current_time = time.time() + # Log frame stats every 5 seconds + if current_time - last_log_time > 5: + logger.info(f"Camera {camera_id}: Fetched {frame_count} snapshots in the last {current_time - last_log_time:.1f} seconds") + frame_count = 0 + last_log_time = current_time + + logger.debug(f"Successfully fetched snapshot from camera {camera_id}, shape: {frame.shape}") + retries = 0 + + # Overwrite old frame if buffer is full + if not buffer.empty(): + try: + buffer.get_nowait() + logger.debug(f"[snapshot_reader] Removed old snapshot from buffer for camera {camera_id}") + except queue.Empty: + pass + buffer.put(frame) + logger.debug(f"[snapshot_reader] Added new snapshot to buffer for camera {camera_id}. Buffer size: {buffer.qsize()}") + + # Wait for the specified interval + elapsed = time.time() - start_time + sleep_time = max(interval_seconds - elapsed, 0) + if sleep_time > 0: + time.sleep(sleep_time) + + except Exception as e: + logger.error(f"Unexpected error fetching snapshot for camera {camera_id}: {str(e)}", exc_info=True) + retries += 1 + if retries > max_retries and max_retries != -1: + logger.error(f"Max retries reached after error for snapshot camera {camera_id}") + break + time.sleep(min(interval_seconds, reconnect_interval)) + except Exception as e: + logger.error(f"Error in snapshot_reader thread for camera {camera_id}: {str(e)}", exc_info=True) + finally: + logger.info(f"Snapshot reader thread for camera {camera_id} is exiting") - except HTTPException: - raise - except Exception as e: - logger.error(f"Error forcing recovery for {component}: {e}") - raise HTTPException(status_code=500, detail=f"Recovery error: {str(e)}") + async def process_streams(): + logger.info("Started processing streams") + try: + while True: + start_time = time.time() + with streams_lock: + current_streams = list(streams.items()) + if current_streams: + logger.debug(f"Processing {len(current_streams)} active streams") + else: + logger.debug("No active streams to process") + + for camera_id, stream in current_streams: + buffer = stream["buffer"] + if buffer.empty(): + logger.debug(f"Frame buffer is empty for camera {camera_id}") + continue + + logger.debug(f"Got frame from buffer for camera {camera_id}") + frame = buffer.get() + + # Cache the frame for REST API access + latest_frames[camera_id] = frame.copy() + logger.debug(f"Cached frame for REST API access for camera {camera_id}") + + with models_lock: + model_tree = models.get(camera_id, {}).get(stream["modelId"]) + if not model_tree: + logger.warning(f"Model not found for camera {camera_id}, modelId {stream['modelId']}") + continue + logger.debug(f"Found model tree for camera {camera_id}, modelId {stream['modelId']}") + + key = (camera_id, stream["modelId"]) + persistent_data = persistent_data_dict.get(key, {}) + logger.debug(f"Starting detection for camera {camera_id} with modelId {stream['modelId']}") + updated_persistent_data = await handle_detection( + camera_id, stream, frame, websocket, model_tree, persistent_data + ) + persistent_data_dict[key] = updated_persistent_data + + elapsed_time = (time.time() - start_time) * 1000 # ms + sleep_time = max(poll_interval - elapsed_time, 0) + logger.debug(f"Frame processing cycle: {elapsed_time:.2f}ms, sleeping for: {sleep_time:.2f}ms") + await asyncio.sleep(sleep_time / 1000.0) + except asyncio.CancelledError: + logger.info("Stream processing task cancelled") + except Exception as e: + logger.error(f"Error in process_streams: {str(e)}", exc_info=True) + async def send_heartbeat(): + while True: + try: + cpu_usage = psutil.cpu_percent() + memory_usage = psutil.virtual_memory().percent + if torch.cuda.is_available(): + gpu_usage = torch.cuda.utilization() if hasattr(torch.cuda, 'utilization') else None + gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) + else: + gpu_usage = None + gpu_memory_usage = None -@app.get("/health/metrics") -async def health_metrics(): - """Performance and health metrics in a format suitable for monitoring systems.""" + camera_connections = [ + { + "subscriptionIdentifier": stream["subscriptionIdentifier"], + "modelId": stream["modelId"], + "modelName": stream["modelName"], + "online": True, + **{k: v for k, v in get_crop_coords(stream).items() if v is not None} + } + for camera_id, stream in streams.items() + ] + + state_report = { + "type": "stateReport", + "cpuUsage": cpu_usage, + "memoryUsage": memory_usage, + "gpuUsage": gpu_usage, + "gpuMemoryUsage": gpu_memory_usage, + "cameraConnections": camera_connections + } + await websocket.send_text(json.dumps(state_report)) + logger.debug(f"Sent stateReport as heartbeat: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, {len(camera_connections)} active cameras") + await asyncio.sleep(HEARTBEAT_INTERVAL) + except Exception as e: + logger.error(f"Error sending stateReport heartbeat: {e}") + break + + async def on_message(): + while True: + try: + msg = await websocket.receive_text() + logger.debug(f"Received message: {msg}") + data = json.loads(msg) + msg_type = data.get("type") + + if msg_type == "subscribe": + payload = data.get("payload", {}) + subscriptionIdentifier = payload.get("subscriptionIdentifier") + rtsp_url = payload.get("rtspUrl") + snapshot_url = payload.get("snapshotUrl") + snapshot_interval = payload.get("snapshotInterval") + model_url = payload.get("modelUrl") + modelId = payload.get("modelId") + modelName = payload.get("modelName") + cropX1 = payload.get("cropX1") + cropY1 = payload.get("cropY1") + cropX2 = payload.get("cropX2") + cropY2 = payload.get("cropY2") + + # Extract camera_id from subscriptionIdentifier (format: displayIdentifier;cameraIdentifier) + parts = subscriptionIdentifier.split(';') + if len(parts) != 2: + logger.error(f"Invalid subscriptionIdentifier format: {subscriptionIdentifier}") + continue + + display_identifier, camera_identifier = parts + camera_id = subscriptionIdentifier # Use full subscriptionIdentifier as camera_id for mapping + + if model_url: + with models_lock: + if (camera_id not in models) or (modelId not in models[camera_id]): + logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}") + extraction_dir = os.path.join("models", camera_identifier, str(modelId)) + os.makedirs(extraction_dir, exist_ok=True) + # If model_url is remote, download it first. + parsed = urlparse(model_url) + if parsed.scheme in ("http", "https"): + logger.info(f"Downloading remote .mpta file from {model_url}") + filename = os.path.basename(parsed.path) or f"model_{modelId}.mpta" + local_mpta = os.path.join(extraction_dir, filename) + logger.debug(f"Download destination: {local_mpta}") + local_path = download_mpta(model_url, local_mpta) + if not local_path: + logger.error(f"Failed to download the remote .mpta file from {model_url}") + error_response = { + "type": "error", + "subscriptionIdentifier": subscriptionIdentifier, + "error": f"Failed to download model from {model_url}" + } + await websocket.send_json(error_response) + continue + model_tree = load_pipeline_from_zip(local_path, extraction_dir) + else: + logger.info(f"Loading local .mpta file from {model_url}") + # Check if file exists before attempting to load + if not os.path.exists(model_url): + logger.error(f"Local .mpta file not found: {model_url}") + logger.debug(f"Current working directory: {os.getcwd()}") + error_response = { + "type": "error", + "subscriptionIdentifier": subscriptionIdentifier, + "error": f"Model file not found: {model_url}" + } + await websocket.send_json(error_response) + continue + model_tree = load_pipeline_from_zip(model_url, extraction_dir) + if model_tree is None: + logger.error(f"Failed to load model {modelId} from .mpta file for camera {camera_id}") + error_response = { + "type": "error", + "subscriptionIdentifier": subscriptionIdentifier, + "error": f"Failed to load model {modelId}" + } + await websocket.send_json(error_response) + continue + if camera_id not in models: + models[camera_id] = {} + models[camera_id][modelId] = model_tree + logger.info(f"Successfully loaded model {modelId} for camera {camera_id}") + logger.debug(f"Model extraction directory: {extraction_dir}") + if camera_id and (rtsp_url or snapshot_url): + with streams_lock: + # Determine camera URL for shared stream management + camera_url = snapshot_url if snapshot_url else rtsp_url + + if camera_id not in streams and len(streams) < max_streams: + # Check if we already have a stream for this camera URL + shared_stream = camera_streams.get(camera_url) + + if shared_stream: + # Reuse existing stream + logger.info(f"Reusing existing stream for camera URL: {camera_url}") + buffer = shared_stream["buffer"] + stop_event = shared_stream["stop_event"] + thread = shared_stream["thread"] + mode = shared_stream["mode"] + + # Increment reference count + shared_stream["ref_count"] = shared_stream.get("ref_count", 0) + 1 + else: + # Create new stream + buffer = queue.Queue(maxsize=1) + stop_event = threading.Event() + + if snapshot_url and snapshot_interval: + logger.info(f"Creating new snapshot stream for camera {camera_id}: {snapshot_url}") + thread = threading.Thread(target=snapshot_reader, args=(camera_id, snapshot_url, snapshot_interval, buffer, stop_event)) + thread.daemon = True + thread.start() + mode = "snapshot" + + # Store shared stream info + shared_stream = { + "buffer": buffer, + "thread": thread, + "stop_event": stop_event, + "mode": mode, + "url": snapshot_url, + "snapshot_interval": snapshot_interval, + "ref_count": 1 + } + camera_streams[camera_url] = shared_stream + + elif rtsp_url: + logger.info(f"Creating new RTSP stream for camera {camera_id}: {rtsp_url}") + cap = cv2.VideoCapture(rtsp_url) + if not cap.isOpened(): + logger.error(f"Failed to open RTSP stream for camera {camera_id}") + continue + thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event)) + thread.daemon = True + thread.start() + mode = "rtsp" + + # Store shared stream info + shared_stream = { + "buffer": buffer, + "thread": thread, + "stop_event": stop_event, + "mode": mode, + "url": rtsp_url, + "cap": cap, + "ref_count": 1 + } + camera_streams[camera_url] = shared_stream + else: + logger.error(f"No valid URL provided for camera {camera_id}") + continue + + # Create stream info for this subscription + stream_info = { + "buffer": buffer, + "thread": thread, + "stop_event": stop_event, + "modelId": modelId, + "modelName": modelName, + "subscriptionIdentifier": subscriptionIdentifier, + "cropX1": cropX1, + "cropY1": cropY1, + "cropX2": cropX2, + "cropY2": cropY2, + "mode": mode, + "camera_url": camera_url + } + + if mode == "snapshot": + stream_info["snapshot_url"] = snapshot_url + stream_info["snapshot_interval"] = snapshot_interval + elif mode == "rtsp": + stream_info["rtsp_url"] = rtsp_url + stream_info["cap"] = shared_stream["cap"] + + streams[camera_id] = stream_info + subscription_to_camera[camera_id] = camera_url + + elif camera_id and camera_id in streams: + # If already subscribed, unsubscribe first + logger.info(f"Resubscribing to camera {camera_id}") + # Note: Keep models in memory for reuse across subscriptions + elif msg_type == "unsubscribe": + payload = data.get("payload", {}) + subscriptionIdentifier = payload.get("subscriptionIdentifier") + camera_id = subscriptionIdentifier + with streams_lock: + if camera_id and camera_id in streams: + stream = streams.pop(camera_id) + camera_url = subscription_to_camera.pop(camera_id, None) + + if camera_url and camera_url in camera_streams: + shared_stream = camera_streams[camera_url] + shared_stream["ref_count"] -= 1 + + # If no more references, stop the shared stream + if shared_stream["ref_count"] <= 0: + logger.info(f"Stopping shared stream for camera URL: {camera_url}") + shared_stream["stop_event"].set() + shared_stream["thread"].join() + if "cap" in shared_stream: + shared_stream["cap"].release() + del camera_streams[camera_url] + else: + logger.info(f"Shared stream for {camera_url} still has {shared_stream['ref_count']} references") + + # Clean up cached frame + latest_frames.pop(camera_id, None) + logger.info(f"Unsubscribed from camera {camera_id}") + # Note: Keep models in memory for potential reuse + elif msg_type == "requestState": + cpu_usage = psutil.cpu_percent() + memory_usage = psutil.virtual_memory().percent + if torch.cuda.is_available(): + gpu_usage = torch.cuda.utilization() if hasattr(torch.cuda, 'utilization') else None + gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) + else: + gpu_usage = None + gpu_memory_usage = None + + camera_connections = [ + { + "subscriptionIdentifier": stream["subscriptionIdentifier"], + "modelId": stream["modelId"], + "modelName": stream["modelName"], + "online": True, + **{k: v for k, v in get_crop_coords(stream).items() if v is not None} + } + for camera_id, stream in streams.items() + ] + + state_report = { + "type": "stateReport", + "cpuUsage": cpu_usage, + "memoryUsage": memory_usage, + "gpuUsage": gpu_usage, + "gpuMemoryUsage": gpu_memory_usage, + "cameraConnections": camera_connections + } + await websocket.send_text(json.dumps(state_report)) + + elif msg_type == "setSessionId": + payload = data.get("payload", {}) + display_identifier = payload.get("displayIdentifier") + session_id = payload.get("sessionId") + + if display_identifier: + # Store session ID for this display + if session_id is None: + session_ids.pop(display_identifier, None) + logger.info(f"Cleared session ID for display {display_identifier}") + else: + session_ids[display_identifier] = session_id + logger.info(f"Set session ID {session_id} for display {display_identifier}") + + elif msg_type == "patchSession": + session_id = data.get("sessionId") + patch_data = data.get("data", {}) + + # For now, just acknowledge the patch - actual implementation depends on backend requirements + response = { + "type": "patchSessionResult", + "payload": { + "sessionId": session_id, + "success": True, + "message": "Session patch acknowledged" + } + } + await websocket.send_json(response) + logger.info(f"Acknowledged patch for session {session_id}") + + else: + logger.error(f"Unknown message type: {msg_type}") + except json.JSONDecodeError: + logger.error("Received invalid JSON message") + except (WebSocketDisconnect, ConnectionClosedError) as e: + logger.warning(f"WebSocket disconnected: {e}") + break + except Exception as e: + logger.error(f"Error handling message: {e}") + break try: - from core.monitoring.health import health_monitor - from core.monitoring.stream_health import stream_health_tracker - from core.streaming.buffers import shared_cache_buffer - - # Get basic metrics - overall_health = health_monitor.get_health_status() - stream_metrics = stream_health_tracker.get_all_metrics() - buffer_stats = shared_cache_buffer.get_stats() - - # Format for monitoring systems (Prometheus-style) - metrics = { - "detector_worker_up": 1, - "detector_worker_streams_total": len(stream_metrics), - "detector_worker_subscriptions_total": len(worker_state.subscriptions), - "detector_worker_sessions_total": len(worker_state.session_ids), - "detector_worker_memory_mb": buffer_stats.get("total_memory_mb", 0), - "detector_worker_health_status": { - "healthy": 1, - "warning": 2, - "critical": 3, - "unknown": 4 - }.get(overall_health.get("overall_status", "unknown"), 4) - } - - # Add per-stream metrics - for camera_id, stream_info in stream_metrics.items(): - safe_camera_id = camera_id.replace("-", "_").replace(".", "_") - metrics.update({ - f"detector_worker_stream_frames_total{{camera=\"{safe_camera_id}\"}}": stream_info.get("frame_count", 0), - f"detector_worker_stream_errors_total{{camera=\"{safe_camera_id}\"}}": stream_info.get("error_count", 0), - f"detector_worker_stream_fps{{camera=\"{safe_camera_id}\"}}": stream_info.get("frames_per_second", 0), - f"detector_worker_stream_frame_age_seconds{{camera=\"{safe_camera_id}\"}}": stream_info.get("last_frame_age_seconds") or 0 - }) - - return { - "timestamp": time.time(), - "metrics": metrics - } - + await websocket.accept() + stream_task = asyncio.create_task(process_streams()) + heartbeat_task = asyncio.create_task(send_heartbeat()) + message_task = asyncio.create_task(on_message()) + await asyncio.gather(heartbeat_task, message_task) except Exception as e: - logger.error(f"Error generating health metrics: {e}") - raise HTTPException(status_code=500, detail=f"Metrics error: {str(e)}") - - - - -if __name__ == "__main__": - import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8001) \ No newline at end of file + logger.error(f"Error in detect websocket: {e}") + finally: + stream_task.cancel() + await stream_task + with streams_lock: + # Clean up shared camera streams + for camera_url, shared_stream in camera_streams.items(): + shared_stream["stop_event"].set() + shared_stream["thread"].join() + if "cap" in shared_stream: + shared_stream["cap"].release() + while not shared_stream["buffer"].empty(): + try: + shared_stream["buffer"].get_nowait() + except queue.Empty: + pass + logger.info(f"Released shared camera stream for {camera_url}") + + streams.clear() + camera_streams.clear() + subscription_to_camera.clear() + with models_lock: + models.clear() + latest_frames.clear() + session_ids.clear() + logger.info("WebSocket connection closed") diff --git a/archive/app.py b/archive/app.py deleted file mode 100644 index 09cb227..0000000 --- a/archive/app.py +++ /dev/null @@ -1,903 +0,0 @@ -from typing import Any, Dict -import os -import json -import time -import queue -import torch -import cv2 -import numpy as np -import base64 -import logging -import threading -import requests -import asyncio -import psutil -import zipfile -from urllib.parse import urlparse -from fastapi import FastAPI, WebSocket, HTTPException -from fastapi.websockets import WebSocketDisconnect -from fastapi.responses import Response -from websockets.exceptions import ConnectionClosedError -from ultralytics import YOLO - -# Import shared pipeline functions -from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline - -app = FastAPI() - -# Global dictionaries to keep track of models and streams -# "models" now holds a nested dict: { camera_id: { modelId: model_tree } } -models: Dict[str, Dict[str, Any]] = {} -streams: Dict[str, Dict[str, Any]] = {} -# Store session IDs per display -session_ids: Dict[str, int] = {} -# Track shared camera streams by camera URL -camera_streams: Dict[str, Dict[str, Any]] = {} -# Map subscriptions to their camera URL -subscription_to_camera: Dict[str, str] = {} -# Store latest frames for REST API access (separate from processing buffer) -latest_frames: Dict[str, Any] = {} - -with open("config.json", "r") as f: - config = json.load(f) - -poll_interval = config.get("poll_interval_ms", 100) -reconnect_interval = config.get("reconnect_interval_sec", 5) -TARGET_FPS = config.get("target_fps", 10) -poll_interval = 1000 / TARGET_FPS -logging.info(f"Poll interval: {poll_interval}ms") -max_streams = config.get("max_streams", 5) -max_retries = config.get("max_retries", 3) - -# Configure logging -logging.basicConfig( - level=logging.INFO, # Set to INFO level for less verbose output - format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", - handlers=[ - logging.FileHandler("detector_worker.log"), # Write logs to a file - logging.StreamHandler() # Also output to console - ] -) - -# Create a logger specifically for this application -logger = logging.getLogger("detector_worker") -logger.setLevel(logging.DEBUG) # Set app-specific logger to DEBUG level - -# Ensure all other libraries (including root) use at least INFO level -logging.getLogger().setLevel(logging.INFO) - -logger.info("Starting detector worker application") -logger.info(f"Configuration: Target FPS: {TARGET_FPS}, Max streams: {max_streams}, Max retries: {max_retries}") - -# Ensure the models directory exists -os.makedirs("models", exist_ok=True) -logger.info("Ensured models directory exists") - -# Constants for heartbeat and timeouts -HEARTBEAT_INTERVAL = 2 # seconds -WORKER_TIMEOUT_MS = 10000 -logger.debug(f"Heartbeat interval set to {HEARTBEAT_INTERVAL} seconds") - -# Locks for thread-safe operations -streams_lock = threading.Lock() -models_lock = threading.Lock() -logger.debug("Initialized thread locks") - -# Add helper to download mpta ZIP file from a remote URL -def download_mpta(url: str, dest_path: str) -> str: - try: - logger.info(f"Starting download of model from {url} to {dest_path}") - os.makedirs(os.path.dirname(dest_path), exist_ok=True) - response = requests.get(url, stream=True) - if response.status_code == 200: - file_size = int(response.headers.get('content-length', 0)) - logger.info(f"Model file size: {file_size/1024/1024:.2f} MB") - downloaded = 0 - with open(dest_path, "wb") as f: - for chunk in response.iter_content(chunk_size=8192): - f.write(chunk) - downloaded += len(chunk) - if file_size > 0 and downloaded % (file_size // 10) < 8192: # Log approximately every 10% - logger.debug(f"Download progress: {downloaded/file_size*100:.1f}%") - logger.info(f"Successfully downloaded mpta file from {url} to {dest_path}") - return dest_path - else: - logger.error(f"Failed to download mpta file (status code {response.status_code}): {response.text}") - return None - except Exception as e: - logger.error(f"Exception downloading mpta file from {url}: {str(e)}", exc_info=True) - return None - -# Add helper to fetch snapshot image from HTTP/HTTPS URL -def fetch_snapshot(url: str): - try: - from requests.auth import HTTPBasicAuth, HTTPDigestAuth - - # Parse URL to extract credentials - parsed = urlparse(url) - - # Prepare headers - some cameras require User-Agent - headers = { - 'User-Agent': 'Mozilla/5.0 (compatible; DetectorWorker/1.0)' - } - - # Reconstruct URL without credentials - clean_url = f"{parsed.scheme}://{parsed.hostname}" - if parsed.port: - clean_url += f":{parsed.port}" - clean_url += parsed.path - if parsed.query: - clean_url += f"?{parsed.query}" - - auth = None - if parsed.username and parsed.password: - # Try HTTP Digest authentication first (common for IP cameras) - try: - auth = HTTPDigestAuth(parsed.username, parsed.password) - response = requests.get(clean_url, auth=auth, headers=headers, timeout=10) - if response.status_code == 200: - logger.debug(f"Successfully authenticated using HTTP Digest for {clean_url}") - elif response.status_code == 401: - # If Digest fails, try Basic auth - logger.debug(f"HTTP Digest failed, trying Basic auth for {clean_url}") - auth = HTTPBasicAuth(parsed.username, parsed.password) - response = requests.get(clean_url, auth=auth, headers=headers, timeout=10) - if response.status_code == 200: - logger.debug(f"Successfully authenticated using HTTP Basic for {clean_url}") - except Exception as auth_error: - logger.debug(f"Authentication setup error: {auth_error}") - # Fallback to original URL with embedded credentials - response = requests.get(url, headers=headers, timeout=10) - else: - # No credentials in URL, make request as-is - response = requests.get(url, headers=headers, timeout=10) - - if response.status_code == 200: - # Convert response content to numpy array - nparr = np.frombuffer(response.content, np.uint8) - # Decode image - frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) - if frame is not None: - logger.debug(f"Successfully fetched snapshot from {clean_url}, shape: {frame.shape}") - return frame - else: - logger.error(f"Failed to decode image from snapshot URL: {clean_url}") - return None - else: - logger.error(f"Failed to fetch snapshot (status code {response.status_code}): {clean_url}") - return None - except Exception as e: - logger.error(f"Exception fetching snapshot from {url}: {str(e)}") - return None - -# Helper to get crop coordinates from stream -def get_crop_coords(stream): - return { - "cropX1": stream.get("cropX1"), - "cropY1": stream.get("cropY1"), - "cropX2": stream.get("cropX2"), - "cropY2": stream.get("cropY2") - } - -#################################################### -# REST API endpoint for image retrieval -#################################################### -@app.get("/camera/{camera_id}/image") -async def get_camera_image(camera_id: str): - """ - Get the current frame from a camera as JPEG image - """ - try: - # URL decode the camera_id to handle encoded characters like %3B for semicolon - from urllib.parse import unquote - original_camera_id = camera_id - camera_id = unquote(camera_id) - logger.debug(f"REST API request: original='{original_camera_id}', decoded='{camera_id}'") - - with streams_lock: - if camera_id not in streams: - logger.warning(f"Camera ID '{camera_id}' not found in streams. Current streams: {list(streams.keys())}") - raise HTTPException(status_code=404, detail=f"Camera {camera_id} not found or not active") - - # Check if we have a cached frame for this camera - if camera_id not in latest_frames: - logger.warning(f"No cached frame available for camera '{camera_id}'.") - raise HTTPException(status_code=404, detail=f"No frame available for camera {camera_id}") - - frame = latest_frames[camera_id] - logger.debug(f"Retrieved cached frame for camera '{camera_id}', frame shape: {frame.shape}") - # Encode frame as JPEG - success, buffer_img = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 85]) - if not success: - raise HTTPException(status_code=500, detail="Failed to encode image as JPEG") - - # Return image as binary response - return Response(content=buffer_img.tobytes(), media_type="image/jpeg") - - except HTTPException: - raise - except Exception as e: - logger.error(f"Error retrieving image for camera {camera_id}: {str(e)}", exc_info=True) - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - -#################################################### -# Detection and frame processing functions -#################################################### -@app.websocket("/") -async def detect(websocket: WebSocket): - logger.info("WebSocket connection accepted") - persistent_data_dict = {} - - async def handle_detection(camera_id, stream, frame, websocket, model_tree, persistent_data): - try: - # Apply crop if specified - cropped_frame = frame - if all(coord is not None for coord in [stream.get("cropX1"), stream.get("cropY1"), stream.get("cropX2"), stream.get("cropY2")]): - cropX1, cropY1, cropX2, cropY2 = stream["cropX1"], stream["cropY1"], stream["cropX2"], stream["cropY2"] - cropped_frame = frame[cropY1:cropY2, cropX1:cropX2] - logger.debug(f"Applied crop coordinates ({cropX1}, {cropY1}, {cropX2}, {cropY2}) to frame for camera {camera_id}") - - logger.debug(f"Processing frame for camera {camera_id} with model {stream['modelId']}") - start_time = time.time() - - # Extract display identifier for session ID lookup - subscription_parts = stream["subscriptionIdentifier"].split(';') - display_identifier = subscription_parts[0] if subscription_parts else None - session_id = session_ids.get(display_identifier) if display_identifier else None - - # Create context for pipeline execution - pipeline_context = { - "camera_id": camera_id, - "display_id": display_identifier, - "session_id": session_id - } - - detection_result = run_pipeline(cropped_frame, model_tree, context=pipeline_context) - process_time = (time.time() - start_time) * 1000 - logger.debug(f"Detection for camera {camera_id} completed in {process_time:.2f}ms") - - # Log the raw detection result for debugging - logger.debug(f"Raw detection result for camera {camera_id}:\n{json.dumps(detection_result, indent=2, default=str)}") - - # Direct class result (no detections/classifications structure) - if detection_result and isinstance(detection_result, dict) and "class" in detection_result and "confidence" in detection_result: - highest_confidence_detection = { - "class": detection_result.get("class", "none"), - "confidence": detection_result.get("confidence", 1.0), - "box": [0, 0, 0, 0] # Empty bounding box for classifications - } - # Handle case when no detections found or result is empty - elif not detection_result or not detection_result.get("detections"): - # Check if we have classification results - if detection_result and detection_result.get("classifications"): - # Get the highest confidence classification - classifications = detection_result.get("classifications", []) - highest_confidence_class = max(classifications, key=lambda x: x.get("confidence", 0)) if classifications else None - - if highest_confidence_class: - highest_confidence_detection = { - "class": highest_confidence_class.get("class", "none"), - "confidence": highest_confidence_class.get("confidence", 1.0), - "box": [0, 0, 0, 0] # Empty bounding box for classifications - } - else: - highest_confidence_detection = { - "class": "none", - "confidence": 1.0, - "box": [0, 0, 0, 0] - } - else: - highest_confidence_detection = { - "class": "none", - "confidence": 1.0, - "box": [0, 0, 0, 0] - } - else: - # Find detection with highest confidence - detections = detection_result.get("detections", []) - highest_confidence_detection = max(detections, key=lambda x: x.get("confidence", 0)) if detections else { - "class": "none", - "confidence": 1.0, - "box": [0, 0, 0, 0] - } - - # Convert detection format to match protocol - flatten detection attributes - detection_dict = {} - - # Handle different detection result formats - if isinstance(highest_confidence_detection, dict): - # Copy all fields from the detection result - for key, value in highest_confidence_detection.items(): - if key not in ["box", "id"]: # Skip internal fields - detection_dict[key] = value - - detection_data = { - "type": "imageDetection", - "subscriptionIdentifier": stream["subscriptionIdentifier"], - "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S.%fZ", time.gmtime()), - "data": { - "detection": detection_dict, - "modelId": stream["modelId"], - "modelName": stream["modelName"] - } - } - - # Add session ID if available - if session_id is not None: - detection_data["sessionId"] = session_id - - if highest_confidence_detection["class"] != "none": - logger.info(f"Camera {camera_id}: Detected {highest_confidence_detection['class']} with confidence {highest_confidence_detection['confidence']:.2f} using model {stream['modelName']}") - - # Log session ID if available - if session_id: - logger.debug(f"Detection associated with session ID: {session_id}") - - await websocket.send_json(detection_data) - logger.debug(f"Sent detection data to client for camera {camera_id}") - return persistent_data - except Exception as e: - logger.error(f"Error in handle_detection for camera {camera_id}: {str(e)}", exc_info=True) - return persistent_data - - def frame_reader(camera_id, cap, buffer, stop_event): - retries = 0 - logger.info(f"Starting frame reader thread for camera {camera_id}") - frame_count = 0 - last_log_time = time.time() - - try: - # Log initial camera status and properties - if cap.isOpened(): - width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - fps = cap.get(cv2.CAP_PROP_FPS) - logger.info(f"Camera {camera_id} opened successfully with resolution {width}x{height}, FPS: {fps}") - else: - logger.error(f"Camera {camera_id} failed to open initially") - - while not stop_event.is_set(): - try: - if not cap.isOpened(): - logger.error(f"Camera {camera_id} is not open before trying to read") - # Attempt to reopen - cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"]) - time.sleep(reconnect_interval) - continue - - logger.debug(f"Attempting to read frame from camera {camera_id}") - ret, frame = cap.read() - - if not ret: - logger.warning(f"Connection lost for camera: {camera_id}, retry {retries+1}/{max_retries}") - cap.release() - time.sleep(reconnect_interval) - retries += 1 - if retries > max_retries and max_retries != -1: - logger.error(f"Max retries reached for camera: {camera_id}, stopping frame reader") - break - # Re-open - logger.info(f"Attempting to reopen RTSP stream for camera: {camera_id}") - cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"]) - if not cap.isOpened(): - logger.error(f"Failed to reopen RTSP stream for camera: {camera_id}") - continue - logger.info(f"Successfully reopened RTSP stream for camera: {camera_id}") - continue - - # Successfully read a frame - frame_count += 1 - current_time = time.time() - # Log frame stats every 5 seconds - if current_time - last_log_time > 5: - logger.info(f"Camera {camera_id}: Read {frame_count} frames in the last {current_time - last_log_time:.1f} seconds") - frame_count = 0 - last_log_time = current_time - - logger.debug(f"Successfully read frame from camera {camera_id}, shape: {frame.shape}") - retries = 0 - - # Overwrite old frame if buffer is full - if not buffer.empty(): - try: - buffer.get_nowait() - logger.debug(f"[frame_reader] Removed old frame from buffer for camera {camera_id}") - except queue.Empty: - pass - buffer.put(frame) - logger.debug(f"[frame_reader] Added new frame to buffer for camera {camera_id}. Buffer size: {buffer.qsize()}") - - # Short sleep to avoid CPU overuse - time.sleep(0.01) - - except cv2.error as e: - logger.error(f"OpenCV error for camera {camera_id}: {e}", exc_info=True) - cap.release() - time.sleep(reconnect_interval) - retries += 1 - if retries > max_retries and max_retries != -1: - logger.error(f"Max retries reached after OpenCV error for camera {camera_id}") - break - logger.info(f"Attempting to reopen RTSP stream after OpenCV error for camera: {camera_id}") - cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"]) - if not cap.isOpened(): - logger.error(f"Failed to reopen RTSP stream for camera {camera_id} after OpenCV error") - continue - logger.info(f"Successfully reopened RTSP stream after OpenCV error for camera: {camera_id}") - except Exception as e: - logger.error(f"Unexpected error for camera {camera_id}: {str(e)}", exc_info=True) - cap.release() - break - except Exception as e: - logger.error(f"Error in frame_reader thread for camera {camera_id}: {str(e)}", exc_info=True) - finally: - logger.info(f"Frame reader thread for camera {camera_id} is exiting") - if cap and cap.isOpened(): - cap.release() - - def snapshot_reader(camera_id, snapshot_url, snapshot_interval, buffer, stop_event): - """Frame reader that fetches snapshots from HTTP/HTTPS URL at specified intervals""" - retries = 0 - logger.info(f"Starting snapshot reader thread for camera {camera_id} from {snapshot_url}") - frame_count = 0 - last_log_time = time.time() - - try: - interval_seconds = snapshot_interval / 1000.0 # Convert milliseconds to seconds - logger.info(f"Snapshot interval for camera {camera_id}: {interval_seconds}s") - - while not stop_event.is_set(): - try: - start_time = time.time() - frame = fetch_snapshot(snapshot_url) - - if frame is None: - logger.warning(f"Failed to fetch snapshot for camera: {camera_id}, retry {retries+1}/{max_retries}") - retries += 1 - if retries > max_retries and max_retries != -1: - logger.error(f"Max retries reached for snapshot camera: {camera_id}, stopping reader") - break - time.sleep(min(interval_seconds, reconnect_interval)) - continue - - # Successfully fetched a frame - frame_count += 1 - current_time = time.time() - # Log frame stats every 5 seconds - if current_time - last_log_time > 5: - logger.info(f"Camera {camera_id}: Fetched {frame_count} snapshots in the last {current_time - last_log_time:.1f} seconds") - frame_count = 0 - last_log_time = current_time - - logger.debug(f"Successfully fetched snapshot from camera {camera_id}, shape: {frame.shape}") - retries = 0 - - # Overwrite old frame if buffer is full - if not buffer.empty(): - try: - buffer.get_nowait() - logger.debug(f"[snapshot_reader] Removed old snapshot from buffer for camera {camera_id}") - except queue.Empty: - pass - buffer.put(frame) - logger.debug(f"[snapshot_reader] Added new snapshot to buffer for camera {camera_id}. Buffer size: {buffer.qsize()}") - - # Wait for the specified interval - elapsed = time.time() - start_time - sleep_time = max(interval_seconds - elapsed, 0) - if sleep_time > 0: - time.sleep(sleep_time) - - except Exception as e: - logger.error(f"Unexpected error fetching snapshot for camera {camera_id}: {str(e)}", exc_info=True) - retries += 1 - if retries > max_retries and max_retries != -1: - logger.error(f"Max retries reached after error for snapshot camera {camera_id}") - break - time.sleep(min(interval_seconds, reconnect_interval)) - except Exception as e: - logger.error(f"Error in snapshot_reader thread for camera {camera_id}: {str(e)}", exc_info=True) - finally: - logger.info(f"Snapshot reader thread for camera {camera_id} is exiting") - - async def process_streams(): - logger.info("Started processing streams") - try: - while True: - start_time = time.time() - with streams_lock: - current_streams = list(streams.items()) - if current_streams: - logger.debug(f"Processing {len(current_streams)} active streams") - else: - logger.debug("No active streams to process") - - for camera_id, stream in current_streams: - buffer = stream["buffer"] - if buffer.empty(): - logger.debug(f"Frame buffer is empty for camera {camera_id}") - continue - - logger.debug(f"Got frame from buffer for camera {camera_id}") - frame = buffer.get() - - # Cache the frame for REST API access - latest_frames[camera_id] = frame.copy() - logger.debug(f"Cached frame for REST API access for camera {camera_id}") - - with models_lock: - model_tree = models.get(camera_id, {}).get(stream["modelId"]) - if not model_tree: - logger.warning(f"Model not found for camera {camera_id}, modelId {stream['modelId']}") - continue - logger.debug(f"Found model tree for camera {camera_id}, modelId {stream['modelId']}") - - key = (camera_id, stream["modelId"]) - persistent_data = persistent_data_dict.get(key, {}) - logger.debug(f"Starting detection for camera {camera_id} with modelId {stream['modelId']}") - updated_persistent_data = await handle_detection( - camera_id, stream, frame, websocket, model_tree, persistent_data - ) - persistent_data_dict[key] = updated_persistent_data - - elapsed_time = (time.time() - start_time) * 1000 # ms - sleep_time = max(poll_interval - elapsed_time, 0) - logger.debug(f"Frame processing cycle: {elapsed_time:.2f}ms, sleeping for: {sleep_time:.2f}ms") - await asyncio.sleep(sleep_time / 1000.0) - except asyncio.CancelledError: - logger.info("Stream processing task cancelled") - except Exception as e: - logger.error(f"Error in process_streams: {str(e)}", exc_info=True) - - async def send_heartbeat(): - while True: - try: - cpu_usage = psutil.cpu_percent() - memory_usage = psutil.virtual_memory().percent - if torch.cuda.is_available(): - gpu_usage = torch.cuda.utilization() if hasattr(torch.cuda, 'utilization') else None - gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) - else: - gpu_usage = None - gpu_memory_usage = None - - camera_connections = [ - { - "subscriptionIdentifier": stream["subscriptionIdentifier"], - "modelId": stream["modelId"], - "modelName": stream["modelName"], - "online": True, - **{k: v for k, v in get_crop_coords(stream).items() if v is not None} - } - for camera_id, stream in streams.items() - ] - - state_report = { - "type": "stateReport", - "cpuUsage": cpu_usage, - "memoryUsage": memory_usage, - "gpuUsage": gpu_usage, - "gpuMemoryUsage": gpu_memory_usage, - "cameraConnections": camera_connections - } - await websocket.send_text(json.dumps(state_report)) - logger.debug(f"Sent stateReport as heartbeat: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, {len(camera_connections)} active cameras") - await asyncio.sleep(HEARTBEAT_INTERVAL) - except Exception as e: - logger.error(f"Error sending stateReport heartbeat: {e}") - break - - async def on_message(): - while True: - try: - msg = await websocket.receive_text() - logger.debug(f"Received message: {msg}") - data = json.loads(msg) - msg_type = data.get("type") - - if msg_type == "subscribe": - payload = data.get("payload", {}) - subscriptionIdentifier = payload.get("subscriptionIdentifier") - rtsp_url = payload.get("rtspUrl") - snapshot_url = payload.get("snapshotUrl") - snapshot_interval = payload.get("snapshotInterval") - model_url = payload.get("modelUrl") - modelId = payload.get("modelId") - modelName = payload.get("modelName") - cropX1 = payload.get("cropX1") - cropY1 = payload.get("cropY1") - cropX2 = payload.get("cropX2") - cropY2 = payload.get("cropY2") - - # Extract camera_id from subscriptionIdentifier (format: displayIdentifier;cameraIdentifier) - parts = subscriptionIdentifier.split(';') - if len(parts) != 2: - logger.error(f"Invalid subscriptionIdentifier format: {subscriptionIdentifier}") - continue - - display_identifier, camera_identifier = parts - camera_id = subscriptionIdentifier # Use full subscriptionIdentifier as camera_id for mapping - - if model_url: - with models_lock: - if (camera_id not in models) or (modelId not in models[camera_id]): - logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}") - extraction_dir = os.path.join("models", camera_identifier, str(modelId)) - os.makedirs(extraction_dir, exist_ok=True) - # If model_url is remote, download it first. - parsed = urlparse(model_url) - if parsed.scheme in ("http", "https"): - logger.info(f"Downloading remote .mpta file from {model_url}") - filename = os.path.basename(parsed.path) or f"model_{modelId}.mpta" - local_mpta = os.path.join(extraction_dir, filename) - logger.debug(f"Download destination: {local_mpta}") - local_path = download_mpta(model_url, local_mpta) - if not local_path: - logger.error(f"Failed to download the remote .mpta file from {model_url}") - error_response = { - "type": "error", - "subscriptionIdentifier": subscriptionIdentifier, - "error": f"Failed to download model from {model_url}" - } - await websocket.send_json(error_response) - continue - model_tree = load_pipeline_from_zip(local_path, extraction_dir) - else: - logger.info(f"Loading local .mpta file from {model_url}") - # Check if file exists before attempting to load - if not os.path.exists(model_url): - logger.error(f"Local .mpta file not found: {model_url}") - logger.debug(f"Current working directory: {os.getcwd()}") - error_response = { - "type": "error", - "subscriptionIdentifier": subscriptionIdentifier, - "error": f"Model file not found: {model_url}" - } - await websocket.send_json(error_response) - continue - model_tree = load_pipeline_from_zip(model_url, extraction_dir) - if model_tree is None: - logger.error(f"Failed to load model {modelId} from .mpta file for camera {camera_id}") - error_response = { - "type": "error", - "subscriptionIdentifier": subscriptionIdentifier, - "error": f"Failed to load model {modelId}" - } - await websocket.send_json(error_response) - continue - if camera_id not in models: - models[camera_id] = {} - models[camera_id][modelId] = model_tree - logger.info(f"Successfully loaded model {modelId} for camera {camera_id}") - logger.debug(f"Model extraction directory: {extraction_dir}") - if camera_id and (rtsp_url or snapshot_url): - with streams_lock: - # Determine camera URL for shared stream management - camera_url = snapshot_url if snapshot_url else rtsp_url - - if camera_id not in streams and len(streams) < max_streams: - # Check if we already have a stream for this camera URL - shared_stream = camera_streams.get(camera_url) - - if shared_stream: - # Reuse existing stream - logger.info(f"Reusing existing stream for camera URL: {camera_url}") - buffer = shared_stream["buffer"] - stop_event = shared_stream["stop_event"] - thread = shared_stream["thread"] - mode = shared_stream["mode"] - - # Increment reference count - shared_stream["ref_count"] = shared_stream.get("ref_count", 0) + 1 - else: - # Create new stream - buffer = queue.Queue(maxsize=1) - stop_event = threading.Event() - - if snapshot_url and snapshot_interval: - logger.info(f"Creating new snapshot stream for camera {camera_id}: {snapshot_url}") - thread = threading.Thread(target=snapshot_reader, args=(camera_id, snapshot_url, snapshot_interval, buffer, stop_event)) - thread.daemon = True - thread.start() - mode = "snapshot" - - # Store shared stream info - shared_stream = { - "buffer": buffer, - "thread": thread, - "stop_event": stop_event, - "mode": mode, - "url": snapshot_url, - "snapshot_interval": snapshot_interval, - "ref_count": 1 - } - camera_streams[camera_url] = shared_stream - - elif rtsp_url: - logger.info(f"Creating new RTSP stream for camera {camera_id}: {rtsp_url}") - cap = cv2.VideoCapture(rtsp_url) - if not cap.isOpened(): - logger.error(f"Failed to open RTSP stream for camera {camera_id}") - continue - thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event)) - thread.daemon = True - thread.start() - mode = "rtsp" - - # Store shared stream info - shared_stream = { - "buffer": buffer, - "thread": thread, - "stop_event": stop_event, - "mode": mode, - "url": rtsp_url, - "cap": cap, - "ref_count": 1 - } - camera_streams[camera_url] = shared_stream - else: - logger.error(f"No valid URL provided for camera {camera_id}") - continue - - # Create stream info for this subscription - stream_info = { - "buffer": buffer, - "thread": thread, - "stop_event": stop_event, - "modelId": modelId, - "modelName": modelName, - "subscriptionIdentifier": subscriptionIdentifier, - "cropX1": cropX1, - "cropY1": cropY1, - "cropX2": cropX2, - "cropY2": cropY2, - "mode": mode, - "camera_url": camera_url - } - - if mode == "snapshot": - stream_info["snapshot_url"] = snapshot_url - stream_info["snapshot_interval"] = snapshot_interval - elif mode == "rtsp": - stream_info["rtsp_url"] = rtsp_url - stream_info["cap"] = shared_stream["cap"] - - streams[camera_id] = stream_info - subscription_to_camera[camera_id] = camera_url - - elif camera_id and camera_id in streams: - # If already subscribed, unsubscribe first - logger.info(f"Resubscribing to camera {camera_id}") - # Note: Keep models in memory for reuse across subscriptions - elif msg_type == "unsubscribe": - payload = data.get("payload", {}) - subscriptionIdentifier = payload.get("subscriptionIdentifier") - camera_id = subscriptionIdentifier - with streams_lock: - if camera_id and camera_id in streams: - stream = streams.pop(camera_id) - camera_url = subscription_to_camera.pop(camera_id, None) - - if camera_url and camera_url in camera_streams: - shared_stream = camera_streams[camera_url] - shared_stream["ref_count"] -= 1 - - # If no more references, stop the shared stream - if shared_stream["ref_count"] <= 0: - logger.info(f"Stopping shared stream for camera URL: {camera_url}") - shared_stream["stop_event"].set() - shared_stream["thread"].join() - if "cap" in shared_stream: - shared_stream["cap"].release() - del camera_streams[camera_url] - else: - logger.info(f"Shared stream for {camera_url} still has {shared_stream['ref_count']} references") - - # Clean up cached frame - latest_frames.pop(camera_id, None) - logger.info(f"Unsubscribed from camera {camera_id}") - # Note: Keep models in memory for potential reuse - elif msg_type == "requestState": - cpu_usage = psutil.cpu_percent() - memory_usage = psutil.virtual_memory().percent - if torch.cuda.is_available(): - gpu_usage = torch.cuda.utilization() if hasattr(torch.cuda, 'utilization') else None - gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) - else: - gpu_usage = None - gpu_memory_usage = None - - camera_connections = [ - { - "subscriptionIdentifier": stream["subscriptionIdentifier"], - "modelId": stream["modelId"], - "modelName": stream["modelName"], - "online": True, - **{k: v for k, v in get_crop_coords(stream).items() if v is not None} - } - for camera_id, stream in streams.items() - ] - - state_report = { - "type": "stateReport", - "cpuUsage": cpu_usage, - "memoryUsage": memory_usage, - "gpuUsage": gpu_usage, - "gpuMemoryUsage": gpu_memory_usage, - "cameraConnections": camera_connections - } - await websocket.send_text(json.dumps(state_report)) - - elif msg_type == "setSessionId": - payload = data.get("payload", {}) - display_identifier = payload.get("displayIdentifier") - session_id = payload.get("sessionId") - - if display_identifier: - # Store session ID for this display - if session_id is None: - session_ids.pop(display_identifier, None) - logger.info(f"Cleared session ID for display {display_identifier}") - else: - session_ids[display_identifier] = session_id - logger.info(f"Set session ID {session_id} for display {display_identifier}") - - elif msg_type == "patchSession": - session_id = data.get("sessionId") - patch_data = data.get("data", {}) - - # For now, just acknowledge the patch - actual implementation depends on backend requirements - response = { - "type": "patchSessionResult", - "payload": { - "sessionId": session_id, - "success": True, - "message": "Session patch acknowledged" - } - } - await websocket.send_json(response) - logger.info(f"Acknowledged patch for session {session_id}") - - else: - logger.error(f"Unknown message type: {msg_type}") - except json.JSONDecodeError: - logger.error("Received invalid JSON message") - except (WebSocketDisconnect, ConnectionClosedError) as e: - logger.warning(f"WebSocket disconnected: {e}") - break - except Exception as e: - logger.error(f"Error handling message: {e}") - break - try: - await websocket.accept() - stream_task = asyncio.create_task(process_streams()) - heartbeat_task = asyncio.create_task(send_heartbeat()) - message_task = asyncio.create_task(on_message()) - await asyncio.gather(heartbeat_task, message_task) - except Exception as e: - logger.error(f"Error in detect websocket: {e}") - finally: - stream_task.cancel() - await stream_task - with streams_lock: - # Clean up shared camera streams - for camera_url, shared_stream in camera_streams.items(): - shared_stream["stop_event"].set() - shared_stream["thread"].join() - if "cap" in shared_stream: - shared_stream["cap"].release() - while not shared_stream["buffer"].empty(): - try: - shared_stream["buffer"].get_nowait() - except queue.Empty: - pass - logger.info(f"Released shared camera stream for {camera_url}") - - streams.clear() - camera_streams.clear() - subscription_to_camera.clear() - with models_lock: - models.clear() - latest_frames.clear() - session_ids.clear() - logger.info("WebSocket connection closed") diff --git a/config.json b/config.json index 0d061f9..311bbf4 100644 --- a/config.json +++ b/config.json @@ -1,9 +1,7 @@ { "poll_interval_ms": 100, - "max_streams": 20, + "max_streams": 5, "target_fps": 2, - "reconnect_interval_sec": 10, - "max_retries": -1, - "rtsp_buffer_size": 3, - "rtsp_tcp_transport": true + "reconnect_interval_sec": 5, + "max_retries": -1 } diff --git a/core/__init__.py b/core/__init__.py deleted file mode 100644 index e697cb2..0000000 --- a/core/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Core package for detector worker \ No newline at end of file diff --git a/core/communication/__init__.py b/core/communication/__init__.py deleted file mode 100644 index 73145a1..0000000 --- a/core/communication/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Communication module for WebSocket and HTTP handling \ No newline at end of file diff --git a/core/communication/messages.py b/core/communication/messages.py deleted file mode 100644 index 98cc9e5..0000000 --- a/core/communication/messages.py +++ /dev/null @@ -1,212 +0,0 @@ -""" -Message types, constants, and validation functions for WebSocket communication. -""" -import json -import logging -from typing import Dict, Any, Optional, Union -from .models import ( - IncomingMessage, OutgoingMessage, - SetSubscriptionListMessage, SetSessionIdMessage, SetProgressionStageMessage, - RequestStateMessage, PatchSessionResultMessage, - StateReportMessage, ImageDetectionMessage, PatchSessionMessage -) - -logger = logging.getLogger(__name__) - -# Message type constants -class MessageTypes: - """WebSocket message type constants.""" - - # Incoming from backend - SET_SUBSCRIPTION_LIST = "setSubscriptionList" - SET_SESSION_ID = "setSessionId" - SET_PROGRESSION_STAGE = "setProgressionStage" - REQUEST_STATE = "requestState" - PATCH_SESSION_RESULT = "patchSessionResult" - - # Outgoing to backend - STATE_REPORT = "stateReport" - IMAGE_DETECTION = "imageDetection" - PATCH_SESSION = "patchSession" - - -def parse_incoming_message(raw_message: str) -> Optional[IncomingMessage]: - """ - Parse incoming WebSocket message and validate against known types. - - Args: - raw_message: Raw JSON string from WebSocket - - Returns: - Parsed message object or None if invalid - """ - try: - data = json.loads(raw_message) - message_type = data.get("type") - - if not message_type: - logger.error("Message missing 'type' field") - return None - - # Route to appropriate message class - if message_type == MessageTypes.SET_SUBSCRIPTION_LIST: - return SetSubscriptionListMessage(**data) - elif message_type == MessageTypes.SET_SESSION_ID: - return SetSessionIdMessage(**data) - elif message_type == MessageTypes.SET_PROGRESSION_STAGE: - return SetProgressionStageMessage(**data) - elif message_type == MessageTypes.REQUEST_STATE: - return RequestStateMessage(**data) - elif message_type == MessageTypes.PATCH_SESSION_RESULT: - return PatchSessionResultMessage(**data) - else: - logger.warning(f"Unknown message type: {message_type}") - return None - - except json.JSONDecodeError as e: - logger.error(f"Failed to decode JSON message: {e}") - return None - except Exception as e: - logger.error(f"Failed to parse incoming message: {e}") - return None - - -def serialize_outgoing_message(message: OutgoingMessage) -> str: - """ - Serialize outgoing message to JSON string. - - Args: - message: Message object to serialize - - Returns: - JSON string representation - """ - try: - # For ImageDetectionMessage, we need to include None values for abandonment detection - from .models import ImageDetectionMessage - if isinstance(message, ImageDetectionMessage): - return message.model_dump_json(exclude_none=False) - else: - return message.model_dump_json(exclude_none=True) - except Exception as e: - logger.error(f"Failed to serialize outgoing message: {e}") - raise - - -def validate_subscription_identifier(identifier: str) -> bool: - """ - Validate subscription identifier format (displayId;cameraId). - - Args: - identifier: Subscription identifier to validate - - Returns: - True if valid format, False otherwise - """ - if not identifier or not isinstance(identifier, str): - return False - - parts = identifier.split(';') - if len(parts) != 2: - logger.error(f"Invalid subscription identifier format: {identifier}") - return False - - display_id, camera_id = parts - if not display_id or not camera_id: - logger.error(f"Empty display or camera ID in identifier: {identifier}") - return False - - return True - - -def extract_display_identifier(subscription_identifier: str) -> Optional[str]: - """ - Extract display identifier from subscription identifier. - - Args: - subscription_identifier: Full subscription identifier (displayId;cameraId) - - Returns: - Display identifier or None if invalid format - """ - if not validate_subscription_identifier(subscription_identifier): - return None - - return subscription_identifier.split(';')[0] - - -def create_state_report(cpu_usage: float, memory_usage: float, - gpu_usage: Optional[float] = None, - gpu_memory_usage: Optional[float] = None, - camera_connections: Optional[list] = None) -> StateReportMessage: - """ - Create a state report message with system metrics. - - Args: - cpu_usage: CPU usage percentage - memory_usage: Memory usage percentage - gpu_usage: GPU usage percentage (optional) - gpu_memory_usage: GPU memory usage in MB (optional) - camera_connections: List of active camera connections - - Returns: - StateReportMessage object - """ - return StateReportMessage( - cpuUsage=cpu_usage, - memoryUsage=memory_usage, - gpuUsage=gpu_usage, - gpuMemoryUsage=gpu_memory_usage, - cameraConnections=camera_connections or [] - ) - - -def create_image_detection(subscription_identifier: str, detection_data: Union[Dict[str, Any], None], - model_id: int, model_name: str) -> ImageDetectionMessage: - """ - Create an image detection message. - - Args: - subscription_identifier: Camera subscription identifier - detection_data: Detection results - Dict for data, {} for empty, None for abandonment - model_id: Model identifier - model_name: Model name - - Returns: - ImageDetectionMessage object - """ - from .models import DetectionData - from typing import Union - - # Handle three cases: - # 1. None = car abandonment (detection: null) - # 2. {} = empty detection (triggers session creation) - # 3. {...} = full detection data (updates session) - - data = DetectionData( - detection=detection_data, - modelId=model_id, - modelName=model_name - ) - - return ImageDetectionMessage( - subscriptionIdentifier=subscription_identifier, - data=data - ) - - -def create_patch_session(session_id: int, patch_data: Dict[str, Any]) -> PatchSessionMessage: - """ - Create a patch session message. - - Args: - session_id: Session ID to patch - patch_data: Partial session data to update - - Returns: - PatchSessionMessage object - """ - return PatchSessionMessage( - sessionId=session_id, - data=patch_data - ) \ No newline at end of file diff --git a/core/communication/models.py b/core/communication/models.py deleted file mode 100644 index 7214472..0000000 --- a/core/communication/models.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -Message data structures for WebSocket communication. -Based on worker.md protocol specification. -""" -from typing import Dict, Any, List, Optional, Union, Literal -from pydantic import BaseModel, Field -from datetime import datetime - - -class SubscriptionObject(BaseModel): - """Individual camera subscription configuration.""" - subscriptionIdentifier: str = Field(..., description="Format: displayId;cameraId") - rtspUrl: Optional[str] = Field(None, description="RTSP stream URL") - snapshotUrl: Optional[str] = Field(None, description="HTTP snapshot URL") - snapshotInterval: Optional[int] = Field(None, description="Snapshot interval in milliseconds") - modelUrl: str = Field(..., description="Pre-signed URL to .mpta file") - modelId: int = Field(..., description="Unique model identifier") - modelName: str = Field(..., description="Human-readable model name") - cropX1: Optional[int] = Field(None, description="Crop region X1 coordinate") - cropY1: Optional[int] = Field(None, description="Crop region Y1 coordinate") - cropX2: Optional[int] = Field(None, description="Crop region X2 coordinate") - cropY2: Optional[int] = Field(None, description="Crop region Y2 coordinate") - - -class CameraConnection(BaseModel): - """Camera connection status for state reporting.""" - subscriptionIdentifier: str - modelId: int - modelName: str - online: bool - cropX1: Optional[int] = None - cropY1: Optional[int] = None - cropX2: Optional[int] = None - cropY2: Optional[int] = None - - -class DetectionData(BaseModel): - """ - Detection result data structure. - - Supports three cases: - 1. Empty detection: detection = {} (triggers session creation) - 2. Full detection: detection = {"carBrand": "Honda", ...} (updates session) - 3. Null detection: detection = None (car abandonment) - """ - model_config = { - "json_encoders": {type(None): lambda v: None}, - "arbitrary_types_allowed": True - } - - detection: Union[Dict[str, Any], None] = Field( - default_factory=dict, - description="Detection results: {} for empty, {...} for data, None/null for abandonment" - ) - modelId: int - modelName: str - - -# Incoming Messages from Backend to Worker - -class SetSubscriptionListMessage(BaseModel): - """Complete subscription list for declarative state management.""" - type: Literal["setSubscriptionList"] = "setSubscriptionList" - subscriptions: List[SubscriptionObject] - - -class SetSessionIdPayload(BaseModel): - """Session ID association payload.""" - displayIdentifier: str - sessionId: Optional[int] = None - - -class SetSessionIdMessage(BaseModel): - """Associate session ID with display.""" - type: Literal["setSessionId"] = "setSessionId" - payload: SetSessionIdPayload - - -class SetProgressionStagePayload(BaseModel): - """Progression stage payload.""" - displayIdentifier: str - progressionStage: Optional[str] = None - - -class SetProgressionStageMessage(BaseModel): - """Set progression stage for display.""" - type: Literal["setProgressionStage"] = "setProgressionStage" - payload: SetProgressionStagePayload - - -class RequestStateMessage(BaseModel): - """Request current worker state.""" - type: Literal["requestState"] = "requestState" - - -class PatchSessionResultPayload(BaseModel): - """Patch session result payload.""" - sessionId: int - success: bool - message: str - - -class PatchSessionResultMessage(BaseModel): - """Response to patch session request.""" - type: Literal["patchSessionResult"] = "patchSessionResult" - payload: PatchSessionResultPayload - - -# Outgoing Messages from Worker to Backend - -class StateReportMessage(BaseModel): - """Periodic heartbeat with system metrics.""" - type: Literal["stateReport"] = "stateReport" - cpuUsage: float - memoryUsage: float - gpuUsage: Optional[float] = None - gpuMemoryUsage: Optional[float] = None - cameraConnections: List[CameraConnection] - - -class ImageDetectionMessage(BaseModel): - """Detection event message.""" - type: Literal["imageDetection"] = "imageDetection" - subscriptionIdentifier: str - timestamp: str = Field(default_factory=lambda: datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ")) - data: DetectionData - - -class PatchSessionMessage(BaseModel): - """Request to modify session data.""" - type: Literal["patchSession"] = "patchSession" - sessionId: int - data: Dict[str, Any] = Field(..., description="Partial DisplayPersistentData structure") - - -# Union type for all incoming messages -IncomingMessage = Union[ - SetSubscriptionListMessage, - SetSessionIdMessage, - SetProgressionStageMessage, - RequestStateMessage, - PatchSessionResultMessage -] - -# Union type for all outgoing messages -OutgoingMessage = Union[ - StateReportMessage, - ImageDetectionMessage, - PatchSessionMessage -] \ No newline at end of file diff --git a/core/communication/state.py b/core/communication/state.py deleted file mode 100644 index 9016c07..0000000 --- a/core/communication/state.py +++ /dev/null @@ -1,234 +0,0 @@ -""" -Worker state management for system metrics and subscription tracking. -""" -import logging -import psutil -import threading -from typing import Dict, Set, Optional, List -from dataclasses import dataclass, field -from .models import CameraConnection, SubscriptionObject - -logger = logging.getLogger(__name__) - -# Try to import torch and pynvml for GPU monitoring -try: - import torch - TORCH_AVAILABLE = True -except ImportError: - TORCH_AVAILABLE = False - logger.warning("PyTorch not available, GPU metrics will not be collected") - -try: - import pynvml - PYNVML_AVAILABLE = True - pynvml.nvmlInit() - logger.info("NVIDIA ML Python (pynvml) initialized successfully") -except ImportError: - PYNVML_AVAILABLE = False - logger.debug("pynvml not available, falling back to PyTorch GPU monitoring") -except Exception as e: - PYNVML_AVAILABLE = False - logger.warning(f"Failed to initialize pynvml: {e}") - - -@dataclass -class WorkerState: - """Central state management for the detector worker.""" - - # Active subscriptions indexed by subscription identifier - subscriptions: Dict[str, SubscriptionObject] = field(default_factory=dict) - - # Session ID mapping: display_identifier -> session_id - session_ids: Dict[str, int] = field(default_factory=dict) - - # Progression stage mapping: display_identifier -> stage - progression_stages: Dict[str, str] = field(default_factory=dict) - - # Active camera connections for state reporting - camera_connections: List[CameraConnection] = field(default_factory=list) - - # Thread lock for state synchronization - _lock: threading.RLock = field(default_factory=threading.RLock) - - def set_subscriptions(self, new_subscriptions: List[SubscriptionObject]) -> None: - """ - Update active subscriptions with declarative list from backend. - - Args: - new_subscriptions: Complete list of desired subscriptions - """ - with self._lock: - # Convert to dict for easy lookup - new_sub_dict = {sub.subscriptionIdentifier: sub for sub in new_subscriptions} - - # Log changes for debugging - current_ids = set(self.subscriptions.keys()) - new_ids = set(new_sub_dict.keys()) - - added = new_ids - current_ids - removed = current_ids - new_ids - updated = current_ids & new_ids - - if added: - logger.info(f"[State Update] Adding subscriptions: {added}") - if removed: - logger.info(f"[State Update] Removing subscriptions: {removed}") - if updated: - logger.info(f"[State Update] Updating subscriptions: {updated}") - - # Replace entire subscription dict - self.subscriptions = new_sub_dict - - # Update camera connections for state reporting - self._update_camera_connections() - - def get_subscription(self, subscription_identifier: str) -> Optional[SubscriptionObject]: - """Get subscription by identifier.""" - with self._lock: - return self.subscriptions.get(subscription_identifier) - - def get_all_subscriptions(self) -> List[SubscriptionObject]: - """Get all active subscriptions.""" - with self._lock: - return list(self.subscriptions.values()) - - def set_session_id(self, display_identifier: str, session_id: Optional[int]) -> None: - """ - Set or clear session ID for a display. - - Args: - display_identifier: Display identifier - session_id: Session ID to set, or None to clear - """ - with self._lock: - if session_id is None: - self.session_ids.pop(display_identifier, None) - logger.info(f"[State Update] Cleared session ID for display {display_identifier}") - else: - self.session_ids[display_identifier] = session_id - logger.info(f"[State Update] Set session ID {session_id} for display {display_identifier}") - - def get_session_id(self, display_identifier: str) -> Optional[int]: - """Get session ID for display identifier.""" - with self._lock: - return self.session_ids.get(display_identifier) - - def get_session_id_for_subscription(self, subscription_identifier: str) -> Optional[int]: - """Get session ID for subscription by extracting display identifier.""" - from .messages import extract_display_identifier - - display_id = extract_display_identifier(subscription_identifier) - if display_id: - return self.get_session_id(display_id) - return None - - def set_progression_stage(self, display_identifier: str, stage: Optional[str]) -> None: - """ - Set or clear progression stage for a display. - - Args: - display_identifier: Display identifier - stage: Progression stage to set, or None to clear - """ - with self._lock: - if stage is None: - self.progression_stages.pop(display_identifier, None) - logger.info(f"[State Update] Cleared progression stage for display {display_identifier}") - else: - self.progression_stages[display_identifier] = stage - logger.info(f"[State Update] Set progression stage '{stage}' for display {display_identifier}") - - def get_progression_stage(self, display_identifier: str) -> Optional[str]: - """Get progression stage for display identifier.""" - with self._lock: - return self.progression_stages.get(display_identifier) - - def _update_camera_connections(self) -> None: - """Update camera connections list for state reporting.""" - connections = [] - - for sub in self.subscriptions.values(): - connection = CameraConnection( - subscriptionIdentifier=sub.subscriptionIdentifier, - modelId=sub.modelId, - modelName=sub.modelName, - online=True, # TODO: Add actual online status tracking - cropX1=sub.cropX1, - cropY1=sub.cropY1, - cropX2=sub.cropX2, - cropY2=sub.cropY2 - ) - connections.append(connection) - - self.camera_connections = connections - - def get_camera_connections(self) -> List[CameraConnection]: - """Get current camera connections for state reporting.""" - with self._lock: - return self.camera_connections.copy() - - -class SystemMetrics: - """System metrics collection for state reporting.""" - - @staticmethod - def get_cpu_usage() -> float: - """Get current CPU usage percentage.""" - try: - return psutil.cpu_percent(interval=0.1) - except Exception as e: - logger.error(f"Failed to get CPU usage: {e}") - return 0.0 - - @staticmethod - def get_memory_usage() -> float: - """Get current memory usage percentage.""" - try: - return psutil.virtual_memory().percent - except Exception as e: - logger.error(f"Failed to get memory usage: {e}") - return 0.0 - - @staticmethod - def get_gpu_usage() -> Optional[float]: - """Get current GPU usage percentage.""" - try: - # Prefer pynvml for accurate GPU utilization - if PYNVML_AVAILABLE: - handle = pynvml.nvmlDeviceGetHandleByIndex(0) # First GPU - utilization = pynvml.nvmlDeviceGetUtilizationRates(handle) - return float(utilization.gpu) - - # Fallback to PyTorch memory-based estimation - elif TORCH_AVAILABLE and torch.cuda.is_available(): - if hasattr(torch.cuda, 'utilization'): - return torch.cuda.utilization() - else: - # Estimate based on memory usage - allocated = torch.cuda.memory_allocated() - reserved = torch.cuda.memory_reserved() - if reserved > 0: - return (allocated / reserved) * 100 - - return None - except Exception as e: - logger.error(f"Failed to get GPU usage: {e}") - return None - - @staticmethod - def get_gpu_memory_usage() -> Optional[float]: - """Get current GPU memory usage in MB.""" - if not TORCH_AVAILABLE: - return None - - try: - if torch.cuda.is_available(): - return torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB - return None - except Exception as e: - logger.error(f"Failed to get GPU memory usage: {e}") - return None - - -# Global worker state instance -worker_state = WorkerState() \ No newline at end of file diff --git a/core/communication/websocket.py b/core/communication/websocket.py deleted file mode 100644 index e53096a..0000000 --- a/core/communication/websocket.py +++ /dev/null @@ -1,677 +0,0 @@ -""" -WebSocket message handling and protocol implementation. -""" -import asyncio -import json -import logging -import os -import cv2 -from datetime import datetime, timezone, timedelta -from pathlib import Path -from typing import Optional -from fastapi import WebSocket, WebSocketDisconnect -from websockets.exceptions import ConnectionClosedError - -from .messages import ( - parse_incoming_message, serialize_outgoing_message, - MessageTypes, create_state_report -) -from .models import ( - SetSubscriptionListMessage, SetSessionIdMessage, SetProgressionStageMessage, - RequestStateMessage, PatchSessionResultMessage -) -from .state import worker_state, SystemMetrics -from ..models import ModelManager -from ..streaming.manager import shared_stream_manager -from ..tracking.integration import TrackingPipelineIntegration - -logger = logging.getLogger(__name__) - -# Constants -HEARTBEAT_INTERVAL = 2.0 # seconds -WORKER_TIMEOUT_MS = 10000 - -# Global model manager instance -model_manager = ModelManager() - - -class WebSocketHandler: - """ - Handles WebSocket connection lifecycle and message processing. - """ - - def __init__(self, websocket: WebSocket): - self.websocket = websocket - self.connected = False - self._heartbeat_task: Optional[asyncio.Task] = None - self._message_task: Optional[asyncio.Task] = None - self._heartbeat_count = 0 - self._last_processed_models: set = set() # Cache of last processed model IDs - - async def handle_connection(self) -> None: - """ - Main connection handler that manages the WebSocket lifecycle. - Based on the original architecture from archive/app.py - """ - client_info = f"{self.websocket.client.host}:{self.websocket.client.port}" if self.websocket.client else "unknown" - logger.info(f"Starting WebSocket handler for {client_info}") - - stream_task = None - try: - logger.info(f"Accepting WebSocket connection from {client_info}") - await self.websocket.accept() - self.connected = True - logger.info(f"WebSocket connection accepted and established for {client_info}") - - # Send immediate heartbeat to show connection is alive - await self._send_immediate_heartbeat() - - # Start background tasks (matching original architecture) - stream_task = asyncio.create_task(self._process_streams()) - heartbeat_task = asyncio.create_task(self._send_heartbeat()) - message_task = asyncio.create_task(self._handle_messages()) - - logger.info(f"WebSocket background tasks started for {client_info} (stream + heartbeat + message handler)") - - # Wait for heartbeat and message tasks (stream runs independently) - await asyncio.gather(heartbeat_task, message_task) - - except Exception as e: - logger.error(f"Error in WebSocket connection for {client_info}: {e}", exc_info=True) - finally: - logger.info(f"Cleaning up connection for {client_info}") - # Cancel stream task - if stream_task and not stream_task.done(): - stream_task.cancel() - try: - await stream_task - except asyncio.CancelledError: - logger.debug(f"Stream task cancelled for {client_info}") - await self._cleanup() - - async def _send_immediate_heartbeat(self) -> None: - """Send immediate heartbeat on connection to show we're alive.""" - try: - cpu_usage = SystemMetrics.get_cpu_usage() - memory_usage = SystemMetrics.get_memory_usage() - gpu_usage = SystemMetrics.get_gpu_usage() - gpu_memory_usage = SystemMetrics.get_gpu_memory_usage() - camera_connections = worker_state.get_camera_connections() - - state_report = create_state_report( - cpu_usage=cpu_usage, - memory_usage=memory_usage, - gpu_usage=gpu_usage, - gpu_memory_usage=gpu_memory_usage, - camera_connections=camera_connections - ) - - await self._send_message(state_report) - logger.info(f"[TX → Backend] Initial stateReport: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, " - f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras") - - except Exception as e: - logger.error(f"Error sending immediate heartbeat: {e}") - - async def _send_heartbeat(self) -> None: - """Send periodic state reports as heartbeat.""" - while self.connected: - try: - # Collect system metrics - cpu_usage = SystemMetrics.get_cpu_usage() - memory_usage = SystemMetrics.get_memory_usage() - gpu_usage = SystemMetrics.get_gpu_usage() - gpu_memory_usage = SystemMetrics.get_gpu_memory_usage() - camera_connections = worker_state.get_camera_connections() - - # Create and send state report - state_report = create_state_report( - cpu_usage=cpu_usage, - memory_usage=memory_usage, - gpu_usage=gpu_usage, - gpu_memory_usage=gpu_memory_usage, - camera_connections=camera_connections - ) - - await self._send_message(state_report) - - # Only log full details every 10th heartbeat, otherwise just show a dot - self._heartbeat_count += 1 - if self._heartbeat_count % 10 == 0: - logger.info(f"[TX → Backend] Heartbeat #{self._heartbeat_count}: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, " - f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras") - else: - print(".", end="", flush=True) # Just show a dot to indicate heartbeat activity - - await asyncio.sleep(HEARTBEAT_INTERVAL) - - except Exception as e: - logger.error(f"Error sending heartbeat: {e}") - break - - async def _handle_messages(self) -> None: - """Handle incoming WebSocket messages.""" - while self.connected: - try: - raw_message = await self.websocket.receive_text() - logger.info(f"[RX ← Backend] {raw_message}") - - # Parse incoming message - message = parse_incoming_message(raw_message) - if not message: - logger.warning("Failed to parse incoming message") - continue - - # Route message to appropriate handler - await self._route_message(message) - - except (WebSocketDisconnect, ConnectionClosedError) as e: - logger.warning(f"WebSocket disconnected: {e}") - break - except json.JSONDecodeError: - logger.error("Received invalid JSON message") - except Exception as e: - logger.error(f"Error handling message: {e}") - break - - async def _route_message(self, message) -> None: - """Route parsed message to appropriate handler.""" - message_type = message.type - - try: - if message_type == MessageTypes.SET_SUBSCRIPTION_LIST: - await self._handle_set_subscription_list(message) - elif message_type == MessageTypes.SET_SESSION_ID: - await self._handle_set_session_id(message) - elif message_type == MessageTypes.SET_PROGRESSION_STAGE: - await self._handle_set_progression_stage(message) - elif message_type == MessageTypes.REQUEST_STATE: - await self._handle_request_state(message) - elif message_type == MessageTypes.PATCH_SESSION_RESULT: - await self._handle_patch_session_result(message) - else: - logger.warning(f"Unknown message type: {message_type}") - - except Exception as e: - logger.error(f"Error handling {message_type} message: {e}") - - async def _handle_set_subscription_list(self, message: SetSubscriptionListMessage) -> None: - """Handle setSubscriptionList message for declarative subscription management.""" - logger.info(f"[RX Processing] setSubscriptionList with {len(message.subscriptions)} subscriptions") - - # Update worker state with new subscriptions - worker_state.set_subscriptions(message.subscriptions) - - # Phase 2: Download and manage models - await self._ensure_models(message.subscriptions) - - # Phase 3 & 4: Integrate with streaming management and tracking - await self._update_stream_subscriptions(message.subscriptions) - - logger.info("Subscription list updated successfully") - - async def _ensure_models(self, subscriptions) -> None: - """Ensure all required models are downloaded and available.""" - # Extract unique model requirements - unique_models = {} - for subscription in subscriptions: - model_id = subscription.modelId - if model_id not in unique_models: - unique_models[model_id] = { - 'model_url': subscription.modelUrl, - 'model_name': subscription.modelName - } - - # Check if model set has changed to avoid redundant processing - current_model_ids = set(unique_models.keys()) - if current_model_ids == self._last_processed_models: - logger.debug(f"[Model Management] Model set unchanged {list(current_model_ids)}, skipping checks") - return - - logger.info(f"[Model Management] Processing {len(unique_models)} unique models: {list(unique_models.keys())}") - self._last_processed_models = current_model_ids - - # Check and download models concurrently - download_tasks = [] - for model_id, model_info in unique_models.items(): - task = asyncio.create_task( - self._ensure_single_model(model_id, model_info['model_url'], model_info['model_name']) - ) - download_tasks.append(task) - - # Wait for all downloads to complete - if download_tasks: - results = await asyncio.gather(*download_tasks, return_exceptions=True) - - # Log results - success_count = 0 - for i, result in enumerate(results): - model_id = list(unique_models.keys())[i] - if isinstance(result, Exception): - logger.error(f"[Model Management] Failed to ensure model {model_id}: {result}") - elif result: - success_count += 1 - logger.info(f"[Model Management] Model {model_id} ready for use") - else: - logger.error(f"[Model Management] Failed to ensure model {model_id}") - - logger.info(f"[Model Management] Successfully ensured {success_count}/{len(unique_models)} models") - - async def _update_stream_subscriptions(self, subscriptions) -> None: - """Update streaming subscriptions with tracking integration.""" - try: - # Convert subscriptions to the format expected by StreamManager - subscription_payloads = [] - for subscription in subscriptions: - payload = { - 'subscriptionIdentifier': subscription.subscriptionIdentifier, - 'rtspUrl': subscription.rtspUrl, - 'snapshotUrl': subscription.snapshotUrl, - 'snapshotInterval': subscription.snapshotInterval, - 'modelId': subscription.modelId, - 'modelUrl': subscription.modelUrl, - 'modelName': subscription.modelName - } - # Add crop coordinates if present - if hasattr(subscription, 'cropX1'): - payload.update({ - 'cropX1': subscription.cropX1, - 'cropY1': subscription.cropY1, - 'cropX2': subscription.cropX2, - 'cropY2': subscription.cropY2 - }) - subscription_payloads.append(payload) - - # Reconcile subscriptions with StreamManager - logger.info("[Streaming] Reconciling stream subscriptions with tracking") - reconcile_result = await self._reconcile_subscriptions_with_tracking(subscription_payloads) - - logger.info(f"[Streaming] Subscription reconciliation complete: " - f"added={reconcile_result.get('added', 0)}, " - f"removed={reconcile_result.get('removed', 0)}, " - f"failed={reconcile_result.get('failed', 0)}") - - except Exception as e: - logger.error(f"Error updating stream subscriptions: {e}", exc_info=True) - - async def _reconcile_subscriptions_with_tracking(self, target_subscriptions) -> dict: - """Reconcile subscriptions with tracking integration.""" - try: - # Create separate tracking integrations for each subscription (camera isolation) - tracking_integrations = {} - - for subscription_payload in target_subscriptions: - subscription_id = subscription_payload['subscriptionIdentifier'] - model_id = subscription_payload['modelId'] - - # Create separate tracking integration per subscription for camera isolation - # Get pipeline configuration for this model - pipeline_parser = model_manager.get_pipeline_config(model_id) - if pipeline_parser: - # Create tracking integration with message sender (separate instance per camera) - tracking_integration = TrackingPipelineIntegration( - pipeline_parser, model_manager, model_id, self._send_message - ) - - # Initialize tracking model - success = await tracking_integration.initialize_tracking_model() - if success: - tracking_integrations[subscription_id] = tracking_integration - logger.info(f"[Tracking] Created isolated tracking integration for subscription {subscription_id} (model {model_id})") - else: - logger.warning(f"[Tracking] Failed to initialize tracking for subscription {subscription_id} (model {model_id})") - else: - logger.warning(f"[Tracking] No pipeline config found for model {model_id} in subscription {subscription_id}") - - # Now reconcile with StreamManager, adding tracking integrations - current_subscription_ids = set() - for subscription_info in shared_stream_manager.get_all_subscriptions(): - current_subscription_ids.add(subscription_info.subscription_id) - - target_subscription_ids = {sub['subscriptionIdentifier'] for sub in target_subscriptions} - - # Find subscriptions to remove and add - to_remove = current_subscription_ids - target_subscription_ids - to_add = target_subscription_ids - current_subscription_ids - - # Remove old subscriptions - removed_count = 0 - for subscription_id in to_remove: - if shared_stream_manager.remove_subscription(subscription_id): - removed_count += 1 - logger.info(f"[Streaming] Removed subscription {subscription_id}") - - # Add new subscriptions with tracking - added_count = 0 - failed_count = 0 - for subscription_payload in target_subscriptions: - subscription_id = subscription_payload['subscriptionIdentifier'] - if subscription_id in to_add: - success = await self._add_subscription_with_tracking( - subscription_payload, tracking_integrations - ) - if success: - added_count += 1 - logger.info(f"[Streaming] Added subscription {subscription_id} with tracking") - else: - failed_count += 1 - logger.error(f"[Streaming] Failed to add subscription {subscription_id}") - - return { - 'removed': removed_count, - 'added': added_count, - 'failed': failed_count, - 'total_active': len(shared_stream_manager.get_all_subscriptions()) - } - - except Exception as e: - logger.error(f"Error in subscription reconciliation with tracking: {e}", exc_info=True) - return {'removed': 0, 'added': 0, 'failed': 0, 'total_active': 0} - - async def _add_subscription_with_tracking(self, payload, tracking_integrations) -> bool: - """Add a subscription with tracking integration.""" - try: - from ..streaming.manager import StreamConfig - - subscription_id = payload['subscriptionIdentifier'] - camera_id = subscription_id.split(';')[-1] - model_id = payload['modelId'] - - logger.info(f"[SUBSCRIPTION_MAPPING] subscription_id='{subscription_id}' → camera_id='{camera_id}'") - - # Get tracking integration for this subscription (camera-isolated) - tracking_integration = tracking_integrations.get(subscription_id) - - # Extract crop coordinates if present - crop_coords = None - if all(key in payload for key in ['cropX1', 'cropY1', 'cropX2', 'cropY2']): - crop_coords = ( - payload['cropX1'], - payload['cropY1'], - payload['cropX2'], - payload['cropY2'] - ) - - # Create stream configuration - stream_config = StreamConfig( - camera_id=camera_id, - rtsp_url=payload.get('rtspUrl'), - snapshot_url=payload.get('snapshotUrl'), - snapshot_interval=payload.get('snapshotInterval', 5000), - max_retries=3, - ) - - # Add subscription to StreamManager with tracking - success = shared_stream_manager.add_subscription( - subscription_id=subscription_id, - stream_config=stream_config, - crop_coords=crop_coords, - model_id=model_id, - model_url=payload.get('modelUrl'), - tracking_integration=tracking_integration - ) - - if success and tracking_integration: - logger.info(f"[Tracking] Subscription {subscription_id} configured with isolated tracking for model {model_id}") - - return success - - except Exception as e: - logger.error(f"Error adding subscription with tracking: {e}", exc_info=True) - return False - - async def _ensure_single_model(self, model_id: int, model_url: str, model_name: str) -> bool: - """Ensure a single model is downloaded and available.""" - try: - # Check if model is already available - if model_manager.is_model_downloaded(model_id): - logger.info(f"[Model Management] Model {model_id} ({model_name}) already available") - return True - - # Download and extract model in a thread pool to avoid blocking the event loop - logger.info(f"[Model Management] Downloading model {model_id} ({model_name}) from {model_url}") - - # Use asyncio.to_thread for CPU-bound operations (Python 3.9+) - # For compatibility, we'll use run_in_executor - loop = asyncio.get_event_loop() - model_path = await loop.run_in_executor( - None, - model_manager.ensure_model, - model_id, - model_url, - model_name - ) - - if model_path: - logger.info(f"[Model Management] Successfully prepared model {model_id} at {model_path}") - return True - else: - logger.error(f"[Model Management] Failed to prepare model {model_id}") - return False - - except Exception as e: - logger.error(f"[Model Management] Exception ensuring model {model_id}: {str(e)}", exc_info=True) - return False - - async def _save_snapshot(self, display_identifier: str, session_id: int) -> None: - """ - Save snapshot image to images folder after receiving sessionId. - - Args: - display_identifier: Display identifier to match with subscriptionIdentifier - session_id: Session ID to include in filename - """ - try: - # Find subscription that matches the displayIdentifier - matching_subscription = None - for subscription in worker_state.get_all_subscriptions(): - # Extract display ID from subscriptionIdentifier (format: displayId;cameraId) - from .messages import extract_display_identifier - sub_display_id = extract_display_identifier(subscription.subscriptionIdentifier) - if sub_display_id == display_identifier: - matching_subscription = subscription - break - - if not matching_subscription: - logger.error(f"[Snapshot Save] No subscription found for display {display_identifier}") - return - - if not matching_subscription.snapshotUrl: - logger.error(f"[Snapshot Save] No snapshotUrl found for display {display_identifier}") - return - - # Ensure images directory exists (relative path for Docker bind mount) - images_dir = Path("images") - images_dir.mkdir(exist_ok=True) - - # Generate filename with timestamp and session ID - timestamp = datetime.now(tz=timezone(timedelta(hours=7))).strftime("%Y%m%d_%H%M%S") - filename = f"{session_id}_{display_identifier}_{timestamp}.jpg" - filepath = images_dir / filename - - # Use existing HTTPSnapshotReader to fetch snapshot - logger.info(f"[Snapshot Save] Fetching snapshot from {matching_subscription.snapshotUrl}") - - # Run snapshot fetch in thread pool to avoid blocking async loop - loop = asyncio.get_event_loop() - frame = await loop.run_in_executor(None, self._fetch_snapshot_sync, matching_subscription.snapshotUrl) - - if frame is not None: - # Save the image using OpenCV - success = cv2.imwrite(str(filepath), frame) - if success: - logger.info(f"[Snapshot Save] Successfully saved snapshot to {filepath}") - else: - logger.error(f"[Snapshot Save] Failed to save image file {filepath}") - else: - logger.error(f"[Snapshot Save] Failed to fetch snapshot from {matching_subscription.snapshotUrl}") - - except Exception as e: - logger.error(f"[Snapshot Save] Error saving snapshot for display {display_identifier}: {e}", exc_info=True) - - def _fetch_snapshot_sync(self, snapshot_url: str): - """ - Synchronous snapshot fetching using existing HTTPSnapshotReader infrastructure. - - Args: - snapshot_url: URL to fetch snapshot from - - Returns: - np.ndarray or None: Fetched frame or None on error - """ - try: - from ..streaming.readers import HTTPSnapshotReader - - # Create temporary snapshot reader for single fetch - snapshot_reader = HTTPSnapshotReader( - camera_id="temp_snapshot", - snapshot_url=snapshot_url, - interval_ms=5000 # Not used for single fetch - ) - - # Use existing fetch_single_snapshot method - return snapshot_reader.fetch_single_snapshot() - - except Exception as e: - logger.error(f"Error in sync snapshot fetch: {e}") - return None - - async def _handle_set_session_id(self, message: SetSessionIdMessage) -> None: - """Handle setSessionId message.""" - display_identifier = message.payload.displayIdentifier - session_id = str(message.payload.sessionId) if message.payload.sessionId is not None else None - - logger.info(f"[RX Processing] setSessionId for display {display_identifier}: {session_id}") - - # Update worker state - worker_state.set_session_id(display_identifier, session_id) - - # Update tracking integrations with session ID - shared_stream_manager.set_session_id(display_identifier, session_id) - - async def _handle_set_progression_stage(self, message: SetProgressionStageMessage) -> None: - """Handle setProgressionStage message.""" - display_identifier = message.payload.displayIdentifier - stage = message.payload.progressionStage - - logger.info(f"[RX Processing] setProgressionStage for display {display_identifier}: {stage}") - - # Update worker state - worker_state.set_progression_stage(display_identifier, stage) - - # Update tracking integration for car abandonment detection - session_id = worker_state.get_session_id(display_identifier) - if session_id: - shared_stream_manager.set_progression_stage(session_id, stage) - - # Save snapshot image when progression stage is car_fueling - if stage == 'car_fueling' and session_id: - await self._save_snapshot(display_identifier, session_id) - - # If stage indicates session is cleared/finished, clear from tracking - if stage in ['finished', 'cleared', 'idle']: - # Get session ID for this display and clear it - if session_id: - shared_stream_manager.clear_session_id(session_id) - logger.info(f"[Tracking] Cleared session {session_id} due to progression stage: {stage}") - - async def _handle_request_state(self, message: RequestStateMessage) -> None: - """Handle requestState message by sending immediate state report.""" - logger.debug("[RX Processing] requestState - sending immediate state report") - - # Collect metrics and send state report - cpu_usage = SystemMetrics.get_cpu_usage() - memory_usage = SystemMetrics.get_memory_usage() - gpu_usage = SystemMetrics.get_gpu_usage() - gpu_memory_usage = SystemMetrics.get_gpu_memory_usage() - camera_connections = worker_state.get_camera_connections() - - state_report = create_state_report( - cpu_usage=cpu_usage, - memory_usage=memory_usage, - gpu_usage=gpu_usage, - gpu_memory_usage=gpu_memory_usage, - camera_connections=camera_connections - ) - - await self._send_message(state_report) - - async def _handle_patch_session_result(self, message: PatchSessionResultMessage) -> None: - """Handle patchSessionResult message.""" - payload = message.payload - logger.info(f"[RX Processing] patchSessionResult for session {payload.sessionId}: " - f"success={payload.success}, message='{payload.message}'") - - # TODO: Handle patch session result if needed - # For now, just log the response - - async def _send_message(self, message) -> None: - """Send message to backend via WebSocket.""" - if not self.connected: - logger.warning("Cannot send message: WebSocket not connected") - return - - try: - json_message = serialize_outgoing_message(message) - await self.websocket.send_text(json_message) - # Log non-heartbeat messages only (heartbeats are logged in their respective functions) - if not (hasattr(message, 'type') and message.type == 'stateReport'): - logger.info(f"[TX → Backend] {json_message}") - except Exception as e: - logger.error(f"Failed to send WebSocket message: {e}") - raise - - async def _process_streams(self) -> None: - """ - Stream processing task that handles frame processing and detection. - This is a placeholder for Phase 2 - currently just logs that it's running. - """ - logger.info("Stream processing task started") - try: - while self.connected: - # Get current subscriptions - subscriptions = worker_state.get_all_subscriptions() - - # TODO: Phase 2 - Add actual frame processing logic here - # This will include: - # - Frame reading from RTSP/HTTP streams - # - Model inference using loaded pipelines - # - Detection result sending via WebSocket - - # Sleep to prevent excessive CPU usage (similar to old poll_interval) - await asyncio.sleep(0.1) # 100ms polling interval - - except asyncio.CancelledError: - logger.info("Stream processing task cancelled") - except Exception as e: - logger.error(f"Error in stream processing: {e}", exc_info=True) - - async def _cleanup(self) -> None: - """Clean up resources when connection closes.""" - logger.info("Cleaning up WebSocket connection") - self.connected = False - - # Cancel background tasks - if self._heartbeat_task and not self._heartbeat_task.done(): - self._heartbeat_task.cancel() - if self._message_task and not self._message_task.done(): - self._message_task.cancel() - - # Clear worker state - worker_state.set_subscriptions([]) - worker_state.session_ids.clear() - worker_state.progression_stages.clear() - - logger.info("WebSocket connection cleanup completed") - - -# Factory function for FastAPI integration -async def websocket_endpoint(websocket: WebSocket) -> None: - """ - FastAPI WebSocket endpoint handler. - - Args: - websocket: FastAPI WebSocket connection - """ - handler = WebSocketHandler(websocket) - await handler.handle_connection() \ No newline at end of file diff --git a/core/detection/__init__.py b/core/detection/__init__.py deleted file mode 100644 index 2bcb75c..0000000 --- a/core/detection/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -Detection module for the Python Detector Worker. - -This module provides the main detection pipeline orchestration and parallel branch processing -for advanced computer vision detection systems. -""" -from .pipeline import DetectionPipeline -from .branches import BranchProcessor - -__all__ = ['DetectionPipeline', 'BranchProcessor'] \ No newline at end of file diff --git a/core/detection/branches.py b/core/detection/branches.py deleted file mode 100644 index 89881b2..0000000 --- a/core/detection/branches.py +++ /dev/null @@ -1,936 +0,0 @@ -""" -Parallel Branch Processing Module. -Handles concurrent execution of classification branches and result synchronization. -""" -import logging -import asyncio -import time -from typing import Dict, List, Optional, Any, Tuple -from concurrent.futures import ThreadPoolExecutor, as_completed -import numpy as np -import cv2 - -from ..models.inference import YOLOWrapper - -logger = logging.getLogger(__name__) - - -class BranchProcessor: - """ - Handles parallel processing of classification branches. - Manages branch synchronization and result collection. - """ - - def __init__(self, model_manager: Any, model_id: int): - """ - Initialize branch processor. - - Args: - model_manager: Model manager for loading models - model_id: The model ID to use for loading models - """ - self.model_manager = model_manager - self.model_id = model_id - - # Branch models cache - self.branch_models: Dict[str, YOLOWrapper] = {} - - # Dynamic field mapping: branch_id → output_field_name (e.g., {"car_brand_cls_v3": "brand"}) - self.branch_output_fields: Dict[str, str] = {} - - # Thread pool for parallel execution - self.executor = ThreadPoolExecutor(max_workers=4) - - # Storage managers (set during initialization) - self.redis_manager = None - # self.db_manager = None # Disabled - PostgreSQL operations moved to microservices - - # Branch execution timeout (seconds) - self.branch_timeout = 30.0 - - # Statistics - self.stats = { - 'branches_processed': 0, - 'parallel_executions': 0, - 'total_processing_time': 0.0, - 'models_loaded': 0, - 'branches_timed_out': 0, - 'branches_failed': 0 - } - - logger.info("BranchProcessor initialized") - - async def initialize(self, pipeline_config: Any, redis_manager: Any, db_manager: Any = None) -> bool: - """ - Initialize branch processor with pipeline configuration. - - Args: - pipeline_config: Pipeline configuration object - redis_manager: Redis manager instance - db_manager: Database manager instance (deprecated, not used) - - Returns: - True if successful, False otherwise - """ - try: - self.redis_manager = redis_manager - # self.db_manager = db_manager # Disabled - PostgreSQL operations moved to microservices - - # Parse field mappings from parallelActions to enable dynamic field extraction - self._parse_branch_output_fields(pipeline_config) - - # Pre-load branch models if they exist - branches = getattr(pipeline_config, 'branches', []) - if branches: - await self._preload_branch_models(branches) - - logger.info(f"BranchProcessor initialized with {len(self.branch_models)} models") - return True - - except Exception as e: - logger.error(f"Error initializing branch processor: {e}", exc_info=True) - return False - - async def _preload_branch_models(self, branches: List[Any]) -> None: - """ - Pre-load all branch models for faster execution. - - Args: - branches: List of branch configurations - """ - for branch in branches: - try: - await self._load_branch_model(branch) - - # Recursively load nested branches - nested_branches = getattr(branch, 'branches', []) - if nested_branches: - await self._preload_branch_models(nested_branches) - - except Exception as e: - logger.error(f"Error preloading branch model {getattr(branch, 'model_id', 'unknown')}: {e}") - - async def _load_branch_model(self, branch_config: Any) -> Optional[YOLOWrapper]: - """ - Load a branch model if not already loaded. - - Args: - branch_config: Branch configuration object - - Returns: - Loaded YOLO model wrapper or None - """ - try: - model_id = getattr(branch_config, 'model_id', None) - model_file = getattr(branch_config, 'model_file', None) - - if not model_id or not model_file: - logger.warning(f"Invalid branch config: model_id={model_id}, model_file={model_file}") - return None - - # Check if model is already loaded - if model_id in self.branch_models: - logger.debug(f"Branch model {model_id} already loaded") - return self.branch_models[model_id] - - # Load model - logger.info(f"Loading branch model: {model_id} ({model_file})") - - # Load model using the proper model ID - model = self.model_manager.get_yolo_model(self.model_id, model_file) - - if model: - self.branch_models[model_id] = model - self.stats['models_loaded'] += 1 - logger.info(f"Branch model {model_id} loaded successfully") - return model - else: - logger.error(f"Failed to load branch model {model_id}") - return None - - except Exception as e: - logger.error(f"Error loading branch model {getattr(branch_config, 'model_id', 'unknown')}: {e}") - return None - - def _parse_branch_output_fields(self, pipeline_config: Any) -> None: - """ - Parse parallelActions.fields to determine what output field each branch produces. - Creates dynamic mapping from branch_id to output field name. - - Example: - Input: parallelActions.fields = {"car_brand": "{car_brand_cls_v3.brand}"} - Output: self.branch_output_fields = {"car_brand_cls_v3": "brand"} - - Args: - pipeline_config: Pipeline configuration object - """ - try: - if not pipeline_config or not hasattr(pipeline_config, 'parallel_actions'): - logger.debug("[FIELD MAPPING] No parallelActions found in pipeline config") - return - - for action in pipeline_config.parallel_actions: - # Skip PostgreSQL actions - they are disabled - if action.type.value == 'postgresql_update_combined': - logger.debug(f"[FIELD MAPPING] Skipping PostgreSQL action (disabled)") - continue # Skip field parsing for disabled PostgreSQL operations - # fields = action.params.get('fields', {}) - # - # # Parse each field template to extract branch_id and field_name - # for db_field_name, template in fields.items(): - # # Template format: "{branch_id.field_name}" - # if template.startswith('{') and template.endswith('}'): - # var_name = template[1:-1] # Remove { } - # - # if '.' in var_name: - # branch_id, field_name = var_name.split('.', 1) - # - # # Store the mapping - # self.branch_output_fields[branch_id] = field_name - # - # logger.info(f"[FIELD MAPPING] Branch '{branch_id}' → outputs field '{field_name}'") - - logger.info(f"[FIELD MAPPING] Parsed {len(self.branch_output_fields)} branch output field mappings") - - except Exception as e: - logger.error(f"[FIELD MAPPING] Error parsing branch output fields: {e}", exc_info=True) - - async def execute_branches(self, - frame: np.ndarray, - branches: List[Any], - detected_regions: Dict[str, Any], - detection_context: Dict[str, Any]) -> Dict[str, Any]: - """ - Execute all branches in parallel and collect results. - - Args: - frame: Input frame - branches: List of branch configurations - detected_regions: Dictionary of detected regions from main detection - detection_context: Detection context data - - Returns: - Dictionary with branch execution results - """ - start_time = time.time() - branch_results = {} - - try: - # Separate parallel and sequential branches - parallel_branches = [] - sequential_branches = [] - - for branch in branches: - if getattr(branch, 'parallel', False): - parallel_branches.append(branch) - else: - sequential_branches.append(branch) - - # Execute parallel branches concurrently - if parallel_branches: - logger.info(f"Executing {len(parallel_branches)} branches in parallel") - parallel_results = await self._execute_parallel_branches( - frame, parallel_branches, detected_regions, detection_context - ) - branch_results.update(parallel_results) - self.stats['parallel_executions'] += 1 - - # Execute sequential branches one by one - if sequential_branches: - logger.info(f"Executing {len(sequential_branches)} branches sequentially") - sequential_results = await self._execute_sequential_branches( - frame, sequential_branches, detected_regions, detection_context - ) - branch_results.update(sequential_results) - - # Update statistics - self.stats['branches_processed'] += len(branches) - processing_time = time.time() - start_time - self.stats['total_processing_time'] += processing_time - - logger.info(f"Branch execution completed in {processing_time:.3f}s with {len(branch_results)} results") - - except Exception as e: - logger.error(f"Error in branch execution: {e}", exc_info=True) - - return branch_results - - async def _execute_parallel_branches(self, - frame: np.ndarray, - branches: List[Any], - detected_regions: Dict[str, Any], - detection_context: Dict[str, Any]) -> Dict[str, Any]: - """ - Execute branches in parallel using ThreadPoolExecutor. - - Args: - frame: Input frame - branches: List of parallel branch configurations - detected_regions: Dictionary of detected regions - detection_context: Detection context data - - Returns: - Dictionary with parallel branch results - """ - results = {} - - # Submit all branches for parallel execution - future_to_branch = {} - - for branch in branches: - branch_id = getattr(branch, 'model_id', 'unknown') - logger.info(f"[PARALLEL SUBMIT] {branch_id}: Submitting branch to thread pool") - - future = self.executor.submit( - self._execute_single_branch_sync, - frame, branch, detected_regions, detection_context - ) - future_to_branch[future] = branch - - # Collect results as they complete with timeout - try: - for future in as_completed(future_to_branch, timeout=self.branch_timeout): - branch = future_to_branch[future] - branch_id = getattr(branch, 'model_id', 'unknown') - - try: - # Get result with timeout to prevent indefinite hanging - result = future.result(timeout=self.branch_timeout) - results[branch_id] = result - logger.info(f"[PARALLEL COMPLETE] {branch_id}: Branch completed successfully") - except TimeoutError: - logger.error(f"[TIMEOUT] Branch {branch_id} exceeded timeout of {self.branch_timeout}s") - self.stats['branches_timed_out'] += 1 - results[branch_id] = { - 'status': 'timeout', - 'message': f'Branch execution timeout after {self.branch_timeout}s', - 'processing_time': self.branch_timeout - } - except Exception as e: - logger.error(f"[ERROR] Error in parallel branch {branch_id}: {e}", exc_info=True) - self.stats['branches_failed'] += 1 - results[branch_id] = { - 'status': 'error', - 'message': str(e), - 'processing_time': 0.0 - } - except TimeoutError: - # as_completed iterator timed out - mark remaining futures as timed out - logger.error(f"[TIMEOUT] Branch execution timeout after {self.branch_timeout}s - some branches did not complete") - for future, branch in future_to_branch.items(): - branch_id = getattr(branch, 'model_id', 'unknown') - if branch_id not in results: - logger.error(f"[TIMEOUT] Branch {branch_id} did not complete within timeout") - self.stats['branches_timed_out'] += 1 - results[branch_id] = { - 'status': 'timeout', - 'message': f'Branch did not complete within {self.branch_timeout}s timeout', - 'processing_time': self.branch_timeout - } - - # Flatten nested branch results to top level for database access - flattened_results = {} - for branch_id, branch_result in results.items(): - # Add the branch result itself - flattened_results[branch_id] = branch_result - - # If this branch has nested branches, add them to the top level too - if isinstance(branch_result, dict) and 'nested_branches' in branch_result: - nested_branches = branch_result['nested_branches'] - for nested_branch_id, nested_result in nested_branches.items(): - flattened_results[nested_branch_id] = nested_result - logger.info(f"[FLATTEN] Added nested branch {nested_branch_id} to top-level results") - - # Log summary of branch execution results - succeeded = [bid for bid, res in results.items() if res.get('status') == 'success'] - failed = [bid for bid, res in results.items() if res.get('status') == 'error'] - timed_out = [bid for bid, res in results.items() if res.get('status') == 'timeout'] - skipped = [bid for bid, res in results.items() if res.get('status') == 'skipped'] - - summary_parts = [] - if succeeded: - summary_parts.append(f"{len(succeeded)} succeeded: {', '.join(succeeded)}") - if failed: - summary_parts.append(f"{len(failed)} FAILED: {', '.join(failed)}") - if timed_out: - summary_parts.append(f"{len(timed_out)} TIMED OUT: {', '.join(timed_out)}") - if skipped: - summary_parts.append(f"{len(skipped)} skipped: {', '.join(skipped)}") - - logger.info(f"[PARALLEL SUMMARY] Branch execution completed: {' | '.join(summary_parts) if summary_parts else 'no branches'}") - - return flattened_results - - async def _execute_sequential_branches(self, - frame: np.ndarray, - branches: List[Any], - detected_regions: Dict[str, Any], - detection_context: Dict[str, Any]) -> Dict[str, Any]: - """ - Execute branches sequentially. - - Args: - frame: Input frame - branches: List of sequential branch configurations - detected_regions: Dictionary of detected regions - detection_context: Detection context data - - Returns: - Dictionary with sequential branch results - """ - results = {} - - for branch in branches: - branch_id = getattr(branch, 'model_id', 'unknown') - - try: - result = await asyncio.get_event_loop().run_in_executor( - self.executor, - self._execute_single_branch_sync, - frame, branch, detected_regions, detection_context - ) - results[branch_id] = result - logger.debug(f"Sequential branch {branch_id} completed successfully") - except Exception as e: - logger.error(f"Error in sequential branch {branch_id}: {e}") - results[branch_id] = { - 'status': 'error', - 'message': str(e), - 'processing_time': 0.0 - } - - # Flatten nested branch results to top level for database access - flattened_results = {} - for branch_id, branch_result in results.items(): - # Add the branch result itself - flattened_results[branch_id] = branch_result - - # If this branch has nested branches, add them to the top level too - if isinstance(branch_result, dict) and 'nested_branches' in branch_result: - nested_branches = branch_result['nested_branches'] - for nested_branch_id, nested_result in nested_branches.items(): - flattened_results[nested_branch_id] = nested_result - logger.info(f"[FLATTEN] Added nested branch {nested_branch_id} to top-level results") - - return flattened_results - - def _execute_single_branch_sync(self, - frame: np.ndarray, - branch_config: Any, - detected_regions: Dict[str, Any], - detection_context: Dict[str, Any]) -> Dict[str, Any]: - """ - Synchronous execution of a single branch (for ThreadPoolExecutor). - - Args: - frame: Input frame - branch_config: Branch configuration object - detected_regions: Dictionary of detected regions - detection_context: Detection context data - - Returns: - Dictionary with branch execution result - """ - start_time = time.time() - branch_id = getattr(branch_config, 'model_id', 'unknown') - - logger.info(f"[BRANCH START] {branch_id}: Starting branch execution") - logger.debug(f"[BRANCH CONFIG] {branch_id}: crop={getattr(branch_config, 'crop', False)}, " - f"trigger_classes={getattr(branch_config, 'trigger_classes', [])}, " - f"min_confidence={getattr(branch_config, 'min_confidence', 0.6)}") - - # Check if branch should execute based on triggerClasses (execution conditions) - trigger_classes = getattr(branch_config, 'trigger_classes', []) - logger.info(f"[DETECTED REGIONS] {branch_id}: Available parent detections: {list(detected_regions.keys())}") - for region_name, region_data in detected_regions.items(): - # Handle both list (new) and single dict (backward compat) - if isinstance(region_data, list): - for i, region in enumerate(region_data): - logger.debug(f"[REGION DATA] {branch_id}: '{region_name}[{i}]' -> bbox={region.get('bbox')}, conf={region.get('confidence')}") - else: - logger.debug(f"[REGION DATA] {branch_id}: '{region_name}' -> bbox={region_data.get('bbox')}, conf={region_data.get('confidence')}") - - if trigger_classes: - # Check if any parent detection matches our trigger classes (case-insensitive) - should_execute = False - for trigger_class in trigger_classes: - # Case-insensitive comparison for robustness - if trigger_class.lower() in [k.lower() for k in detected_regions.keys()]: - should_execute = True - logger.info(f"[TRIGGER CHECK] {branch_id}: Found '{trigger_class}' in parent detections - branch will execute") - break - - if not should_execute: - logger.warning(f"[TRIGGER CHECK] {branch_id}: None of trigger classes {trigger_classes} found in parent detections {list(detected_regions.keys())} - skipping branch") - return { - 'status': 'skipped', - 'branch_id': branch_id, - 'message': f'No trigger classes {trigger_classes} found in parent detections', - 'processing_time': time.time() - start_time - } - - result = { - 'status': 'success', - 'branch_id': branch_id, - 'result': {}, - 'processing_time': 0.0, - 'timestamp': time.time() - } - - try: - # Get or load branch model - if branch_id not in self.branch_models: - logger.warning(f"Branch model {branch_id} not preloaded, loading now...") - # This should be rare since models are preloaded - return { - 'status': 'error', - 'message': f'Branch model {branch_id} not available', - 'processing_time': time.time() - start_time - } - - model = self.branch_models[branch_id] - - # Get configuration values first - min_confidence = getattr(branch_config, 'min_confidence', 0.6) - - # Prepare input frame for this branch - input_frame = frame - - # Handle cropping if required - use biggest bbox that passes min_confidence - if getattr(branch_config, 'crop', False): - crop_classes = getattr(branch_config, 'crop_class', []) - if isinstance(crop_classes, str): - crop_classes = [crop_classes] - - # Find the biggest bbox that passes min_confidence threshold - best_region = None - best_class = None - best_area = 0.0 - - for crop_class in crop_classes: - if crop_class in detected_regions: - regions = detected_regions[crop_class] - - # Handle both list (new) and single dict (backward compat) - if not isinstance(regions, list): - regions = [regions] - - # Find largest bbox from all detections of this class - for region in regions: - confidence = region.get('confidence', 0.0) - bbox = region['bbox'] - area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) # width * height - - # Choose biggest bbox among all available detections - if area > best_area: - best_region = region - best_class = crop_class - best_area = area - logger.debug(f"[CROP] Selected larger bbox for '{crop_class}': area={area:.0f}px², conf={confidence:.3f}") - - if best_region: - bbox = best_region['bbox'] - x1, y1, x2, y2 = [int(coord) for coord in bbox] - cropped = frame[y1:y2, x1:x2] - if cropped.size > 0: - input_frame = cropped - confidence = best_region.get('confidence', 0.0) - logger.info(f"[CROP SUCCESS] {branch_id}: cropped '{best_class}' region (conf={confidence:.3f}, area={int(best_area)}) -> shape={cropped.shape}") - else: - logger.warning(f"Branch {branch_id}: empty crop, using full frame") - else: - logger.warning(f"Branch {branch_id}: no valid crop regions found (min_conf={min_confidence})") - - logger.info(f"[INFERENCE START] {branch_id}: Running inference on {'cropped' if input_frame is not frame else 'full'} frame " - f"({input_frame.shape[1]}x{input_frame.shape[0]}) with confidence={min_confidence}") - - # Use .predict() method for both detection and classification models - inference_start = time.time() - try: - detection_results = model.model.predict(input_frame, conf=min_confidence, verbose=False) - inference_time = time.time() - inference_start - logger.info(f"[INFERENCE DONE] {branch_id}: Predict completed in {inference_time:.3f}s using .predict() method") - except Exception as inference_error: - inference_time = time.time() - inference_start - logger.error(f"[INFERENCE ERROR] {branch_id}: Model inference failed after {inference_time:.3f}s: {inference_error}", exc_info=True) - return { - 'status': 'error', - 'branch_id': branch_id, - 'message': f'Model inference failed: {str(inference_error)}', - 'processing_time': time.time() - start_time - } - - # Initialize branch_detections outside the conditional - branch_detections = [] - - # Process results using clean, unified logic - if detection_results and len(detection_results) > 0: - result_obj = detection_results[0] - - # Handle detection models (have .boxes attribute) - if hasattr(result_obj, 'boxes') and result_obj.boxes is not None: - logger.info(f"[RAW DETECTIONS] {branch_id}: Found {len(result_obj.boxes)} raw detections") - - for i, box in enumerate(result_obj.boxes): - class_id = int(box.cls[0]) - confidence = float(box.conf[0]) - bbox = box.xyxy[0].cpu().numpy().tolist() # [x1, y1, x2, y2] - class_name = model.model.names[class_id] - - logger.debug(f"[RAW DETECTION {i+1}] {branch_id}: '{class_name}', conf={confidence:.3f}") - - # All detections are included - no filtering by trigger_classes here - branch_detections.append({ - 'class_name': class_name, - 'confidence': confidence, - 'bbox': bbox - }) - - # Handle classification models (have .probs attribute) - elif hasattr(result_obj, 'probs') and result_obj.probs is not None: - logger.info(f"[RAW CLASSIFICATION] {branch_id}: Processing classification results") - - probs = result_obj.probs - top_indices = probs.top5 # Get top 5 predictions - top_conf = probs.top5conf.cpu().numpy() - - # For classification: take only TOP-1 prediction (not all top-5) - # This prevents empty results when all top-5 predictions are below threshold - if len(top_indices) > 0 and len(top_conf) > 0: - top_idx = top_indices[0] - top_confidence = float(top_conf[0]) - - # Apply minConfidence threshold to top-1 only - if top_confidence >= min_confidence: - class_name = model.model.names[int(top_idx)] - logger.info(f"[CLASSIFICATION TOP-1] {branch_id}: '{class_name}', conf={top_confidence:.3f}") - - # For classification, use full input frame dimensions as bbox - branch_detections.append({ - 'class_name': class_name, - 'confidence': top_confidence, - 'bbox': [0, 0, input_frame.shape[1], input_frame.shape[0]] - }) - else: - logger.warning(f"[CLASSIFICATION FILTERED] {branch_id}: Top prediction conf={top_confidence:.3f} < threshold={min_confidence}") - else: - logger.warning(f"[UNKNOWN MODEL] {branch_id}: Model results have no .boxes or .probs") - - result['result'] = { - 'detections': branch_detections, - 'detection_count': len(branch_detections) - } - - logger.info(f"[FINAL RESULTS] {branch_id}: {len(branch_detections)} detections processed") - - # Determine output field name from dynamic mapping (parsed from parallelActions.fields) - output_field = self.branch_output_fields.get(branch_id) - - # Always initialize the field (even if None) to ensure it exists for database update - if output_field: - result['result'][output_field] = None - logger.debug(f"[FIELD INIT] {branch_id}: Initialized field '{output_field}' = None") - - # Extract best detection if available - if branch_detections: - best_detection = max(branch_detections, key=lambda x: x['confidence']) - logger.info(f"[BEST DETECTION] {branch_id}: '{best_detection['class_name']}' with confidence {best_detection['confidence']:.3f}") - - # Set the output field value using dynamic mapping - if output_field: - result['result'][output_field] = best_detection['class_name'] - logger.info(f"[FIELD SET] {branch_id}: Set field '{output_field}' = '{best_detection['class_name']}'") - else: - logger.warning(f"[NO MAPPING] {branch_id}: No output field defined in parallelActions.fields") - else: - logger.warning(f"[NO RESULTS] {branch_id}: No detections found, field '{output_field}' remains None") - - # Execute branch actions if this branch found valid detections - actions_executed = [] - branch_actions = getattr(branch_config, 'actions', []) - if branch_actions and branch_detections: - logger.info(f"[BRANCH ACTIONS] {branch_id}: Executing {len(branch_actions)} actions") - - # Create detected_regions from THIS branch's detections for actions - branch_detected_regions = {} - for detection in branch_detections: - branch_detected_regions[detection['class_name']] = { - 'bbox': detection['bbox'], - 'confidence': detection['confidence'] - } - - for action in branch_actions: - try: - action_type = action.type.value # Access the enum value - logger.info(f"[ACTION EXECUTE] {branch_id}: Executing action '{action_type}'") - - if action_type == 'redis_save_image': - action_result = self._execute_redis_save_image_sync( - action, input_frame, branch_detected_regions, detection_context - ) - elif action_type == 'redis_publish': - action_result = self._execute_redis_publish_sync( - action, detection_context - ) - else: - logger.warning(f"[ACTION UNKNOWN] {branch_id}: Unknown action type '{action_type}'") - action_result = {'status': 'error', 'message': f'Unknown action type: {action_type}'} - - actions_executed.append({ - 'action_type': action_type, - 'result': action_result - }) - - logger.info(f"[ACTION COMPLETE] {branch_id}: Action '{action_type}' result: {action_result.get('status')}") - - except Exception as e: - action_type = getattr(action, 'type', None) - if action_type: - action_type = action_type.value if hasattr(action_type, 'value') else str(action_type) - logger.error(f"[ACTION ERROR] {branch_id}: Error executing action '{action_type}': {e}", exc_info=True) - actions_executed.append({ - 'action_type': action_type, - 'result': {'status': 'error', 'message': str(e)} - }) - - # Add actions executed to result - if actions_executed: - result['actions_executed'] = actions_executed - - # Handle nested branches ONLY if parent found valid detections - nested_branches = getattr(branch_config, 'branches', []) - if nested_branches: - # Check if parent branch found any valid detections - if not branch_detections: - logger.warning(f"[BRANCH SKIP] {branch_id}: Skipping {len(nested_branches)} nested branches - parent found no valid detections") - else: - logger.debug(f"Branch {branch_id}: executing {len(nested_branches)} nested branches") - - # Create detected_regions from THIS branch's detections for nested branches - # Nested branches should see their immediate parent's detections, not the root pipeline - nested_detected_regions = {} - for detection in branch_detections: - nested_detected_regions[detection['class_name']] = { - 'bbox': detection['bbox'], - 'confidence': detection['confidence'] - } - - logger.info(f"[NESTED REGIONS] {branch_id}: Passing {list(nested_detected_regions.keys())} to nested branches") - - # Note: For simplicity, nested branches are executed sequentially in this sync method - # In a full async implementation, these could also be parallelized - nested_results = {} - for nested_branch in nested_branches: - nested_result = self._execute_single_branch_sync( - input_frame, nested_branch, nested_detected_regions, detection_context - ) - nested_branch_id = getattr(nested_branch, 'model_id', 'unknown') - nested_results[nested_branch_id] = nested_result - - result['nested_branches'] = nested_results - - except Exception as e: - logger.error(f"[BRANCH ERROR] {branch_id}: Error in execution: {e}", exc_info=True) - result['status'] = 'error' - result['message'] = str(e) - - result['processing_time'] = time.time() - start_time - - # Summary log - logger.info(f"[BRANCH COMPLETE] {branch_id}: status={result['status']}, " - f"processing_time={result['processing_time']:.3f}s, " - f"result_keys={list(result['result'].keys()) if result['result'] else 'none'}") - - return result - - def _execute_redis_save_image_sync(self, - action: Dict, - frame: np.ndarray, - detected_regions: Dict[str, Any], - context: Dict[str, Any]) -> Dict[str, Any]: - """Execute redis_save_image action synchronously.""" - if not self.redis_manager: - return {'status': 'error', 'message': 'Redis not available'} - - try: - # Get image to save (cropped or full frame) - image_to_save = frame - region_name = action.params.get('region') - - bbox = None - if region_name and region_name in detected_regions: - # Crop the specified region - # Handle both list (new) and single dict (backward compat) - regions = detected_regions[region_name] - if isinstance(regions, list): - # Multiple detections - select largest bbox - if regions: - best_region = max(regions, key=lambda r: (r['bbox'][2] - r['bbox'][0]) * (r['bbox'][3] - r['bbox'][1])) - bbox = best_region['bbox'] - else: - bbox = regions['bbox'] - elif region_name and region_name.lower() == 'frontal' and 'front_rear' in detected_regions: - # Special case: "frontal" region maps to "front_rear" detection - # Handle both list (new) and single dict (backward compat) - regions = detected_regions['front_rear'] - if isinstance(regions, list): - # Multiple detections - select largest bbox - if regions: - best_region = max(regions, key=lambda r: (r['bbox'][2] - r['bbox'][0]) * (r['bbox'][3] - r['bbox'][1])) - bbox = best_region['bbox'] - else: - bbox = regions['bbox'] - - if bbox is not None: - x1, y1, x2, y2 = [int(coord) for coord in bbox] - cropped = frame[y1:y2, x1:x2] - if cropped.size > 0: - image_to_save = cropped - logger.debug(f"Cropped region '{region_name}' for redis_save_image") - else: - logger.warning(f"Empty crop for region '{region_name}', using full frame") - - # Format key with context - key = action.params['key'].format(**context) - - # Convert image to bytes - import cv2 - image_format = action.params.get('format', 'jpeg') - quality = action.params.get('quality', 90) - - if image_format.lower() == 'jpeg': - encode_param = [cv2.IMWRITE_JPEG_QUALITY, quality] - _, image_bytes = cv2.imencode('.jpg', image_to_save, encode_param) - else: - _, image_bytes = cv2.imencode('.png', image_to_save) - - # Save to Redis synchronously using a sync Redis client - try: - import redis - import cv2 - - # Create a synchronous Redis client with same connection details - sync_redis = redis.Redis( - host=self.redis_manager.host, - port=self.redis_manager.port, - password=self.redis_manager.password, - db=self.redis_manager.db, - decode_responses=False, # We're storing binary data - socket_timeout=self.redis_manager.socket_timeout, - socket_connect_timeout=self.redis_manager.socket_connect_timeout - ) - - # Encode the image - if image_format.lower() == 'jpeg': - encode_param = [cv2.IMWRITE_JPEG_QUALITY, quality] - success, encoded_image = cv2.imencode('.jpg', image_to_save, encode_param) - else: - success, encoded_image = cv2.imencode('.png', image_to_save) - - if not success: - return {'status': 'error', 'message': 'Failed to encode image'} - - # Save to Redis with expiration - expire_seconds = action.params.get('expire_seconds', 600) - result = sync_redis.setex(key, expire_seconds, encoded_image.tobytes()) - - sync_redis.close() # Clean up connection - - if result: - # Add image_key to context for subsequent actions - context['image_key'] = key - return {'status': 'success', 'key': key} - else: - return {'status': 'error', 'message': 'Failed to save image to Redis'} - - except Exception as redis_error: - logger.error(f"Error calling Redis from sync context: {redis_error}") - return {'status': 'error', 'message': f'Redis operation failed: {redis_error}'} - - except Exception as e: - logger.error(f"Error in redis_save_image action: {e}", exc_info=True) - return {'status': 'error', 'message': str(e)} - - def _execute_redis_publish_sync(self, action: Dict, context: Dict[str, Any]) -> Dict[str, Any]: - """Execute redis_publish action synchronously.""" - if not self.redis_manager: - return {'status': 'error', 'message': 'Redis not available'} - - try: - channel = action.params['channel'] - message_template = action.params['message'] - - # Debug the message template - logger.debug(f"Message template: {repr(message_template)}") - logger.debug(f"Context keys: {list(context.keys())}") - - # Format message with context - handle JSON string formatting carefully - # The message template contains JSON which causes issues with .format() - # Use string replacement instead of format to avoid JSON brace conflicts - try: - # Ensure image_key is available for message formatting - if 'image_key' not in context: - context['image_key'] = '' # Default empty value if redis_save_image failed - - # Use string replacement to avoid JSON formatting issues - message = message_template - for key, value in context.items(): - placeholder = '{' + key + '}' - message = message.replace(placeholder, str(value)) - - logger.debug(f"Formatted message using replacement: {message}") - except Exception as e: - logger.error(f"Message formatting failed: {e}") - logger.error(f"Template: {repr(message_template)}") - logger.error(f"Context: {context}") - return {'status': 'error', 'message': f'Message formatting failed: {e}'} - - # Publish message synchronously using a sync Redis client - try: - import redis - - # Create a synchronous Redis client with same connection details - sync_redis = redis.Redis( - host=self.redis_manager.host, - port=self.redis_manager.port, - password=self.redis_manager.password, - db=self.redis_manager.db, - decode_responses=True, # For publishing text messages - socket_timeout=self.redis_manager.socket_timeout, - socket_connect_timeout=self.redis_manager.socket_connect_timeout - ) - - # Publish message - result = sync_redis.publish(channel, message) - sync_redis.close() # Clean up connection - - if result >= 0: # Redis publish returns number of subscribers - return {'status': 'success', 'subscribers': result, 'channel': channel} - else: - return {'status': 'error', 'message': 'Failed to publish message to Redis'} - - except Exception as redis_error: - logger.error(f"Error calling Redis from sync context: {redis_error}") - return {'status': 'error', 'message': f'Redis operation failed: {redis_error}'} - - except Exception as e: - logger.error(f"Error in redis_publish action: {e}", exc_info=True) - return {'status': 'error', 'message': str(e)} - - def get_statistics(self) -> Dict[str, Any]: - """Get branch processor statistics.""" - return { - **self.stats, - 'loaded_models': list(self.branch_models.keys()), - 'model_count': len(self.branch_models) - } - - def cleanup(self): - """Cleanup resources.""" - if self.executor: - self.executor.shutdown(wait=False) - - # Clear model cache - self.branch_models.clear() - - logger.info("BranchProcessor cleaned up") \ No newline at end of file diff --git a/core/detection/pipeline.py b/core/detection/pipeline.py deleted file mode 100644 index 26654cc..0000000 --- a/core/detection/pipeline.py +++ /dev/null @@ -1,1310 +0,0 @@ -""" -Detection Pipeline Module. -Main detection pipeline orchestration that coordinates detection flow and execution. -""" -import asyncio -import logging -import time -import uuid -from datetime import datetime -from typing import Dict, List, Optional, Any -from concurrent.futures import ThreadPoolExecutor -import numpy as np - -from ..models.inference import YOLOWrapper -from ..models.pipeline import PipelineParser -from .branches import BranchProcessor -from ..storage.redis import RedisManager -# from ..storage.database import DatabaseManager # Disabled - PostgreSQL moved to microservices -from ..storage.license_plate import LicensePlateManager - -logger = logging.getLogger(__name__) - - -class DetectionPipeline: - """ - Main detection pipeline that orchestrates the complete detection flow. - Handles detection execution, branch coordination, and result aggregation. - """ - - def __init__(self, pipeline_parser: PipelineParser, model_manager: Any, model_id: int, message_sender=None): - """ - Initialize detection pipeline. - - Args: - pipeline_parser: Pipeline parser with loaded configuration - model_manager: Model manager for loading models - model_id: The model ID to use for loading models - message_sender: Optional callback function for sending WebSocket messages - """ - self.pipeline_parser = pipeline_parser - self.model_manager = model_manager - self.model_id = model_id - self.message_sender = message_sender - - # Initialize components - self.branch_processor = BranchProcessor(model_manager, model_id) - self.redis_manager = None - # self.db_manager = None # Disabled - PostgreSQL operations moved to microservices - self.license_plate_manager = None - - # Main detection model - self.detection_model: Optional[YOLOWrapper] = None - self.detection_model_id = None - - # Thread pool for parallel processing - self.executor = ThreadPoolExecutor(max_workers=4) - - # Pipeline configuration - self.pipeline_config = pipeline_parser.pipeline_config - - # SessionId to subscriptionIdentifier mapping - self.session_to_subscription = {} - - # SessionId to processing results mapping (for combining with license plate results) - self.session_processing_results = {} - - # Field mappings from parallelActions (e.g., {"car_brand": "{car_brand_cls_v3.brand}"}) - self.field_mappings = {} - self._parse_field_mappings() - - # Statistics - self.stats = { - 'detections_processed': 0, - 'branches_executed': 0, - 'actions_executed': 0, - 'total_processing_time': 0.0 - } - - logger.info("DetectionPipeline initialized") - - def _parse_field_mappings(self): - """ - Parse field mappings from parallelActions.postgresql_update_combined.fields. - Extracts mappings like {"car_brand": "{car_brand_cls_v3.brand}"} for dynamic field resolution. - """ - try: - if not self.pipeline_config or not hasattr(self.pipeline_config, 'parallel_actions'): - return - - for action in self.pipeline_config.parallel_actions: - if action.type.value == 'postgresql_update_combined': - fields = action.params.get('fields', {}) - self.field_mappings = fields - logger.info(f"[FIELD MAPPINGS] Parsed from pipeline config: {self.field_mappings}") - break - - except Exception as e: - logger.error(f"Error parsing field mappings: {e}", exc_info=True) - - async def initialize(self) -> bool: - """ - Initialize all pipeline components including models, Redis, and database. - - Returns: - True if successful, False otherwise - """ - try: - # Initialize Redis connection - if self.pipeline_parser.redis_config: - self.redis_manager = RedisManager(self.pipeline_parser.redis_config.__dict__) - if not await self.redis_manager.initialize(): - logger.error("Failed to initialize Redis connection") - return False - logger.info("Redis connection initialized") - - # PostgreSQL database connection DISABLED - operations moved to microservices - # Database operations are now handled by backend services via WebSocket - # if self.pipeline_parser.postgresql_config: - # self.db_manager = DatabaseManager(self.pipeline_parser.postgresql_config.__dict__) - # if not self.db_manager.connect(): - # logger.error("Failed to initialize database connection") - # return False - # # Create required tables - # if not self.db_manager.create_car_frontal_info_table(): - # logger.warning("Failed to create car_frontal_info table") - # logger.info("Database connection initialized") - logger.info("PostgreSQL operations disabled - using WebSocket for data communication") - - # Initialize license plate manager (using same Redis config as main Redis manager) - if self.pipeline_parser.redis_config: - self.license_plate_manager = LicensePlateManager(self.pipeline_parser.redis_config.__dict__) - if not await self.license_plate_manager.initialize(self._on_license_plate_result): - logger.error("Failed to initialize license plate manager") - return False - logger.info("License plate manager initialized") - - - # Initialize main detection model - if not await self._initialize_detection_model(): - logger.error("Failed to initialize detection model") - return False - - # Initialize branch processor (db_manager=None since PostgreSQL is disabled) - if not await self.branch_processor.initialize( - self.pipeline_config, - self.redis_manager, - db_manager=None # PostgreSQL disabled - ): - logger.error("Failed to initialize branch processor") - return False - - logger.info("Detection pipeline initialization completed successfully") - return True - - except Exception as e: - logger.error(f"Error initializing detection pipeline: {e}", exc_info=True) - return False - - async def _initialize_detection_model(self) -> bool: - """ - Load and initialize the main detection model. - - Returns: - True if successful, False otherwise - """ - try: - if not self.pipeline_config: - logger.warning("No pipeline configuration found") - return False - - model_file = getattr(self.pipeline_config, 'model_file', None) - model_id = getattr(self.pipeline_config, 'model_id', None) - - if not model_file: - logger.warning("No detection model file specified") - return False - - # Load detection model - logger.info(f"Loading detection model: {model_id} ({model_file})") - self.detection_model = self.model_manager.get_yolo_model(self.model_id, model_file) - if not self.detection_model: - logger.error(f"Failed to load detection model {model_file} from model {self.model_id}") - return False - - self.detection_model_id = model_id - logger.info(f"Detection model {model_id} loaded successfully") - return True - - except Exception as e: - logger.error(f"Error initializing detection model: {e}", exc_info=True) - return False - - def _extract_fields_from_branches(self, branch_results: Dict[str, Any]) -> Dict[str, Any]: - """ - Extract fields dynamically from branch results using field mappings. - - Args: - branch_results: Dictionary of branch execution results - - Returns: - Dictionary with extracted field values (e.g., {"car_brand": "Honda", "body_type": "Sedan"}) - """ - extracted = {} - missing_fields = [] - available_fields = [] - - try: - for db_field_name, template in self.field_mappings.items(): - # Parse template like "{car_brand_cls_v3.brand}" -> branch_id="car_brand_cls_v3", field="brand" - if template.startswith('{') and template.endswith('}'): - var_name = template[1:-1] - if '.' in var_name: - branch_id, field_name = var_name.split('.', 1) - - # Look up value in branch_results - if branch_id in branch_results: - branch_data = branch_results[branch_id] - if isinstance(branch_data, dict) and 'result' in branch_data: - result_data = branch_data['result'] - if isinstance(result_data, dict) and field_name in result_data: - extracted[field_name] = result_data[field_name] - available_fields.append(f"{field_name}={result_data[field_name]}") - logger.debug(f"[DYNAMIC EXTRACT] {field_name}={result_data[field_name]} from branch {branch_id}") - else: - missing_fields.append(f"{field_name} (field not in branch {branch_id})") - logger.debug(f"[DYNAMIC EXTRACT] Field '{field_name}' not found in branch {branch_id}") - else: - missing_fields.append(f"{field_name} (branch {branch_id} missing)") - logger.debug(f"[DYNAMIC EXTRACT] Branch '{branch_id}' not in results") - - # Log summary of extraction - if available_fields: - logger.info(f"[FIELD EXTRACTION] Available fields: {', '.join(available_fields)}") - if missing_fields: - logger.warning(f"[FIELD EXTRACTION] Missing fields (will be null): {', '.join(missing_fields)}") - - except Exception as e: - logger.error(f"Error extracting fields from branches: {e}", exc_info=True) - - return extracted - - async def _on_license_plate_result(self, session_id: str, license_data: Dict[str, Any]): - """ - Callback for handling license plate results from LPR service. - - Args: - session_id: Session identifier - license_data: License plate data including text and confidence - """ - try: - license_text = license_data.get('license_plate_text', '') - confidence = license_data.get('confidence', 0.0) - - logger.info(f"[LICENSE PLATE CALLBACK] Session {session_id}: " - f"text='{license_text}', confidence={confidence:.3f}") - - # Find matching subscriptionIdentifier for this sessionId - subscription_id = self.session_to_subscription.get(session_id) - - if not subscription_id: - logger.warning(f"[LICENSE PLATE] No subscription found for sessionId '{session_id}' (type: {type(session_id)}), cannot send imageDetection") - logger.warning(f"[LICENSE PLATE DEBUG] Current session mappings: {dict(self.session_to_subscription)}") - - # Try to find by type conversion in case of type mismatch - # Try as integer if session_id is string - if isinstance(session_id, str) and session_id.isdigit(): - session_id_int = int(session_id) - subscription_id = self.session_to_subscription.get(session_id_int) - if subscription_id: - logger.info(f"[LICENSE PLATE] Found subscription using int conversion: '{session_id}' -> {session_id_int} -> '{subscription_id}'") - else: - logger.error(f"[LICENSE PLATE] Failed to find subscription with int conversion") - return - # Try as string if session_id is integer - elif isinstance(session_id, int): - session_id_str = str(session_id) - subscription_id = self.session_to_subscription.get(session_id_str) - if subscription_id: - logger.info(f"[LICENSE PLATE] Found subscription using string conversion: {session_id} -> '{session_id_str}' -> '{subscription_id}'") - else: - logger.error(f"[LICENSE PLATE] Failed to find subscription with string conversion") - return - else: - logger.error(f"[LICENSE PLATE] Failed to find subscription with any type conversion") - return - - # Send imageDetection message with license plate data combined with processing results - # This is the PRIMARY data flow to backend - WebSocket is critical, keep this! - await self._send_license_plate_message(subscription_id, license_text, confidence, session_id) - - # PostgreSQL database update DISABLED - backend handles data via WebSocket messages - # if self.db_manager and license_text: - # success = self.db_manager.execute_update( - # table='car_frontal_info', - # key_field='session_id', - # key_value=session_id, - # fields={ - # 'license_character': license_text, - # 'license_type': 'LPR_detected' # Mark as detected by LPR service - # } - # ) - # if success: - # logger.info(f"[LICENSE PLATE] Updated database for session {session_id}") - # else: - # logger.warning(f"[LICENSE PLATE] Failed to update database for session {session_id}") - logger.debug(f"[LICENSE PLATE] Data sent via WebSocket for session {session_id}") - - except Exception as e: - logger.error(f"Error in license plate result callback: {e}", exc_info=True) - - - async def _send_license_plate_message(self, subscription_id: str, license_text: str, confidence: float, session_id: str = None): - """ - Send imageDetection message with license plate data plus any available processing results. - - Args: - subscription_id: Subscription identifier to send message to - license_text: License plate text - confidence: License plate confidence score - session_id: Session identifier for looking up processing results - """ - try: - if not self.message_sender: - logger.warning("No message sender configured, cannot send imageDetection") - return - - # Import here to avoid circular imports - from ..communication.models import ImageDetectionMessage, DetectionData - - # Get processing results for this session from stored results - car_brand = None - body_type = None - - # Find session_id from session mappings (we need session_id as key) - session_id_for_lookup = None - - # Try direct lookup first (if session_id is already the right type) - if session_id in self.session_processing_results: - session_id_for_lookup = session_id - else: - # Try to find by type conversion - for stored_session_id in self.session_processing_results.keys(): - if str(stored_session_id) == str(session_id): - session_id_for_lookup = stored_session_id - break - - if session_id_for_lookup and session_id_for_lookup in self.session_processing_results: - branch_results = self.session_processing_results[session_id_for_lookup] - logger.info(f"[LICENSE PLATE] Retrieved processing results for session {session_id_for_lookup}") - - # Extract fields dynamically using field mappings from pipeline config - extracted_fields = self._extract_fields_from_branches(branch_results) - car_brand = extracted_fields.get('brand') - body_type = extracted_fields.get('body_type') - - # Log extraction results - fields_status = [] - if car_brand is not None: - fields_status.append(f"brand={car_brand}") - else: - fields_status.append("brand=null") - if body_type is not None: - fields_status.append(f"bodyType={body_type}") - else: - fields_status.append("bodyType=null") - logger.info(f"[LICENSE PLATE] Extracted fields: {', '.join(fields_status)}") - - # Clean up stored results after use - del self.session_processing_results[session_id_for_lookup] - logger.debug(f"[LICENSE PLATE] Cleaned up stored results for session {session_id_for_lookup}") - else: - logger.warning(f"[LICENSE PLATE] No processing results found for session {session_id}") - - # Create detection data with combined information - detection_data_obj = DetectionData( - detection={ - "carBrand": car_brand, - "carModel": None, - "bodyType": body_type, - "licensePlateText": license_text, - "licensePlateConfidence": confidence - }, - modelId=self.model_id, - modelName=self.pipeline_parser.pipeline_config.model_id if self.pipeline_parser.pipeline_config else "detection_model" - ) - - # Create imageDetection message - detection_message = ImageDetectionMessage( - subscriptionIdentifier=subscription_id, - data=detection_data_obj - ) - - # Send message - await self.message_sender(detection_message) - - # Log with indication of partial results - null_fields = [] - if car_brand is None: - null_fields.append('brand') - if body_type is None: - null_fields.append('bodyType') - - if null_fields: - logger.info(f"[COMBINED MESSAGE] Sent imageDetection with PARTIAL results (null: {', '.join(null_fields)}) - brand='{car_brand}', bodyType='{body_type}', license='{license_text}' to '{subscription_id}'") - else: - logger.info(f"[COMBINED MESSAGE] Sent imageDetection with brand='{car_brand}', bodyType='{body_type}', license='{license_text}' to '{subscription_id}'") - - except Exception as e: - logger.error(f"Error sending license plate imageDetection message: {e}", exc_info=True) - - async def _send_initial_detection_message(self, subscription_id: str): - """ - Send initial imageDetection message when vehicle is first detected. - - Args: - subscription_id: Subscription identifier to send message to - """ - try: - if not self.message_sender: - logger.warning("No message sender configured, cannot send imageDetection") - return - - # Import here to avoid circular imports - from ..communication.models import ImageDetectionMessage, DetectionData - - # Create detection data with all fields as None (vehicle just detected, no classification yet) - detection_data_obj = DetectionData( - detection={ - "carBrand": None, - "carModel": None, - "bodyType": None, - "licensePlateText": None, - "licensePlateConfidence": None - }, - modelId=self.model_id, - modelName=self.pipeline_parser.pipeline_config.model_id if self.pipeline_parser.pipeline_config else "detection_model" - ) - - # Create imageDetection message - detection_message = ImageDetectionMessage( - subscriptionIdentifier=subscription_id, - data=detection_data_obj - ) - - # Send message - await self.message_sender(detection_message) - logger.info(f"[INITIAL DETECTION] Sent imageDetection for vehicle detection to '{subscription_id}'") - - except Exception as e: - logger.error(f"Error sending initial detection imageDetection message: {e}", exc_info=True) - - async def _send_classification_results(self, subscription_id: str, session_id: str, branch_results: Dict[str, Any]): - """ - Send imageDetection message with classification results (without license plate). - Called after processing phase completes to send partial results immediately. - - Args: - subscription_id: Subscription identifier to send message to - session_id: Session identifier - branch_results: Dictionary of branch execution results - """ - try: - if not self.message_sender: - logger.warning("No message sender configured, cannot send imageDetection") - return - - # Import here to avoid circular imports - from ..communication.models import ImageDetectionMessage, DetectionData - - # Extract classification fields from branch results - extracted_fields = self._extract_fields_from_branches(branch_results) - car_brand = extracted_fields.get('brand') - body_type = extracted_fields.get('body_type') - - # Log what we're sending - fields_status = [] - if car_brand is not None: - fields_status.append(f"brand={car_brand}") - else: - fields_status.append("brand=null") - if body_type is not None: - fields_status.append(f"bodyType={body_type}") - else: - fields_status.append("bodyType=null") - logger.info(f"[CLASSIFICATION] Sending partial results for session {session_id}: {', '.join(fields_status)}") - - # Create detection data with classification results (license plate still pending) - detection_data_obj = DetectionData( - detection={ - "carBrand": car_brand, - "carModel": None, # Not implemented yet - "bodyType": body_type, - "licensePlateText": None, # Will be sent later via license plate callback - "licensePlateConfidence": None - }, - modelId=self.model_id, - modelName=self.pipeline_parser.pipeline_config.model_id if self.pipeline_parser.pipeline_config else "detection_model" - ) - - # Create imageDetection message - detection_message = ImageDetectionMessage( - subscriptionIdentifier=subscription_id, - data=detection_data_obj - ) - - # Send message - await self.message_sender(detection_message) - - # Log with indication of partial results - null_fields = [] - if car_brand is None: - null_fields.append('brand') - if body_type is None: - null_fields.append('bodyType') - - if null_fields: - logger.info(f"[PARTIAL RESULTS] Sent imageDetection with PARTIAL results (null: {', '.join(null_fields)}) - brand='{car_brand}', bodyType='{body_type}' to '{subscription_id}'") - else: - logger.info(f"[CLASSIFICATION COMPLETE] Sent imageDetection with brand='{car_brand}', bodyType='{body_type}' to '{subscription_id}'") - - except Exception as e: - logger.error(f"Error sending classification results imageDetection message: {e}", exc_info=True) - - async def execute_detection_phase(self, - frame: np.ndarray, - display_id: str, - subscription_id: str) -> Dict[str, Any]: - """ - Execute only the detection phase - run main detection and send imageDetection message. - This is the first phase that runs when a vehicle is validated. - - Args: - frame: Input frame to process - display_id: Display identifier - subscription_id: Subscription identifier - - Returns: - Dictionary with detection phase results - """ - start_time = time.time() - result = { - 'status': 'success', - 'detections': [], - 'message_sent': False, - 'processing_time': 0.0, - 'timestamp': datetime.now().isoformat() - } - - try: - # Run main detection model - if not self.detection_model: - result['status'] = 'error' - result['message'] = 'Detection model not available' - return result - - # Create detection context - detection_context = { - 'display_id': display_id, - 'subscription_id': subscription_id, - 'timestamp': datetime.now().strftime("%Y-%m-%dT%H-%M-%S"), - 'timestamp_ms': int(time.time() * 1000) - } - - # Run inference on single snapshot using .predict() method - detection_results = self.detection_model.model.predict( - frame, - conf=getattr(self.pipeline_config, 'min_confidence', 0.6), - verbose=False - ) - - # Process detection results using clean logic - valid_detections = [] - detected_regions = {} - - if detection_results and len(detection_results) > 0: - result_obj = detection_results[0] - trigger_classes = getattr(self.pipeline_config, 'trigger_classes', []) - - # Handle .predict() results which have .boxes for detection models - if hasattr(result_obj, 'boxes') and result_obj.boxes is not None: - logger.info(f"[DETECTION PHASE] Found {len(result_obj.boxes)} raw detections from {getattr(self.pipeline_config, 'model_id', 'unknown')}") - - for i, box in enumerate(result_obj.boxes): - class_id = int(box.cls[0]) - confidence = float(box.conf[0]) - bbox = box.xyxy[0].cpu().numpy().tolist() # [x1, y1, x2, y2] - class_name = self.detection_model.model.names[class_id] - - logger.info(f"[DETECTION PHASE {i+1}] {class_name}: bbox={bbox}, conf={confidence:.3f}") - - # Check if detection matches trigger classes - if trigger_classes and class_name not in trigger_classes: - logger.debug(f"[DETECTION PHASE] Filtered '{class_name}' - not in trigger_classes {trigger_classes}") - continue - - logger.info(f"[DETECTION PHASE] Accepted '{class_name}' - matches trigger_classes") - - # Store detection info - detection_info = { - 'class_name': class_name, - 'confidence': confidence, - 'bbox': bbox - } - valid_detections.append(detection_info) - - # Store region for processing phase (support multiple detections per class) - if class_name not in detected_regions: - detected_regions[class_name] = [] - detected_regions[class_name].append({ - 'bbox': bbox, - 'confidence': confidence - }) - else: - logger.warning("[DETECTION PHASE] No boxes found in detection results") - - # Store detected_regions in result for processing phase - result['detected_regions'] = detected_regions - - result['detections'] = valid_detections - - # If we have valid detections, create session and send initial imageDetection - if valid_detections: - logger.info(f"Found {len(valid_detections)} valid detections, storing session mapping") - - # Store mapping from display_id to subscriptionIdentifier (for detection phase) - # Note: We'll store session_id mapping later in processing phase - self.session_to_subscription[display_id] = subscription_id - logger.info(f"[SESSION MAPPING] Stored mapping: displayId '{display_id}' -> subscriptionIdentifier '{subscription_id}'") - - # Send initial imageDetection message with empty detection data - await self._send_initial_detection_message(subscription_id) - - logger.info(f"Detection phase completed - {len(valid_detections)} detections found for {display_id}") - result['message_sent'] = True - else: - logger.debug("No valid detections found in detection phase") - - except Exception as e: - logger.error(f"Error in detection phase: {e}", exc_info=True) - result['status'] = 'error' - result['message'] = str(e) - - result['processing_time'] = time.time() - start_time - return result - - async def execute_processing_phase(self, - frame: np.ndarray, - display_id: str, - session_id: str, - subscription_id: str, - detected_regions: Dict[str, Any] = None) -> Dict[str, Any]: - """ - Execute the processing phase - run branches and database operations after receiving sessionId. - This is the second phase that runs after backend sends setSessionId. - - Args: - frame: Input frame to process - display_id: Display identifier - session_id: Session ID from backend - subscription_id: Subscription identifier - detected_regions: Pre-detected regions from detection phase - - Returns: - Dictionary with processing phase results - """ - start_time = time.time() - result = { - 'status': 'success', - 'branch_results': {}, - 'actions_executed': [], - 'session_id': session_id, - 'processing_time': 0.0, - 'timestamp': datetime.now().isoformat() - } - - try: - # Create enhanced detection context with session_id - detection_context = { - 'display_id': display_id, - 'session_id': session_id, - 'subscription_id': subscription_id, - 'timestamp': datetime.now().strftime("%Y-%m-%dT%H-%M-%S"), - 'timestamp_ms': int(time.time() * 1000), - 'uuid': str(uuid.uuid4()), - 'filename': f"{uuid.uuid4()}.jpg" - } - - # If no detected_regions provided, re-run detection to get them - if not detected_regions: - # Use .predict() method for detection - detection_results = self.detection_model.model.predict( - frame, - conf=getattr(self.pipeline_config, 'min_confidence', 0.6), - verbose=False - ) - - detected_regions = {} - if detection_results and len(detection_results) > 0: - result_obj = detection_results[0] - if hasattr(result_obj, 'boxes') and result_obj.boxes is not None: - for box in result_obj.boxes: - class_id = int(box.cls[0]) - confidence = float(box.conf[0]) - bbox = box.xyxy[0].cpu().numpy().tolist() # [x1, y1, x2, y2] - class_name = self.detection_model.model.names[class_id] - - detected_regions[class_name] = { - 'bbox': bbox, - 'confidence': confidence - } - - # Store session mapping for license plate callback - if session_id: - self.session_to_subscription[session_id] = subscription_id - logger.info(f"[SESSION MAPPING] Stored mapping: sessionId '{session_id}' -> subscriptionIdentifier '{subscription_id}'") - - # PostgreSQL database insert DISABLED - backend handles data via WebSocket - # if session_id and self.db_manager: - # success = self.db_manager.insert_initial_detection( - # display_id=display_id, - # captured_timestamp=detection_context['timestamp'], - # session_id=session_id - # ) - # if success: - # logger.info(f"Created initial database record with session {session_id}") - # else: - # logger.warning(f"Failed to create initial database record for session {session_id}") - logger.debug(f"Session {session_id} will be communicated via WebSocket") - - # Execute branches in parallel - if hasattr(self.pipeline_config, 'branches') and self.pipeline_config.branches: - branch_results = await self.branch_processor.execute_branches( - frame=frame, - branches=self.pipeline_config.branches, - detected_regions=detected_regions, - detection_context=detection_context - ) - result['branch_results'] = branch_results - logger.info(f"Executed {len(branch_results)} branches for session {session_id}") - - # Execute immediate actions (non-parallel) - immediate_actions = getattr(self.pipeline_config, 'actions', []) - if immediate_actions: - executed_actions = await self._execute_immediate_actions( - actions=immediate_actions, - frame=frame, - detected_regions=detected_regions, - detection_context=detection_context - ) - result['actions_executed'].extend(executed_actions) - - # Execute parallel actions (after all branches complete) - parallel_actions = getattr(self.pipeline_config, 'parallel_actions', []) - if parallel_actions: - # Add branch results to context - enhanced_context = {**detection_context} - if result['branch_results']: - enhanced_context['branch_results'] = result['branch_results'] - - executed_parallel_actions = await self._execute_parallel_actions( - actions=parallel_actions, - frame=frame, - detected_regions=detected_regions, - context=enhanced_context - ) - result['actions_executed'].extend(executed_parallel_actions) - - # Store processing results for later combination with license plate data - if result['branch_results'] and session_id: - self.session_processing_results[session_id] = result['branch_results'] - logger.info(f"[PROCESSING RESULTS] Stored results for session {session_id} for later combination") - - # Send classification results immediately (license plate will come later via callback) - await self._send_classification_results( - subscription_id=subscription_id, - session_id=session_id, - branch_results=result['branch_results'] - ) - - logger.info(f"Processing phase completed for session {session_id}: " - f"{len(result['branch_results'])} branches, {len(result['actions_executed'])} actions") - - except Exception as e: - logger.error(f"Error in processing phase: {e}", exc_info=True) - result['status'] = 'error' - result['message'] = str(e) - - result['processing_time'] = time.time() - start_time - return result - - - async def execute_detection(self, - frame: np.ndarray, - display_id: str, - session_id: Optional[str] = None, - subscription_id: Optional[str] = None) -> Dict[str, Any]: - """ - Execute the main detection pipeline on a frame. - - Args: - frame: Input frame to process - display_id: Display identifier - session_id: Optional session ID - subscription_id: Optional subscription identifier - - Returns: - Dictionary with detection results - """ - start_time = time.time() - result = { - 'status': 'success', - 'detections': [], - 'branch_results': {}, - 'actions_executed': [], - 'session_id': session_id, - 'processing_time': 0.0, - 'timestamp': datetime.now().isoformat() - } - - try: - # Update stats - self.stats['detections_processed'] += 1 - - # Run main detection model - if not self.detection_model: - result['status'] = 'error' - result['message'] = 'Detection model not available' - return result - - # Create detection context - detection_context = { - 'display_id': display_id, - 'session_id': session_id, - 'subscription_id': subscription_id, - 'timestamp': datetime.now().strftime("%Y-%m-%dT%H-%M-%S"), - 'timestamp_ms': int(time.time() * 1000), - 'uuid': str(uuid.uuid4()), - 'filename': f"{uuid.uuid4()}.jpg" - } - - - # Run inference on single snapshot using .predict() method - detection_results = self.detection_model.model.predict( - frame, - conf=getattr(self.pipeline_config, 'min_confidence', 0.6), - verbose=False - ) - - # Process detection results - detected_regions = {} - valid_detections = [] - - if detection_results and len(detection_results) > 0: - result_obj = detection_results[0] - trigger_classes = getattr(self.pipeline_config, 'trigger_classes', []) - - # Handle .predict() results which have .boxes for detection models - if hasattr(result_obj, 'boxes') and result_obj.boxes is not None: - logger.info(f"[PIPELINE RAW] Found {len(result_obj.boxes)} raw detections from {getattr(self.pipeline_config, 'model_id', 'unknown')}") - - for i, box in enumerate(result_obj.boxes): - class_id = int(box.cls[0]) - confidence = float(box.conf[0]) - bbox = box.xyxy[0].cpu().numpy().tolist() # [x1, y1, x2, y2] - class_name = self.detection_model.model.names[class_id] - - logger.info(f"[PIPELINE RAW {i+1}] {class_name}: bbox={bbox}, conf={confidence:.3f}") - - # Check if detection matches trigger classes - if trigger_classes and class_name not in trigger_classes: - continue - - # Store detection info - detection_info = { - 'class_name': class_name, - 'confidence': confidence, - 'bbox': bbox - } - valid_detections.append(detection_info) - - # Store region for cropping - detected_regions[class_name] = { - 'bbox': bbox, - 'confidence': confidence - } - logger.info(f"[PIPELINE DETECTION] {class_name}: bbox={bbox}, conf={confidence:.3f}") - - result['detections'] = valid_detections - - # If we have valid detections, proceed with branches and actions - if valid_detections: - logger.info(f"Found {len(valid_detections)} valid detections for pipeline processing") - - # PostgreSQL database insert DISABLED - backend handles data via WebSocket - # if session_id and self.db_manager: - # success = self.db_manager.insert_initial_detection( - # display_id=display_id, - # captured_timestamp=detection_context['timestamp'], - # session_id=session_id - # ) - # if not success: - # logger.warning(f"Failed to create initial database record for session {session_id}") - logger.debug(f"Detection results for session {session_id} will be sent via WebSocket") - - # Execute branches in parallel - if hasattr(self.pipeline_config, 'branches') and self.pipeline_config.branches: - branch_results = await self.branch_processor.execute_branches( - frame=frame, - branches=self.pipeline_config.branches, - detected_regions=detected_regions, - detection_context=detection_context - ) - result['branch_results'] = branch_results - self.stats['branches_executed'] += len(branch_results) - - # Execute immediate actions (non-parallel) - immediate_actions = getattr(self.pipeline_config, 'actions', []) - if immediate_actions: - executed_actions = await self._execute_immediate_actions( - actions=immediate_actions, - frame=frame, - detected_regions=detected_regions, - detection_context=detection_context - ) - result['actions_executed'].extend(executed_actions) - - # Execute parallel actions (after all branches complete) - parallel_actions = getattr(self.pipeline_config, 'parallel_actions', []) - if parallel_actions: - # Add branch results to context - enhanced_context = {**detection_context} - if result['branch_results']: - enhanced_context['branch_results'] = result['branch_results'] - - executed_parallel_actions = await self._execute_parallel_actions( - actions=parallel_actions, - frame=frame, - detected_regions=detected_regions, - context=enhanced_context - ) - result['actions_executed'].extend(executed_parallel_actions) - - self.stats['actions_executed'] += len(result['actions_executed']) - else: - logger.debug("No valid detections found for pipeline processing") - - except Exception as e: - logger.error(f"Error in detection pipeline execution: {e}", exc_info=True) - result['status'] = 'error' - result['message'] = str(e) - - # Update timing - processing_time = time.time() - start_time - result['processing_time'] = processing_time - self.stats['total_processing_time'] += processing_time - - return result - - async def _execute_immediate_actions(self, - actions: List[Dict], - frame: np.ndarray, - detected_regions: Dict[str, Any], - detection_context: Dict[str, Any]) -> List[Dict]: - """ - Execute immediate actions (non-parallel). - - Args: - actions: List of action configurations - frame: Input frame - detected_regions: Dictionary of detected regions - detection_context: Detection context data - - Returns: - List of executed action results - """ - executed_actions = [] - - for action in actions: - try: - action_type = action.type.value - logger.debug(f"Executing immediate action: {action_type}") - - if action_type == 'redis_save_image': - result = await self._execute_redis_save_image( - action, frame, detected_regions, detection_context - ) - elif action_type == 'redis_publish': - result = await self._execute_redis_publish( - action, detection_context - ) - else: - logger.warning(f"Unknown immediate action type: {action_type}") - result = {'status': 'error', 'message': f'Unknown action type: {action_type}'} - - executed_actions.append({ - 'action_type': action_type, - 'result': result - }) - - except Exception as e: - logger.error(f"Error executing immediate action {action_type}: {e}", exc_info=True) - executed_actions.append({ - 'action_type': action.type.value, - 'result': {'status': 'error', 'message': str(e)} - }) - - return executed_actions - - async def _execute_parallel_actions(self, - actions: List[Dict], - frame: np.ndarray, - detected_regions: Dict[str, Any], - context: Dict[str, Any]) -> List[Dict]: - """ - Execute parallel actions (after branches complete). - - Args: - actions: List of parallel action configurations - frame: Input frame - detected_regions: Dictionary of detected regions - context: Enhanced context with branch results - - Returns: - List of executed action results - """ - executed_actions = [] - - for action in actions: - try: - action_type = action.type.value - logger.debug(f"Executing parallel action: {action_type}") - - if action_type == 'postgresql_update_combined': - # PostgreSQL action SKIPPED - database operations disabled - logger.info(f"Skipping PostgreSQL action '{action_type}' (disabled)") - result = {'status': 'skipped', 'message': 'PostgreSQL operations disabled'} - - # Still update session state for WebSocket messaging - await self._update_session_with_processing_results(context) - - # result = await self._execute_postgresql_update_combined(action, context) - # if result.get('status') == 'success': - # await self._update_session_with_processing_results(context) - else: - logger.warning(f"Unknown parallel action type: {action_type}") - result = {'status': 'error', 'message': f'Unknown action type: {action_type}'} - - executed_actions.append({ - 'action_type': action_type, - 'result': result - }) - - except Exception as e: - logger.error(f"Error executing parallel action {action_type}: {e}", exc_info=True) - executed_actions.append({ - 'action_type': action.type.value, - 'result': {'status': 'error', 'message': str(e)} - }) - - return executed_actions - - async def _execute_redis_save_image(self, - action: Dict, - frame: np.ndarray, - detected_regions: Dict[str, Any], - context: Dict[str, Any]) -> Dict[str, Any]: - """Execute redis_save_image action.""" - if not self.redis_manager: - return {'status': 'error', 'message': 'Redis not available'} - - try: - # Get image to save (cropped or full frame) - image_to_save = frame - region_name = action.params.get('region') - - if region_name and region_name in detected_regions: - # Crop the specified region - # Handle both list (new) and single dict (backward compat) - regions = detected_regions[region_name] - if isinstance(regions, list): - # Multiple detections - select largest bbox - if regions: - best_region = max(regions, key=lambda r: (r['bbox'][2] - r['bbox'][0]) * (r['bbox'][3] - r['bbox'][1])) - bbox = best_region['bbox'] - else: - bbox = None - else: - bbox = regions['bbox'] - - if bbox: - x1, y1, x2, y2 = [int(coord) for coord in bbox] - cropped = frame[y1:y2, x1:x2] - if cropped.size > 0: - image_to_save = cropped - logger.debug(f"Cropped region '{region_name}' for redis_save_image") - else: - logger.warning(f"Empty crop for region '{region_name}', using full frame") - - # Format key with context - key = action.params['key'].format(**context) - - # Save image to Redis - result = await self.redis_manager.save_image( - key=key, - image=image_to_save, - expire_seconds=action.params.get('expire_seconds'), - image_format=action.params.get('format', 'jpeg'), - quality=action.params.get('quality', 90) - ) - - if result: - # Add image_key to context for subsequent actions - context['image_key'] = key - return {'status': 'success', 'key': key} - else: - return {'status': 'error', 'message': 'Failed to save image to Redis'} - - except Exception as e: - logger.error(f"Error in redis_save_image action: {e}", exc_info=True) - return {'status': 'error', 'message': str(e)} - - async def _execute_redis_publish(self, action: Dict, context: Dict[str, Any]) -> Dict[str, Any]: - """Execute redis_publish action.""" - if not self.redis_manager: - return {'status': 'error', 'message': 'Redis not available'} - - try: - channel = action.params['channel'] - message_template = action.params['message'] - - # Format message with context - message = message_template.format(**context) - - # Publish message - result = await self.redis_manager.publish_message(channel, message) - - if result >= 0: # Redis publish returns number of subscribers - return {'status': 'success', 'subscribers': result, 'channel': channel} - else: - return {'status': 'error', 'message': 'Failed to publish message to Redis'} - - except Exception as e: - logger.error(f"Error in redis_publish action: {e}", exc_info=True) - return {'status': 'error', 'message': str(e)} - - # PostgreSQL update method DISABLED - database operations moved to microservices - # This method is no longer used as data flows via WebSocket messages to backend - # async def _execute_postgresql_update_combined(self, - # action: Dict, - # context: Dict[str, Any]) -> Dict[str, Any]: - # """Execute postgresql_update_combined action.""" - # if not self.db_manager: - # return {'status': 'error', 'message': 'Database not available'} - # - # try: - # # Wait for required branches if specified - # wait_for_branches = action.params.get('waitForBranches', []) - # branch_results = context.get('branch_results', {}) - # - # # Log missing branches but don't block the update (allow partial results) - # missing_branches = [b for b in wait_for_branches if b not in branch_results] - # if missing_branches: - # logger.warning(f"Some branches missing from results (will use null): {missing_branches}") - # available_branches = [b for b in wait_for_branches if b in branch_results] - # if available_branches: - # logger.info(f"Available branches for database update: {available_branches}") - # - # # Prepare fields for database update - # table = action.params.get('table', 'car_frontal_info') - # key_field = action.params.get('key_field', 'session_id') - # key_value = action.params.get('key_value', '{session_id}').format(**context) - # field_mappings = action.params.get('fields', {}) - # - # # Resolve field values using branch results - # resolved_fields = {} - # for field_name, field_template in field_mappings.items(): - # try: - # # Replace template variables with actual values from branch results - # resolved_value = self._resolve_field_template(field_template, branch_results, context) - # resolved_fields[field_name] = resolved_value - # except Exception as e: - # logger.warning(f"Failed to resolve field {field_name}: {e}") - # resolved_fields[field_name] = None - # - # # Execute database update - # success = self.db_manager.execute_update( - # table=table, - # key_field=key_field, - # key_value=key_value, - # fields=resolved_fields - # ) - # - # if success: - # return {'status': 'success', 'table': table, 'key': f'{key_field}={key_value}', 'fields': resolved_fields} - # else: - # return {'status': 'error', 'message': 'Database update failed'} - # - # except Exception as e: - # logger.error(f"Error in postgresql_update_combined action: {e}", exc_info=True) - # return {'status': 'error', 'message': str(e)} - - def _resolve_field_template(self, template: str, branch_results: Dict, context: Dict) -> str: - """ - Resolve field template using branch results and context. - - Args: - template: Template string like "{car_brand_cls_v3.brand}" - branch_results: Dictionary of branch execution results - context: Detection context - - Returns: - Resolved field value - """ - try: - # Handle simple context variables first - if template.startswith('{') and template.endswith('}'): - var_name = template[1:-1] - - # Check for branch result reference (e.g., "car_brand_cls_v3.brand") - if '.' in var_name: - branch_id, field_name = var_name.split('.', 1) - if branch_id in branch_results: - branch_data = branch_results[branch_id] - # Look for the field in branch results - if isinstance(branch_data, dict) and 'result' in branch_data: - result_data = branch_data['result'] - if isinstance(result_data, dict) and field_name in result_data: - return str(result_data[field_name]) - logger.warning(f"Field {field_name} not found in branch {branch_id} results") - return None - else: - logger.warning(f"Branch {branch_id} not found in results") - return None - - # Simple context variable - elif var_name in context: - return str(context[var_name]) - - logger.warning(f"Template variable {var_name} not found in context or branch results") - return None - - # Return template as-is if not a template variable - return template - - except Exception as e: - logger.error(f"Error resolving field template {template}: {e}") - return None - - async def _update_session_with_processing_results(self, context: Dict[str, Any]): - """ - Update session state with processing results from branch execution. - - Args: - context: Detection context containing branch results and session info - """ - try: - branch_results = context.get('branch_results', {}) - session_id = context.get('session_id', '') - subscription_id = context.get('subscription_id', '') - - if not session_id: - logger.warning("No session_id in context for processing results") - return - - # Extract fields dynamically using field mappings from pipeline config - extracted_fields = self._extract_fields_from_branches(branch_results) - car_brand = extracted_fields.get('brand') - body_type = extracted_fields.get('body_type') - - logger.info(f"[PROCESSING RESULTS] Completed for session {session_id}: " - f"brand={car_brand}, bodyType={body_type}") - - except Exception as e: - logger.error(f"Error updating session with processing results: {e}", exc_info=True) - - def get_statistics(self) -> Dict[str, Any]: - """Get detection pipeline statistics.""" - branch_stats = self.branch_processor.get_statistics() if self.branch_processor else {} - license_stats = self.license_plate_manager.get_statistics() if self.license_plate_manager else {} - - return { - 'pipeline': self.stats, - 'branches': branch_stats, - 'license_plate': license_stats, - 'redis_available': self.redis_manager is not None, - # 'database_available': self.db_manager is not None, # PostgreSQL disabled - 'detection_model_loaded': self.detection_model is not None - } - - def cleanup(self): - """Cleanup resources.""" - if self.executor: - self.executor.shutdown(wait=False) - - if self.redis_manager: - self.redis_manager.cleanup() - - # PostgreSQL disconnect DISABLED - database operations moved to microservices - # if self.db_manager: - # self.db_manager.disconnect() - - if self.branch_processor: - self.branch_processor.cleanup() - - if self.license_plate_manager: - # Schedule cleanup task and track it to prevent warnings - cleanup_task = asyncio.create_task(self.license_plate_manager.close()) - cleanup_task.add_done_callback(lambda _: None) # Suppress "Task exception was never retrieved" - - logger.info("Detection pipeline cleaned up") \ No newline at end of file diff --git a/core/models/__init__.py b/core/models/__init__.py deleted file mode 100644 index fa2c71a..0000000 --- a/core/models/__init__.py +++ /dev/null @@ -1,42 +0,0 @@ -""" -Models Module - MPTA management, pipeline configuration, and YOLO inference -""" - -from .manager import ModelManager -from .pipeline import ( - PipelineParser, - PipelineConfig, - TrackingConfig, - ModelBranch, - Action, - ActionType, - RedisConfig, - # PostgreSQLConfig # Disabled - moved to microservices -) -from .inference import ( - YOLOWrapper, - ModelInferenceManager, - Detection, - InferenceResult -) - -__all__ = [ - # Manager - 'ModelManager', - - # Pipeline - 'PipelineParser', - 'PipelineConfig', - 'TrackingConfig', - 'ModelBranch', - 'Action', - 'ActionType', - 'RedisConfig', - # 'PostgreSQLConfig', # Disabled - moved to microservices - - # Inference - 'YOLOWrapper', - 'ModelInferenceManager', - 'Detection', - 'InferenceResult', -] \ No newline at end of file diff --git a/core/models/inference.py b/core/models/inference.py deleted file mode 100644 index f96c0e8..0000000 --- a/core/models/inference.py +++ /dev/null @@ -1,447 +0,0 @@ -""" -YOLO Model Inference Wrapper - Handles model loading and inference optimization -""" - -import logging -import torch -import numpy as np -from pathlib import Path -from typing import Dict, List, Optional, Any, Tuple, Union -from threading import Lock -from dataclasses import dataclass -import cv2 - -logger = logging.getLogger(__name__) - - -@dataclass -class Detection: - """Represents a single detection result""" - bbox: List[float] # [x1, y1, x2, y2] - confidence: float - class_id: int - class_name: str - track_id: Optional[int] = None - - -@dataclass -class InferenceResult: - """Result from model inference""" - detections: List[Detection] - image_shape: Tuple[int, int] # (height, width) - inference_time: float - model_id: str - - -class YOLOWrapper: - """Wrapper for YOLO models with caching and optimization""" - - # Class-level model cache shared across all instances - _model_cache: Dict[str, Any] = {} - _cache_lock = Lock() - - def __init__(self, model_path: Path, model_id: str, device: Optional[str] = None): - """ - Initialize YOLO wrapper - - Args: - model_path: Path to the .pt model file - model_id: Unique identifier for the model - device: Device to run inference on ('cuda', 'cpu', or None for auto) - """ - self.model_path = model_path - self.model_id = model_id - - # Auto-detect device if not specified - if device is None: - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - else: - self.device = device - - self.model = None - self._class_names = [] - - - self._load_model() - - logger.info(f"Initialized YOLO wrapper for {model_id} on {self.device}") - - def _load_model(self) -> None: - """Load the YOLO model with caching""" - cache_key = str(self.model_path) - - with self._cache_lock: - # Check if model is already cached - if cache_key in self._model_cache: - logger.info(f"Loading model {self.model_id} from cache") - self.model = self._model_cache[cache_key] - self._extract_class_names() - return - - # Load model - try: - from ultralytics import YOLO - - logger.info(f"Loading YOLO model from {self.model_path}") - self.model = YOLO(str(self.model_path)) - - # Move model to device - if self.device == 'cuda' and torch.cuda.is_available(): - self.model.to('cuda') - logger.info(f"Model {self.model_id} moved to GPU") - - # Cache the model - self._model_cache[cache_key] = self.model - self._extract_class_names() - - logger.info(f"Successfully loaded model {self.model_id}") - - except ImportError: - logger.error("Ultralytics YOLO not installed. Install with: pip install ultralytics") - raise - except Exception as e: - logger.error(f"Failed to load YOLO model {self.model_id}: {str(e)}", exc_info=True) - raise - - def _extract_class_names(self) -> None: - """Extract class names from the model""" - try: - if hasattr(self.model, 'names'): - self._class_names = self.model.names - elif hasattr(self.model, 'model') and hasattr(self.model.model, 'names'): - self._class_names = self.model.model.names - else: - logger.warning(f"Could not extract class names from model {self.model_id}") - self._class_names = {} - except Exception as e: - logger.error(f"Failed to extract class names: {str(e)}") - self._class_names = {} - - - def infer( - self, - image: np.ndarray, - confidence_threshold: float = 0.5, - trigger_classes: Optional[List[str]] = None, - iou_threshold: float = 0.45 - ) -> InferenceResult: - """ - Run inference on an image - - Args: - image: Input image as numpy array (BGR format) - confidence_threshold: Minimum confidence for detections - trigger_classes: List of class names to filter (None = all classes) - iou_threshold: IoU threshold for NMS - - Returns: - InferenceResult containing detections - """ - if self.model is None: - raise RuntimeError(f"Model {self.model_id} not loaded") - - try: - import time - start_time = time.time() - - # Run inference - results = self.model( - image, - conf=confidence_threshold, - iou=iou_threshold, - verbose=False - ) - - inference_time = time.time() - start_time - - # Parse results - detections = self._parse_results(results[0], trigger_classes) - - return InferenceResult( - detections=detections, - image_shape=(image.shape[0], image.shape[1]), - inference_time=inference_time, - model_id=self.model_id - ) - - except Exception as e: - logger.error(f"Inference failed for model {self.model_id}: {str(e)}", exc_info=True) - raise - - def _parse_results( - self, - result: Any, - trigger_classes: Optional[List[str]] = None - ) -> List[Detection]: - """ - Parse YOLO results into Detection objects - - Args: - result: YOLO result object - trigger_classes: Optional list of class names to filter - - Returns: - List of Detection objects - """ - detections = [] - - try: - if result.boxes is None: - return detections - - boxes = result.boxes - for i in range(len(boxes)): - # Get box coordinates - box = boxes.xyxy[i].cpu().numpy() - x1, y1, x2, y2 = box - - # Get confidence and class - conf = float(boxes.conf[i]) - cls_id = int(boxes.cls[i]) - - # Get class name - class_name = self._class_names.get(cls_id, f"class_{cls_id}") - - # Filter by trigger classes if specified - if trigger_classes and class_name not in trigger_classes: - continue - - # Get track ID if available - track_id = None - if hasattr(boxes, 'id') and boxes.id is not None: - track_id = int(boxes.id[i]) - - detection = Detection( - bbox=[float(x1), float(y1), float(x2), float(y2)], - confidence=conf, - class_id=cls_id, - class_name=class_name, - track_id=track_id - ) - detections.append(detection) - - except Exception as e: - logger.error(f"Failed to parse results: {str(e)}", exc_info=True) - - return detections - - - def track( - self, - image: np.ndarray, - confidence_threshold: float = 0.5, - trigger_classes: Optional[List[str]] = None, - persist: bool = True, - camera_id: Optional[str] = None - ) -> InferenceResult: - """ - Run detection (tracking will be handled by external tracker) - - Args: - image: Input image as numpy array (BGR format) - confidence_threshold: Minimum confidence for detections - trigger_classes: List of class names to filter - persist: Ignored - tracking handled externally - camera_id: Ignored - tracking handled externally - - Returns: - InferenceResult containing detections (no track IDs from YOLO) - """ - # Just do detection - no YOLO tracking - return self.infer(image, confidence_threshold, trigger_classes) - - def predict_classification( - self, - image: np.ndarray, - top_k: int = 1 - ) -> Dict[str, float]: - """ - Run classification on an image - - Args: - image: Input image as numpy array (BGR format) - top_k: Number of top predictions to return - - Returns: - Dictionary of class_name -> confidence scores - """ - if self.model is None: - raise RuntimeError(f"Model {self.model_id} not loaded") - - try: - # Run inference - results = self.model(image, verbose=False) - - # For classification models, extract probabilities - if hasattr(results[0], 'probs'): - probs = results[0].probs - top_indices = probs.top5[:top_k] - top_conf = probs.top5conf[:top_k].cpu().numpy() - - predictions = {} - for idx, conf in zip(top_indices, top_conf): - class_name = self._class_names.get(int(idx), f"class_{idx}") - predictions[class_name] = float(conf) - - return predictions - else: - logger.warning(f"Model {self.model_id} does not support classification") - return {} - - except Exception as e: - logger.error(f"Classification failed for model {self.model_id}: {str(e)}", exc_info=True) - raise - - def crop_detection( - self, - image: np.ndarray, - detection: Detection, - padding: int = 0 - ) -> np.ndarray: - """ - Crop image to detection bounding box - - Args: - image: Original image - detection: Detection to crop - padding: Additional padding around the box - - Returns: - Cropped image region - """ - h, w = image.shape[:2] - x1, y1, x2, y2 = detection.bbox - - # Add padding and clip to image boundaries - x1 = max(0, int(x1) - padding) - y1 = max(0, int(y1) - padding) - x2 = min(w, int(x2) + padding) - y2 = min(h, int(y2) + padding) - - return image[y1:y2, x1:x2] - - def get_class_names(self) -> Dict[int, str]: - """Get the class names dictionary""" - return self._class_names.copy() - - def get_num_classes(self) -> int: - """Get the number of classes the model can detect""" - return len(self._class_names) - - - def clear_cache(self) -> None: - """Clear the model cache""" - with self._cache_lock: - cache_key = str(self.model_path) - if cache_key in self._model_cache: - del self._model_cache[cache_key] - logger.info(f"Cleared cache for model {self.model_id}") - - @classmethod - def clear_all_cache(cls) -> None: - """Clear all cached models""" - with cls._cache_lock: - cls._model_cache.clear() - logger.info("Cleared all model cache") - - def warmup(self, image_size: Tuple[int, int] = (640, 640)) -> None: - """ - Warmup the model with a dummy inference - - Args: - image_size: Size of dummy image (height, width) - """ - try: - dummy_image = np.zeros((image_size[0], image_size[1], 3), dtype=np.uint8) - self.infer(dummy_image, confidence_threshold=0.5) - logger.info(f"Model {self.model_id} warmed up") - except Exception as e: - logger.warning(f"Failed to warmup model {self.model_id}: {str(e)}") - - -class ModelInferenceManager: - """Manages multiple YOLO models for a pipeline""" - - def __init__(self, model_dir: Path): - """ - Initialize the inference manager - - Args: - model_dir: Directory containing model files - """ - self.model_dir = model_dir - self.models: Dict[str, YOLOWrapper] = {} - self._lock = Lock() - - logger.info(f"Initialized ModelInferenceManager with model directory: {model_dir}") - - def load_model( - self, - model_id: str, - model_file: str, - device: Optional[str] = None - ) -> YOLOWrapper: - """ - Load a model for inference - - Args: - model_id: Unique identifier for the model - model_file: Filename of the model - device: Device to run on - - Returns: - YOLOWrapper instance - """ - with self._lock: - # Check if already loaded - if model_id in self.models: - logger.debug(f"Model {model_id} already loaded") - return self.models[model_id] - - # Load the model - model_path = self.model_dir / model_file - if not model_path.exists(): - raise FileNotFoundError(f"Model file not found: {model_path}") - - wrapper = YOLOWrapper(model_path, model_id, device) - self.models[model_id] = wrapper - - return wrapper - - def get_model(self, model_id: str) -> Optional[YOLOWrapper]: - """ - Get a loaded model - - Args: - model_id: Model identifier - - Returns: - YOLOWrapper instance or None if not loaded - """ - return self.models.get(model_id) - - def unload_model(self, model_id: str) -> bool: - """ - Unload a model to free memory - - Args: - model_id: Model identifier - - Returns: - True if unloaded, False if not found - """ - with self._lock: - if model_id in self.models: - self.models[model_id].clear_cache() - del self.models[model_id] - logger.info(f"Unloaded model {model_id}") - return True - return False - - def unload_all(self) -> None: - """Unload all models""" - with self._lock: - for model_id in list(self.models.keys()): - self.models[model_id].clear_cache() - self.models.clear() - logger.info("Unloaded all models") \ No newline at end of file diff --git a/core/models/manager.py b/core/models/manager.py deleted file mode 100644 index d40c48f..0000000 --- a/core/models/manager.py +++ /dev/null @@ -1,439 +0,0 @@ -""" -Model Manager Module - Handles MPTA download, extraction, and model loading -""" - -import os -import logging -import zipfile -import json -import hashlib -import requests -from pathlib import Path -from typing import Dict, Optional, Any, Set -from threading import Lock -from urllib.parse import urlparse, parse_qs - -logger = logging.getLogger(__name__) - - -class ModelManager: - """Manages MPTA model downloads, extraction, and caching""" - - def __init__(self, models_dir: str = "models"): - """ - Initialize the Model Manager - - Args: - models_dir: Base directory for storing models - """ - self.models_dir = Path(models_dir) - self.models_dir.mkdir(parents=True, exist_ok=True) - - # Track downloaded models to avoid duplicates - self._downloaded_models: Set[int] = set() - self._model_paths: Dict[int, Path] = {} - self._download_lock = Lock() - - # Scan existing models - self._scan_existing_models() - - logger.info(f"ModelManager initialized with models directory: {self.models_dir}") - logger.info(f"Found existing models: {list(self._downloaded_models)}") - - def _scan_existing_models(self) -> None: - """Scan the models directory for existing downloaded models""" - if not self.models_dir.exists(): - return - - for model_dir in self.models_dir.iterdir(): - if model_dir.is_dir() and model_dir.name.isdigit(): - model_id = int(model_dir.name) - # Check if extraction was successful by looking for pipeline.json - extracted_dirs = list(model_dir.glob("*/pipeline.json")) - if extracted_dirs: - self._downloaded_models.add(model_id) - # Store path to the extracted model directory - self._model_paths[model_id] = extracted_dirs[0].parent - logger.debug(f"Found existing model {model_id} at {extracted_dirs[0].parent}") - - def get_model_path(self, model_id: int) -> Optional[Path]: - """ - Get the path to an extracted model directory - - Args: - model_id: The model ID - - Returns: - Path to the extracted model directory or None if not found - """ - return self._model_paths.get(model_id) - - def is_model_downloaded(self, model_id: int) -> bool: - """ - Check if a model has already been downloaded and extracted - - Args: - model_id: The model ID to check - - Returns: - True if the model is already available - """ - return model_id in self._downloaded_models - - def ensure_model(self, model_id: int, model_url: str, model_name: str = None) -> Optional[Path]: - """ - Ensure a model is downloaded and extracted, downloading if necessary - - Args: - model_id: The model ID - model_url: URL to download the MPTA file from - model_name: Optional model name for logging - - Returns: - Path to the extracted model directory or None if failed - """ - # Check if already downloaded - if self.is_model_downloaded(model_id): - logger.info(f"Model {model_id} already available at {self._model_paths[model_id]}") - return self._model_paths[model_id] - - # Download and extract with lock to prevent concurrent downloads of same model - with self._download_lock: - # Double-check after acquiring lock - if self.is_model_downloaded(model_id): - return self._model_paths[model_id] - - logger.info(f"Model {model_id} not found locally, downloading from {model_url}") - - # Create model directory - model_dir = self.models_dir / str(model_id) - model_dir.mkdir(parents=True, exist_ok=True) - - # Extract filename from URL - mpta_filename = self._extract_filename_from_url(model_url, model_name, model_id) - mpta_path = model_dir / mpta_filename - - # Download MPTA file - if not self._download_mpta(model_url, mpta_path): - logger.error(f"Failed to download model {model_id}") - return None - - # Extract MPTA file - extracted_path = self._extract_mpta(mpta_path, model_dir) - if not extracted_path: - logger.error(f"Failed to extract model {model_id}") - return None - - # Mark as downloaded and store path - self._downloaded_models.add(model_id) - self._model_paths[model_id] = extracted_path - - logger.info(f"Successfully prepared model {model_id} at {extracted_path}") - return extracted_path - - def _extract_filename_from_url(self, url: str, model_name: str = None, model_id: int = None) -> str: - """ - Extract a suitable filename from the URL - - Args: - url: The URL to extract filename from - model_name: Optional model name - model_id: Optional model ID - - Returns: - A suitable filename for the MPTA file - """ - parsed = urlparse(url) - path = parsed.path - - # Try to get filename from path - if path: - filename = os.path.basename(path) - if filename and filename.endswith('.mpta'): - return filename - - # Fallback to constructed name - if model_name: - return f"{model_name}-{model_id}.mpta" - else: - return f"model-{model_id}.mpta" - - def _download_mpta(self, url: str, dest_path: Path) -> bool: - """ - Download an MPTA file from a URL - - Args: - url: URL to download from - dest_path: Destination path for the file - - Returns: - True if successful, False otherwise - """ - try: - logger.info(f"Starting download of model from {url}") - logger.debug(f"Download destination: {dest_path}") - - response = requests.get(url, stream=True, timeout=300) - if response.status_code != 200: - logger.error(f"Failed to download MPTA file (status {response.status_code})") - return False - - file_size = int(response.headers.get('content-length', 0)) - logger.info(f"Model file size: {file_size/1024/1024:.2f} MB") - - downloaded = 0 - last_log_percent = 0 - - with open(dest_path, 'wb') as f: - for chunk in response.iter_content(chunk_size=8192): - if chunk: - f.write(chunk) - downloaded += len(chunk) - - # Log progress every 10% - if file_size > 0: - percent = int(downloaded * 100 / file_size) - if percent >= last_log_percent + 10: - logger.debug(f"Download progress: {percent}%") - last_log_percent = percent - - logger.info(f"Successfully downloaded MPTA file to {dest_path}") - return True - - except requests.RequestException as e: - logger.error(f"Network error downloading MPTA: {str(e)}", exc_info=True) - # Clean up partial download - if dest_path.exists(): - dest_path.unlink() - return False - except Exception as e: - logger.error(f"Unexpected error downloading MPTA: {str(e)}", exc_info=True) - # Clean up partial download - if dest_path.exists(): - dest_path.unlink() - return False - - def _extract_mpta(self, mpta_path: Path, target_dir: Path) -> Optional[Path]: - """ - Extract an MPTA (ZIP) file to the target directory - - Args: - mpta_path: Path to the MPTA file - target_dir: Directory to extract to - - Returns: - Path to the extracted model directory containing pipeline.json, or None if failed - """ - try: - if not mpta_path.exists(): - logger.error(f"MPTA file not found: {mpta_path}") - return None - - logger.info(f"Extracting MPTA file from {mpta_path} to {target_dir}") - - with zipfile.ZipFile(mpta_path, 'r') as zip_ref: - # Get list of files - file_list = zip_ref.namelist() - logger.debug(f"Files in MPTA archive: {len(file_list)} files") - - # Extract all files - zip_ref.extractall(target_dir) - - logger.info(f"Successfully extracted MPTA file to {target_dir}") - - # Find the directory containing pipeline.json - pipeline_files = list(target_dir.glob("*/pipeline.json")) - if not pipeline_files: - # Check if pipeline.json is in root - if (target_dir / "pipeline.json").exists(): - logger.info(f"Found pipeline.json in root of {target_dir}") - return target_dir - logger.error(f"No pipeline.json found after extraction in {target_dir}") - return None - - # Return the directory containing pipeline.json - extracted_dir = pipeline_files[0].parent - logger.info(f"Extracted model to {extracted_dir}") - - # Keep the MPTA file for reference but could delete if space is a concern - # mpta_path.unlink() - # logger.debug(f"Removed MPTA file after extraction: {mpta_path}") - - return extracted_dir - - except zipfile.BadZipFile as e: - logger.error(f"Invalid ZIP/MPTA file {mpta_path}: {str(e)}", exc_info=True) - return None - except Exception as e: - logger.error(f"Failed to extract MPTA file {mpta_path}: {str(e)}", exc_info=True) - return None - - def load_pipeline_config(self, model_id: int) -> Optional[Dict[str, Any]]: - """ - Load the pipeline.json configuration for a model - - Args: - model_id: The model ID - - Returns: - The pipeline configuration dictionary or None if not found - """ - model_path = self.get_model_path(model_id) - if not model_path: - logger.error(f"Model {model_id} not found") - return None - - pipeline_path = model_path / "pipeline.json" - if not pipeline_path.exists(): - logger.error(f"pipeline.json not found for model {model_id}") - return None - - try: - with open(pipeline_path, 'r') as f: - config = json.load(f) - logger.debug(f"Loaded pipeline config for model {model_id}") - return config - except json.JSONDecodeError as e: - logger.error(f"Invalid JSON in pipeline.json for model {model_id}: {str(e)}") - return None - except Exception as e: - logger.error(f"Failed to load pipeline.json for model {model_id}: {str(e)}") - return None - - def get_model_file_path(self, model_id: int, filename: str) -> Optional[Path]: - """ - Get the full path to a model file (e.g., .pt file) - - Args: - model_id: The model ID - filename: The filename within the model directory - - Returns: - Full path to the model file or None if not found - """ - model_path = self.get_model_path(model_id) - if not model_path: - return None - - file_path = model_path / filename - if not file_path.exists(): - logger.error(f"Model file {filename} not found in model {model_id}") - return None - - return file_path - - def cleanup_model(self, model_id: int) -> bool: - """ - Remove a downloaded model to free up space - - Args: - model_id: The model ID to remove - - Returns: - True if successful, False otherwise - """ - if model_id not in self._downloaded_models: - logger.warning(f"Model {model_id} not in downloaded models") - return False - - try: - model_dir = self.models_dir / str(model_id) - if model_dir.exists(): - import shutil - shutil.rmtree(model_dir) - logger.info(f"Removed model directory: {model_dir}") - - self._downloaded_models.discard(model_id) - self._model_paths.pop(model_id, None) - return True - - except Exception as e: - logger.error(f"Failed to cleanup model {model_id}: {str(e)}") - return False - - def get_all_downloaded_models(self) -> Set[int]: - """ - Get a set of all downloaded model IDs - - Returns: - Set of model IDs that are currently downloaded - """ - return self._downloaded_models.copy() - - def get_pipeline_config(self, model_id: int) -> Optional[Any]: - """ - Get the pipeline configuration for a model. - - Args: - model_id: The model ID - - Returns: - PipelineConfig object if found, None otherwise - """ - try: - if model_id not in self._downloaded_models: - logger.warning(f"Model {model_id} not downloaded") - return None - - model_path = self._model_paths.get(model_id) - if not model_path: - logger.warning(f"Model path not found for model {model_id}") - return None - - # Import here to avoid circular imports - from .pipeline import PipelineParser - - # Load pipeline.json - pipeline_file = model_path / "pipeline.json" - if not pipeline_file.exists(): - logger.warning(f"No pipeline.json found for model {model_id}") - return None - - # Create PipelineParser object and parse the configuration - pipeline_parser = PipelineParser() - success = pipeline_parser.parse(pipeline_file) - - if success: - return pipeline_parser - else: - logger.error(f"Failed to parse pipeline.json for model {model_id}") - return None - - except Exception as e: - logger.error(f"Error getting pipeline config for model {model_id}: {e}", exc_info=True) - return None - - def get_yolo_model(self, model_id: int, model_filename: str) -> Optional[Any]: - """ - Create a YOLOWrapper instance for a specific model file. - - Args: - model_id: The model ID - model_filename: The .pt model filename - - Returns: - YOLOWrapper instance if successful, None otherwise - """ - try: - # Get the model file path - model_file_path = self.get_model_file_path(model_id, model_filename) - if not model_file_path or not model_file_path.exists(): - logger.error(f"Model file {model_filename} not found for model {model_id}") - return None - - # Import here to avoid circular imports - from .inference import YOLOWrapper - - # Create YOLOWrapper instance - yolo_model = YOLOWrapper( - model_path=model_file_path, - model_id=f"{model_id}_{model_filename}", - device=None # Auto-detect device - ) - - logger.info(f"Created YOLOWrapper for model {model_id}: {model_filename}") - return yolo_model - - except Exception as e: - logger.error(f"Error creating YOLO model for {model_id}:{model_filename}: {e}", exc_info=True) - return None \ No newline at end of file diff --git a/core/models/pipeline.py b/core/models/pipeline.py deleted file mode 100644 index 3ae7463..0000000 --- a/core/models/pipeline.py +++ /dev/null @@ -1,373 +0,0 @@ -""" -Pipeline Configuration Parser - Handles pipeline.json parsing and validation -""" - -import json -import logging -from pathlib import Path -from typing import Dict, List, Any, Optional, Set -from dataclasses import dataclass, field -from enum import Enum - -logger = logging.getLogger(__name__) - - -class ActionType(Enum): - """Supported action types in pipeline""" - REDIS_SAVE_IMAGE = "redis_save_image" - REDIS_PUBLISH = "redis_publish" - # PostgreSQL actions below are DEPRECATED - kept for backward compatibility only - # These actions will be silently skipped during pipeline execution - POSTGRESQL_UPDATE = "postgresql_update" - POSTGRESQL_UPDATE_COMBINED = "postgresql_update_combined" - POSTGRESQL_INSERT = "postgresql_insert" - - -@dataclass -class RedisConfig: - """Redis connection configuration""" - host: str - port: int = 6379 - password: Optional[str] = None - db: int = 0 - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'RedisConfig': - return cls( - host=data['host'], - port=data.get('port', 6379), - password=data.get('password'), - db=data.get('db', 0) - ) - - -@dataclass -class PostgreSQLConfig: - """ - PostgreSQL connection configuration - DISABLED - - NOTE: This configuration is kept for backward compatibility with existing - pipeline.json files, but PostgreSQL operations are disabled. All database - operations have been moved to microservices architecture. - - This config will be parsed but not used for any database connections. - """ - host: str - port: int - database: str - username: str - password: str - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'PostgreSQLConfig': - """Parse PostgreSQL config from dict (kept for backward compatibility)""" - return cls( - host=data['host'], - port=data.get('port', 5432), - database=data['database'], - username=data['username'], - password=data['password'] - ) - - -@dataclass -class Action: - """Represents an action in the pipeline""" - type: ActionType - params: Dict[str, Any] = field(default_factory=dict) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'Action': - action_type = ActionType(data['type']) - params = {k: v for k, v in data.items() if k != 'type'} - return cls(type=action_type, params=params) - - -@dataclass -class ModelBranch: - """Represents a branch in the pipeline with its own model""" - model_id: str - model_file: str - trigger_classes: List[str] - min_confidence: float = 0.5 - crop: bool = False - crop_class: Optional[Any] = None # Can be string or list - parallel: bool = False - actions: List[Action] = field(default_factory=list) - branches: List['ModelBranch'] = field(default_factory=list) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'ModelBranch': - actions = [Action.from_dict(a) for a in data.get('actions', [])] - branches = [cls.from_dict(b) for b in data.get('branches', [])] - - return cls( - model_id=data['modelId'], - model_file=data['modelFile'], - trigger_classes=data.get('triggerClasses', []), - min_confidence=data.get('minConfidence', 0.5), - crop=data.get('crop', False), - crop_class=data.get('cropClass'), - parallel=data.get('parallel', False), - actions=actions, - branches=branches - ) - - -@dataclass -class TrackingConfig: - """Configuration for the tracking phase""" - model_id: str - model_file: str - trigger_classes: List[str] - min_confidence: float = 0.6 - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'TrackingConfig': - return cls( - model_id=data['modelId'], - model_file=data['modelFile'], - trigger_classes=data.get('triggerClasses', []), - min_confidence=data.get('minConfidence', 0.6) - ) - - -@dataclass -class PipelineConfig: - """Main pipeline configuration""" - model_id: str - model_file: str - trigger_classes: List[str] - min_confidence: float = 0.5 - crop: bool = False - branches: List[ModelBranch] = field(default_factory=list) - parallel_actions: List[Action] = field(default_factory=list) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'PipelineConfig': - branches = [ModelBranch.from_dict(b) for b in data.get('branches', [])] - parallel_actions = [Action.from_dict(a) for a in data.get('parallelActions', [])] - - return cls( - model_id=data['modelId'], - model_file=data['modelFile'], - trigger_classes=data.get('triggerClasses', []), - min_confidence=data.get('minConfidence', 0.5), - crop=data.get('crop', False), - branches=branches, - parallel_actions=parallel_actions - ) - - -class PipelineParser: - """Parser for pipeline.json configuration files""" - - def __init__(self): - self.redis_config: Optional[RedisConfig] = None - self.postgresql_config: Optional[PostgreSQLConfig] = None - self.tracking_config: Optional[TrackingConfig] = None - self.pipeline_config: Optional[PipelineConfig] = None - self._model_dependencies: Set[str] = set() - - def parse(self, config_path: Path) -> bool: - """ - Parse a pipeline.json configuration file - - Args: - config_path: Path to the pipeline.json file - - Returns: - True if parsing was successful, False otherwise - """ - try: - if not config_path.exists(): - logger.error(f"Pipeline config not found: {config_path}") - return False - - with open(config_path, 'r') as f: - data = json.load(f) - - return self.parse_dict(data) - - except json.JSONDecodeError as e: - logger.error(f"Invalid JSON in pipeline config: {str(e)}") - return False - except Exception as e: - logger.error(f"Failed to parse pipeline config: {str(e)}", exc_info=True) - return False - - def parse_dict(self, data: Dict[str, Any]) -> bool: - """ - Parse a pipeline configuration from a dictionary - - Args: - data: The configuration dictionary - - Returns: - True if parsing was successful, False otherwise - """ - try: - # Parse Redis configuration - if 'redis' in data: - self.redis_config = RedisConfig.from_dict(data['redis']) - logger.debug(f"Parsed Redis config: {self.redis_config.host}:{self.redis_config.port}") - - # Parse PostgreSQL configuration - if 'postgresql' in data: - self.postgresql_config = PostgreSQLConfig.from_dict(data['postgresql']) - logger.debug(f"Parsed PostgreSQL config: {self.postgresql_config.host}:{self.postgresql_config.port}/{self.postgresql_config.database}") - - # Parse tracking configuration - if 'tracking' in data: - self.tracking_config = TrackingConfig.from_dict(data['tracking']) - self._model_dependencies.add(self.tracking_config.model_file) - logger.debug(f"Parsed tracking config: {self.tracking_config.model_id}") - - # Parse main pipeline configuration - if 'pipeline' in data: - self.pipeline_config = PipelineConfig.from_dict(data['pipeline']) - self._collect_model_dependencies(self.pipeline_config) - logger.debug(f"Parsed pipeline config: {self.pipeline_config.model_id}") - - logger.info(f"Successfully parsed pipeline configuration") - logger.debug(f"Model dependencies: {self._model_dependencies}") - return True - - except KeyError as e: - logger.error(f"Missing required field in pipeline config: {str(e)}") - return False - except Exception as e: - logger.error(f"Failed to parse pipeline config: {str(e)}", exc_info=True) - return False - - def _collect_model_dependencies(self, config: Any) -> None: - """ - Recursively collect all model file dependencies - - Args: - config: Pipeline or branch configuration - """ - if hasattr(config, 'model_file'): - self._model_dependencies.add(config.model_file) - - if hasattr(config, 'branches'): - for branch in config.branches: - self._collect_model_dependencies(branch) - - def get_model_dependencies(self) -> Set[str]: - """ - Get all model file dependencies from the pipeline - - Returns: - Set of model filenames required by the pipeline - """ - return self._model_dependencies.copy() - - def validate(self) -> bool: - """ - Validate the parsed configuration - - Returns: - True if configuration is valid, False otherwise - """ - if not self.pipeline_config: - logger.error("No pipeline configuration found") - return False - - # Check that all required model files are specified - if not self.pipeline_config.model_file: - logger.error("Main pipeline model file not specified") - return False - - # Validate action configurations - if not self._validate_actions(self.pipeline_config): - return False - - # Validate parallel actions (PostgreSQL actions are skipped) - for action in self.pipeline_config.parallel_actions: - if action.type == ActionType.POSTGRESQL_UPDATE_COMBINED: - logger.warning(f"PostgreSQL parallel action {action.type.value} found but will be SKIPPED (PostgreSQL disabled)") - # Skip validation for PostgreSQL actions since they won't be executed - # wait_for = action.params.get('waitForBranches', []) - # if wait_for: - # # Check that referenced branches exist - # branch_ids = self._get_all_branch_ids(self.pipeline_config) - # for branch_id in wait_for: - # if branch_id not in branch_ids: - # logger.error(f"Referenced branch '{branch_id}' in waitForBranches not found") - # return False - - logger.info("Pipeline configuration validated successfully") - return True - - def _validate_actions(self, config: Any) -> bool: - """ - Validate actions in a pipeline or branch configuration - - Args: - config: Pipeline or branch configuration - - Returns: - True if valid, False otherwise - """ - if hasattr(config, 'actions'): - for action in config.actions: - # Validate Redis actions need Redis config - if action.type in [ActionType.REDIS_SAVE_IMAGE, ActionType.REDIS_PUBLISH]: - if not self.redis_config: - logger.error(f"Action {action.type} requires Redis configuration") - return False - - # PostgreSQL actions are disabled - log warning instead of failing - # Kept for backward compatibility with existing pipeline.json files - if action.type in [ActionType.POSTGRESQL_UPDATE, ActionType.POSTGRESQL_UPDATE_COMBINED, ActionType.POSTGRESQL_INSERT]: - logger.warning(f"PostgreSQL action {action.type.value} found but will be SKIPPED (PostgreSQL disabled)") - # Do not fail validation - just skip these actions during execution - # if not self.postgresql_config: - # logger.error(f"Action {action.type} requires PostgreSQL configuration") - # return False - - # Recursively validate branches - if hasattr(config, 'branches'): - for branch in config.branches: - if not self._validate_actions(branch): - return False - - return True - - def _get_all_branch_ids(self, config: Any, branch_ids: Set[str] = None) -> Set[str]: - """ - Recursively collect all branch model IDs - - Args: - config: Pipeline or branch configuration - branch_ids: Set to collect IDs into - - Returns: - Set of all branch model IDs - """ - if branch_ids is None: - branch_ids = set() - - if hasattr(config, 'branches'): - for branch in config.branches: - branch_ids.add(branch.model_id) - self._get_all_branch_ids(branch, branch_ids) - - return branch_ids - - def get_redis_config(self) -> Optional[RedisConfig]: - """Get the Redis configuration""" - return self.redis_config - - def get_postgresql_config(self) -> Optional[PostgreSQLConfig]: - """Get the PostgreSQL configuration""" - return self.postgresql_config - - def get_tracking_config(self) -> Optional[TrackingConfig]: - """Get the tracking configuration""" - return self.tracking_config - - def get_pipeline_config(self) -> Optional[PipelineConfig]: - """Get the main pipeline configuration""" - return self.pipeline_config \ No newline at end of file diff --git a/core/monitoring/__init__.py b/core/monitoring/__init__.py deleted file mode 100644 index 2ad32ed..0000000 --- a/core/monitoring/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -""" -Comprehensive health monitoring system for detector worker. -Tracks stream health, thread responsiveness, and system performance. -""" - -from .health import HealthMonitor, HealthStatus, HealthCheck -from .stream_health import StreamHealthTracker -from .thread_health import ThreadHealthMonitor -from .recovery import RecoveryManager - -__all__ = [ - 'HealthMonitor', - 'HealthStatus', - 'HealthCheck', - 'StreamHealthTracker', - 'ThreadHealthMonitor', - 'RecoveryManager' -] \ No newline at end of file diff --git a/core/monitoring/health.py b/core/monitoring/health.py deleted file mode 100644 index be094f3..0000000 --- a/core/monitoring/health.py +++ /dev/null @@ -1,456 +0,0 @@ -""" -Core health monitoring system for comprehensive stream and system health tracking. -Provides centralized health status, alerting, and recovery coordination. -""" -import time -import threading -import logging -import psutil -from typing import Dict, List, Optional, Any, Callable -from dataclasses import dataclass, field -from enum import Enum -from collections import defaultdict, deque - - -logger = logging.getLogger(__name__) - - -class HealthStatus(Enum): - """Health status levels.""" - HEALTHY = "healthy" - WARNING = "warning" - CRITICAL = "critical" - UNKNOWN = "unknown" - - -@dataclass -class HealthCheck: - """Individual health check result.""" - name: str - status: HealthStatus - message: str - timestamp: float = field(default_factory=time.time) - details: Dict[str, Any] = field(default_factory=dict) - recovery_action: Optional[str] = None - - -@dataclass -class HealthMetrics: - """Health metrics for a component.""" - component_id: str - last_update: float - frame_count: int = 0 - error_count: int = 0 - warning_count: int = 0 - restart_count: int = 0 - avg_frame_interval: float = 0.0 - last_frame_time: Optional[float] = None - thread_alive: bool = True - connection_healthy: bool = True - memory_usage_mb: float = 0.0 - cpu_usage_percent: float = 0.0 - - -class HealthMonitor: - """Comprehensive health monitoring system.""" - - def __init__(self, check_interval: float = 30.0): - """ - Initialize health monitor. - - Args: - check_interval: Interval between health checks in seconds - """ - self.check_interval = check_interval - self.running = False - self.monitor_thread = None - self._lock = threading.RLock() - - # Health data storage - self.health_checks: Dict[str, HealthCheck] = {} - self.metrics: Dict[str, HealthMetrics] = {} - self.alert_history: deque = deque(maxlen=1000) - self.recovery_actions: deque = deque(maxlen=500) - - # Thresholds (configurable) - self.thresholds = { - 'frame_stale_warning_seconds': 120, # 2 minutes - 'frame_stale_critical_seconds': 300, # 5 minutes - 'thread_unresponsive_seconds': 60, # 1 minute - 'memory_warning_mb': 500, # 500MB per stream - 'memory_critical_mb': 1000, # 1GB per stream - 'cpu_warning_percent': 80, # 80% CPU - 'cpu_critical_percent': 95, # 95% CPU - 'error_rate_warning': 0.1, # 10% error rate - 'error_rate_critical': 0.3, # 30% error rate - 'restart_threshold': 3 # Max restarts per hour - } - - # Health check functions - self.health_checkers: List[Callable[[], List[HealthCheck]]] = [] - self.recovery_callbacks: Dict[str, Callable[[str, HealthCheck], bool]] = {} - - # System monitoring - self.process = psutil.Process() - self.system_start_time = time.time() - - def start(self): - """Start health monitoring.""" - if self.running: - logger.warning("Health monitor already running") - return - - self.running = True - self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True) - self.monitor_thread.start() - logger.info(f"Health monitor started (check interval: {self.check_interval}s)") - - def stop(self): - """Stop health monitoring.""" - self.running = False - if self.monitor_thread: - self.monitor_thread.join(timeout=5.0) - logger.info("Health monitor stopped") - - def register_health_checker(self, checker: Callable[[], List[HealthCheck]]): - """Register a health check function.""" - self.health_checkers.append(checker) - logger.debug(f"Registered health checker: {checker.__name__}") - - def register_recovery_callback(self, component: str, callback: Callable[[str, HealthCheck], bool]): - """Register a recovery callback for a component.""" - self.recovery_callbacks[component] = callback - logger.debug(f"Registered recovery callback for {component}") - - def update_metrics(self, component_id: str, **kwargs): - """Update metrics for a component.""" - with self._lock: - if component_id not in self.metrics: - self.metrics[component_id] = HealthMetrics( - component_id=component_id, - last_update=time.time() - ) - - metrics = self.metrics[component_id] - metrics.last_update = time.time() - - # Update provided metrics - for key, value in kwargs.items(): - if hasattr(metrics, key): - setattr(metrics, key, value) - - def report_frame_received(self, component_id: str): - """Report that a frame was received for a component.""" - current_time = time.time() - with self._lock: - if component_id not in self.metrics: - self.metrics[component_id] = HealthMetrics( - component_id=component_id, - last_update=current_time - ) - - metrics = self.metrics[component_id] - - # Update frame metrics - if metrics.last_frame_time: - interval = current_time - metrics.last_frame_time - # Moving average of frame intervals - if metrics.avg_frame_interval == 0: - metrics.avg_frame_interval = interval - else: - metrics.avg_frame_interval = (metrics.avg_frame_interval * 0.9) + (interval * 0.1) - - metrics.last_frame_time = current_time - metrics.frame_count += 1 - metrics.last_update = current_time - - def report_error(self, component_id: str, error_type: str = "general"): - """Report an error for a component.""" - with self._lock: - if component_id not in self.metrics: - self.metrics[component_id] = HealthMetrics( - component_id=component_id, - last_update=time.time() - ) - - self.metrics[component_id].error_count += 1 - self.metrics[component_id].last_update = time.time() - - logger.debug(f"Error reported for {component_id}: {error_type}") - - def report_warning(self, component_id: str, warning_type: str = "general"): - """Report a warning for a component.""" - with self._lock: - if component_id not in self.metrics: - self.metrics[component_id] = HealthMetrics( - component_id=component_id, - last_update=time.time() - ) - - self.metrics[component_id].warning_count += 1 - self.metrics[component_id].last_update = time.time() - - logger.debug(f"Warning reported for {component_id}: {warning_type}") - - def report_restart(self, component_id: str): - """Report that a component was restarted.""" - with self._lock: - if component_id not in self.metrics: - self.metrics[component_id] = HealthMetrics( - component_id=component_id, - last_update=time.time() - ) - - self.metrics[component_id].restart_count += 1 - self.metrics[component_id].last_update = time.time() - - # Log recovery action - recovery_action = { - 'timestamp': time.time(), - 'component': component_id, - 'action': 'restart', - 'reason': 'manual_restart' - } - - with self._lock: - self.recovery_actions.append(recovery_action) - - logger.info(f"Restart reported for {component_id}") - - def get_health_status(self, component_id: Optional[str] = None) -> Dict[str, Any]: - """Get comprehensive health status.""" - with self._lock: - if component_id: - # Get health for specific component - return self._get_component_health(component_id) - else: - # Get overall health status - return self._get_overall_health() - - def _get_component_health(self, component_id: str) -> Dict[str, Any]: - """Get health status for a specific component.""" - if component_id not in self.metrics: - return { - 'component_id': component_id, - 'status': HealthStatus.UNKNOWN.value, - 'message': 'No metrics available', - 'metrics': {} - } - - metrics = self.metrics[component_id] - current_time = time.time() - - # Determine health status - status = HealthStatus.HEALTHY - issues = [] - - # Check frame freshness - if metrics.last_frame_time: - frame_age = current_time - metrics.last_frame_time - if frame_age > self.thresholds['frame_stale_critical_seconds']: - status = HealthStatus.CRITICAL - issues.append(f"Frames stale for {frame_age:.1f}s") - elif frame_age > self.thresholds['frame_stale_warning_seconds']: - if status == HealthStatus.HEALTHY: - status = HealthStatus.WARNING - issues.append(f"Frames aging ({frame_age:.1f}s)") - - # Check error rates - if metrics.frame_count > 0: - error_rate = metrics.error_count / metrics.frame_count - if error_rate > self.thresholds['error_rate_critical']: - status = HealthStatus.CRITICAL - issues.append(f"High error rate ({error_rate:.1%})") - elif error_rate > self.thresholds['error_rate_warning']: - if status == HealthStatus.HEALTHY: - status = HealthStatus.WARNING - issues.append(f"Elevated error rate ({error_rate:.1%})") - - # Check restart frequency - restart_rate = metrics.restart_count / max(1, (current_time - self.system_start_time) / 3600) - if restart_rate > self.thresholds['restart_threshold']: - status = HealthStatus.CRITICAL - issues.append(f"Frequent restarts ({restart_rate:.1f}/hour)") - - # Check thread health - if not metrics.thread_alive: - status = HealthStatus.CRITICAL - issues.append("Thread not alive") - - # Check connection health - if not metrics.connection_healthy: - if status == HealthStatus.HEALTHY: - status = HealthStatus.WARNING - issues.append("Connection unhealthy") - - return { - 'component_id': component_id, - 'status': status.value, - 'message': '; '.join(issues) if issues else 'All checks passing', - 'metrics': { - 'frame_count': metrics.frame_count, - 'error_count': metrics.error_count, - 'warning_count': metrics.warning_count, - 'restart_count': metrics.restart_count, - 'avg_frame_interval': metrics.avg_frame_interval, - 'last_frame_age': current_time - metrics.last_frame_time if metrics.last_frame_time else None, - 'thread_alive': metrics.thread_alive, - 'connection_healthy': metrics.connection_healthy, - 'memory_usage_mb': metrics.memory_usage_mb, - 'cpu_usage_percent': metrics.cpu_usage_percent, - 'uptime_seconds': current_time - self.system_start_time - }, - 'last_update': metrics.last_update - } - - def _get_overall_health(self) -> Dict[str, Any]: - """Get overall system health status.""" - current_time = time.time() - components = {} - overall_status = HealthStatus.HEALTHY - - # Get health for all components - for component_id in self.metrics.keys(): - component_health = self._get_component_health(component_id) - components[component_id] = component_health - - # Determine overall status - component_status = HealthStatus(component_health['status']) - if component_status == HealthStatus.CRITICAL: - overall_status = HealthStatus.CRITICAL - elif component_status == HealthStatus.WARNING and overall_status == HealthStatus.HEALTHY: - overall_status = HealthStatus.WARNING - - # System metrics - try: - system_memory = self.process.memory_info() - system_cpu = self.process.cpu_percent() - except Exception: - system_memory = None - system_cpu = 0.0 - - return { - 'overall_status': overall_status.value, - 'timestamp': current_time, - 'uptime_seconds': current_time - self.system_start_time, - 'total_components': len(self.metrics), - 'components': components, - 'system_metrics': { - 'memory_mb': system_memory.rss / (1024 * 1024) if system_memory else 0, - 'cpu_percent': system_cpu, - 'process_id': self.process.pid - }, - 'recent_alerts': list(self.alert_history)[-10:], # Last 10 alerts - 'recent_recoveries': list(self.recovery_actions)[-10:] # Last 10 recovery actions - } - - def _monitor_loop(self): - """Main health monitoring loop.""" - logger.info("Health monitor loop started") - - while self.running: - try: - start_time = time.time() - - # Run all registered health checks - all_checks = [] - for checker in self.health_checkers: - try: - checks = checker() - all_checks.extend(checks) - except Exception as e: - logger.error(f"Error in health checker {checker.__name__}: {e}") - - # Process health checks and trigger recovery if needed - for check in all_checks: - self._process_health_check(check) - - # Update system metrics - self._update_system_metrics() - - # Sleep until next check - elapsed = time.time() - start_time - sleep_time = max(0, self.check_interval - elapsed) - if sleep_time > 0: - time.sleep(sleep_time) - - except Exception as e: - logger.error(f"Error in health monitor loop: {e}") - time.sleep(5.0) # Fallback sleep - - logger.info("Health monitor loop ended") - - def _process_health_check(self, check: HealthCheck): - """Process a health check result and trigger recovery if needed.""" - with self._lock: - # Store health check - self.health_checks[check.name] = check - - # Log alerts for non-healthy status - if check.status != HealthStatus.HEALTHY: - alert = { - 'timestamp': check.timestamp, - 'component': check.name, - 'status': check.status.value, - 'message': check.message, - 'details': check.details - } - self.alert_history.append(alert) - - logger.warning(f"Health alert [{check.status.value.upper()}] {check.name}: {check.message}") - - # Trigger recovery if critical and recovery action available - if check.status == HealthStatus.CRITICAL and check.recovery_action: - self._trigger_recovery(check.name, check) - - def _trigger_recovery(self, component: str, check: HealthCheck): - """Trigger recovery action for a component.""" - if component in self.recovery_callbacks: - try: - logger.info(f"Triggering recovery for {component}: {check.recovery_action}") - - success = self.recovery_callbacks[component](component, check) - - recovery_action = { - 'timestamp': time.time(), - 'component': component, - 'action': check.recovery_action, - 'reason': check.message, - 'success': success - } - - with self._lock: - self.recovery_actions.append(recovery_action) - - if success: - logger.info(f"Recovery successful for {component}") - else: - logger.error(f"Recovery failed for {component}") - - except Exception as e: - logger.error(f"Error in recovery callback for {component}: {e}") - - def _update_system_metrics(self): - """Update system-level metrics.""" - try: - # Update process metrics for all components - current_time = time.time() - - with self._lock: - for component_id, metrics in self.metrics.items(): - # Update CPU and memory if available - try: - # This is a simplified approach - in practice you'd want - # per-thread or per-component resource tracking - metrics.cpu_usage_percent = self.process.cpu_percent() / len(self.metrics) - memory_info = self.process.memory_info() - metrics.memory_usage_mb = memory_info.rss / (1024 * 1024) / len(self.metrics) - except Exception: - pass - - except Exception as e: - logger.error(f"Error updating system metrics: {e}") - - -# Global health monitor instance -health_monitor = HealthMonitor() \ No newline at end of file diff --git a/core/monitoring/recovery.py b/core/monitoring/recovery.py deleted file mode 100644 index 4ea16dc..0000000 --- a/core/monitoring/recovery.py +++ /dev/null @@ -1,385 +0,0 @@ -""" -Recovery manager for automatic handling of health issues. -Provides circuit breaker patterns, automatic restarts, and graceful degradation. -""" -import time -import logging -import threading -from typing import Dict, List, Optional, Any, Callable -from dataclasses import dataclass -from enum import Enum -from collections import defaultdict, deque - -from .health import HealthCheck, HealthStatus, health_monitor - - -logger = logging.getLogger(__name__) - - -class RecoveryAction(Enum): - """Types of recovery actions.""" - RESTART_STREAM = "restart_stream" - RESTART_THREAD = "restart_thread" - CLEAR_BUFFER = "clear_buffer" - RECONNECT = "reconnect" - THROTTLE = "throttle" - DISABLE = "disable" - - -@dataclass -class RecoveryAttempt: - """Record of a recovery attempt.""" - timestamp: float - component: str - action: RecoveryAction - reason: str - success: bool - details: Dict[str, Any] = None - - -@dataclass -class RecoveryState: - """Recovery state for a component - simplified without circuit breaker.""" - failure_count: int = 0 - success_count: int = 0 - last_failure_time: Optional[float] = None - last_success_time: Optional[float] = None - - -class RecoveryManager: - """Manages automatic recovery actions for health issues.""" - - def __init__(self): - self.recovery_handlers: Dict[str, Callable[[str, HealthCheck], bool]] = {} - self.recovery_states: Dict[str, RecoveryState] = {} - self.recovery_history: deque = deque(maxlen=1000) - self._lock = threading.RLock() - - # Configuration - simplified without circuit breaker - self.recovery_cooldown = 30 # 30 seconds between recovery attempts - self.max_attempts_per_hour = 20 # Still limit to prevent spam, but much higher - - # Track recovery attempts per component - self.recovery_attempts: Dict[str, deque] = defaultdict(lambda: deque(maxlen=50)) - - # Register with health monitor - health_monitor.register_recovery_callback("stream", self._handle_stream_recovery) - health_monitor.register_recovery_callback("thread", self._handle_thread_recovery) - health_monitor.register_recovery_callback("buffer", self._handle_buffer_recovery) - - def register_recovery_handler(self, action: RecoveryAction, handler: Callable[[str, Dict[str, Any]], bool]): - """ - Register a recovery handler for a specific action. - - Args: - action: Type of recovery action - handler: Function that performs the recovery - """ - self.recovery_handlers[action.value] = handler - logger.info(f"Registered recovery handler for {action.value}") - - def can_attempt_recovery(self, component: str) -> bool: - """ - Check if recovery can be attempted for a component. - - Args: - component: Component identifier - - Returns: - True if recovery can be attempted (always allow with minimal throttling) - """ - with self._lock: - current_time = time.time() - - # Check recovery attempt rate limiting (much more permissive) - recent_attempts = [ - attempt for attempt in self.recovery_attempts[component] - if current_time - attempt <= 3600 # Last hour - ] - - # Only block if truly excessive attempts - if len(recent_attempts) >= self.max_attempts_per_hour: - logger.warning(f"Recovery rate limit exceeded for {component} " - f"({len(recent_attempts)} attempts in last hour)") - return False - - # Check cooldown period (shorter cooldown) - if recent_attempts: - last_attempt = max(recent_attempts) - if current_time - last_attempt < self.recovery_cooldown: - logger.debug(f"Recovery cooldown active for {component} " - f"(last attempt {current_time - last_attempt:.1f}s ago)") - return False - - return True - - def attempt_recovery(self, component: str, action: RecoveryAction, reason: str, - details: Optional[Dict[str, Any]] = None) -> bool: - """ - Attempt recovery for a component. - - Args: - component: Component identifier - action: Recovery action to perform - reason: Reason for recovery - details: Additional details - - Returns: - True if recovery was successful - """ - if not self.can_attempt_recovery(component): - return False - - current_time = time.time() - - logger.info(f"Attempting recovery for {component}: {action.value} ({reason})") - - try: - # Record recovery attempt - with self._lock: - self.recovery_attempts[component].append(current_time) - - # Perform recovery action - success = self._execute_recovery_action(component, action, details or {}) - - # Record recovery result - attempt = RecoveryAttempt( - timestamp=current_time, - component=component, - action=action, - reason=reason, - success=success, - details=details - ) - - with self._lock: - self.recovery_history.append(attempt) - - # Update recovery state - self._update_recovery_state(component, success) - - if success: - logger.info(f"Recovery successful for {component}: {action.value}") - else: - logger.error(f"Recovery failed for {component}: {action.value}") - - return success - - except Exception as e: - logger.error(f"Error during recovery for {component}: {e}") - self._update_recovery_state(component, False) - return False - - def _execute_recovery_action(self, component: str, action: RecoveryAction, - details: Dict[str, Any]) -> bool: - """Execute a specific recovery action.""" - handler_key = action.value - - if handler_key not in self.recovery_handlers: - logger.error(f"No recovery handler registered for action: {handler_key}") - return False - - try: - handler = self.recovery_handlers[handler_key] - return handler(component, details) - - except Exception as e: - logger.error(f"Error executing recovery action {handler_key} for {component}: {e}") - return False - - def _update_recovery_state(self, component: str, success: bool): - """Update recovery state based on recovery result.""" - current_time = time.time() - - with self._lock: - if component not in self.recovery_states: - self.recovery_states[component] = RecoveryState() - - state = self.recovery_states[component] - - if success: - state.success_count += 1 - state.last_success_time = current_time - # Reset failure count on success - state.failure_count = max(0, state.failure_count - 1) - logger.debug(f"Recovery success for {component} (total successes: {state.success_count})") - else: - state.failure_count += 1 - state.last_failure_time = current_time - logger.debug(f"Recovery failure for {component} (total failures: {state.failure_count})") - - def _handle_stream_recovery(self, component: str, health_check: HealthCheck) -> bool: - """Handle recovery for stream-related issues.""" - if "frames" in health_check.name: - # Frame-related issue - restart stream - return self.attempt_recovery( - component, - RecoveryAction.RESTART_STREAM, - health_check.message, - health_check.details - ) - elif "connection" in health_check.name: - # Connection issue - reconnect - return self.attempt_recovery( - component, - RecoveryAction.RECONNECT, - health_check.message, - health_check.details - ) - elif "errors" in health_check.name: - # High error rate - throttle or restart - return self.attempt_recovery( - component, - RecoveryAction.THROTTLE, - health_check.message, - health_check.details - ) - else: - # Generic stream issue - restart - return self.attempt_recovery( - component, - RecoveryAction.RESTART_STREAM, - health_check.message, - health_check.details - ) - - def _handle_thread_recovery(self, component: str, health_check: HealthCheck) -> bool: - """Handle recovery for thread-related issues.""" - if "deadlock" in health_check.name: - # Deadlock detected - restart thread - return self.attempt_recovery( - component, - RecoveryAction.RESTART_THREAD, - health_check.message, - health_check.details - ) - elif "responsive" in health_check.name: - # Thread unresponsive - restart - return self.attempt_recovery( - component, - RecoveryAction.RESTART_THREAD, - health_check.message, - health_check.details - ) - else: - # Generic thread issue - restart - return self.attempt_recovery( - component, - RecoveryAction.RESTART_THREAD, - health_check.message, - health_check.details - ) - - def _handle_buffer_recovery(self, component: str, health_check: HealthCheck) -> bool: - """Handle recovery for buffer-related issues.""" - # Buffer issues - clear buffer - return self.attempt_recovery( - component, - RecoveryAction.CLEAR_BUFFER, - health_check.message, - health_check.details - ) - - def get_recovery_stats(self) -> Dict[str, Any]: - """Get recovery statistics.""" - current_time = time.time() - - with self._lock: - # Calculate stats from history - recent_recoveries = [ - attempt for attempt in self.recovery_history - if current_time - attempt.timestamp <= 3600 # Last hour - ] - - stats_by_component = defaultdict(lambda: { - 'attempts': 0, - 'successes': 0, - 'failures': 0, - 'last_attempt': None, - 'last_success': None - }) - - for attempt in recent_recoveries: - stats = stats_by_component[attempt.component] - stats['attempts'] += 1 - - if attempt.success: - stats['successes'] += 1 - if not stats['last_success'] or attempt.timestamp > stats['last_success']: - stats['last_success'] = attempt.timestamp - else: - stats['failures'] += 1 - - if not stats['last_attempt'] or attempt.timestamp > stats['last_attempt']: - stats['last_attempt'] = attempt.timestamp - - return { - 'total_recoveries_last_hour': len(recent_recoveries), - 'recovery_by_component': dict(stats_by_component), - 'recovery_states': { - component: { - 'failure_count': state.failure_count, - 'success_count': state.success_count, - 'last_failure_time': state.last_failure_time, - 'last_success_time': state.last_success_time - } - for component, state in self.recovery_states.items() - }, - 'recent_history': [ - { - 'timestamp': attempt.timestamp, - 'component': attempt.component, - 'action': attempt.action.value, - 'reason': attempt.reason, - 'success': attempt.success - } - for attempt in list(self.recovery_history)[-10:] # Last 10 attempts - ] - } - - def force_recovery(self, component: str, action: RecoveryAction, reason: str = "manual") -> bool: - """ - Force recovery for a component, bypassing rate limiting. - - Args: - component: Component identifier - action: Recovery action to perform - reason: Reason for forced recovery - - Returns: - True if recovery was successful - """ - logger.info(f"Forcing recovery for {component}: {action.value} ({reason})") - - current_time = time.time() - - try: - # Execute recovery action directly - success = self._execute_recovery_action(component, action, {}) - - # Record forced recovery - attempt = RecoveryAttempt( - timestamp=current_time, - component=component, - action=action, - reason=f"forced: {reason}", - success=success, - details={'forced': True} - ) - - with self._lock: - self.recovery_history.append(attempt) - self.recovery_attempts[component].append(current_time) - - # Update recovery state - self._update_recovery_state(component, success) - - return success - - except Exception as e: - logger.error(f"Error during forced recovery for {component}: {e}") - return False - - -# Global recovery manager instance -recovery_manager = RecoveryManager() \ No newline at end of file diff --git a/core/monitoring/stream_health.py b/core/monitoring/stream_health.py deleted file mode 100644 index 770dfe4..0000000 --- a/core/monitoring/stream_health.py +++ /dev/null @@ -1,351 +0,0 @@ -""" -Stream-specific health monitoring for video streams. -Tracks frame production, connection health, and stream-specific metrics. -""" -import time -import logging -import threading -import requests -from typing import Dict, Optional, List, Any -from collections import deque -from dataclasses import dataclass - -from .health import HealthCheck, HealthStatus, health_monitor - - -logger = logging.getLogger(__name__) - - -@dataclass -class StreamMetrics: - """Metrics for an individual stream.""" - camera_id: str - stream_type: str # 'rtsp', 'http_snapshot' - start_time: float - last_frame_time: Optional[float] = None - frame_count: int = 0 - error_count: int = 0 - reconnect_count: int = 0 - bytes_received: int = 0 - frames_per_second: float = 0.0 - connection_attempts: int = 0 - last_connection_test: Optional[float] = None - connection_healthy: bool = True - last_error: Optional[str] = None - last_error_time: Optional[float] = None - - -class StreamHealthTracker: - """Tracks health for individual video streams.""" - - def __init__(self): - self.streams: Dict[str, StreamMetrics] = {} - self._lock = threading.RLock() - - # Configuration - self.connection_test_interval = 300 # Test connection every 5 minutes - self.frame_timeout_warning = 120 # Warn if no frames for 2 minutes - self.frame_timeout_critical = 300 # Critical if no frames for 5 minutes - self.error_rate_threshold = 0.1 # 10% error rate threshold - - # Register with health monitor - health_monitor.register_health_checker(self._perform_health_checks) - - def register_stream(self, camera_id: str, stream_type: str, source_url: Optional[str] = None): - """Register a new stream for monitoring.""" - with self._lock: - if camera_id not in self.streams: - self.streams[camera_id] = StreamMetrics( - camera_id=camera_id, - stream_type=stream_type, - start_time=time.time() - ) - logger.info(f"Registered stream for monitoring: {camera_id} ({stream_type})") - - # Update health monitor metrics - health_monitor.update_metrics( - camera_id, - thread_alive=True, - connection_healthy=True - ) - - def unregister_stream(self, camera_id: str): - """Unregister a stream from monitoring.""" - with self._lock: - if camera_id in self.streams: - del self.streams[camera_id] - logger.info(f"Unregistered stream from monitoring: {camera_id}") - - def report_frame_received(self, camera_id: str, frame_size_bytes: int = 0): - """Report that a frame was received.""" - current_time = time.time() - - with self._lock: - if camera_id not in self.streams: - logger.warning(f"Frame received for unregistered stream: {camera_id}") - return - - stream = self.streams[camera_id] - - # Update frame metrics - if stream.last_frame_time: - interval = current_time - stream.last_frame_time - # Calculate FPS as moving average - if stream.frames_per_second == 0: - stream.frames_per_second = 1.0 / interval if interval > 0 else 0 - else: - new_fps = 1.0 / interval if interval > 0 else 0 - stream.frames_per_second = (stream.frames_per_second * 0.9) + (new_fps * 0.1) - - stream.last_frame_time = current_time - stream.frame_count += 1 - stream.bytes_received += frame_size_bytes - - # Report to health monitor - health_monitor.report_frame_received(camera_id) - health_monitor.update_metrics( - camera_id, - frame_count=stream.frame_count, - avg_frame_interval=1.0 / stream.frames_per_second if stream.frames_per_second > 0 else 0, - last_frame_time=current_time - ) - - def report_error(self, camera_id: str, error_message: str): - """Report an error for a stream.""" - current_time = time.time() - - with self._lock: - if camera_id not in self.streams: - logger.warning(f"Error reported for unregistered stream: {camera_id}") - return - - stream = self.streams[camera_id] - stream.error_count += 1 - stream.last_error = error_message - stream.last_error_time = current_time - - # Report to health monitor - health_monitor.report_error(camera_id, "stream_error") - health_monitor.update_metrics( - camera_id, - error_count=stream.error_count - ) - - logger.debug(f"Error reported for stream {camera_id}: {error_message}") - - def report_reconnect(self, camera_id: str, reason: str = "unknown"): - """Report that a stream reconnected.""" - current_time = time.time() - - with self._lock: - if camera_id not in self.streams: - logger.warning(f"Reconnect reported for unregistered stream: {camera_id}") - return - - stream = self.streams[camera_id] - stream.reconnect_count += 1 - - # Report to health monitor - health_monitor.report_restart(camera_id) - health_monitor.update_metrics( - camera_id, - restart_count=stream.reconnect_count - ) - - logger.info(f"Reconnect reported for stream {camera_id}: {reason}") - - def report_connection_attempt(self, camera_id: str, success: bool): - """Report a connection attempt.""" - with self._lock: - if camera_id not in self.streams: - return - - stream = self.streams[camera_id] - stream.connection_attempts += 1 - stream.connection_healthy = success - - # Report to health monitor - health_monitor.update_metrics( - camera_id, - connection_healthy=success - ) - - def test_http_connection(self, camera_id: str, url: str) -> bool: - """Test HTTP connection health for snapshot streams.""" - try: - # Quick HEAD request to test connectivity - response = requests.head(url, timeout=5, verify=False) - success = response.status_code in [200, 404] # 404 might be normal for some cameras - - self.report_connection_attempt(camera_id, success) - - if success: - logger.debug(f"Connection test passed for {camera_id}") - else: - logger.warning(f"Connection test failed for {camera_id}: HTTP {response.status_code}") - - return success - - except Exception as e: - logger.warning(f"Connection test failed for {camera_id}: {e}") - self.report_connection_attempt(camera_id, False) - return False - - def get_stream_metrics(self, camera_id: str) -> Optional[Dict[str, Any]]: - """Get metrics for a specific stream.""" - with self._lock: - if camera_id not in self.streams: - return None - - stream = self.streams[camera_id] - current_time = time.time() - - # Calculate derived metrics - uptime = current_time - stream.start_time - frame_age = current_time - stream.last_frame_time if stream.last_frame_time else None - error_rate = stream.error_count / max(1, stream.frame_count) - - return { - 'camera_id': camera_id, - 'stream_type': stream.stream_type, - 'uptime_seconds': uptime, - 'frame_count': stream.frame_count, - 'frames_per_second': stream.frames_per_second, - 'bytes_received': stream.bytes_received, - 'error_count': stream.error_count, - 'error_rate': error_rate, - 'reconnect_count': stream.reconnect_count, - 'connection_attempts': stream.connection_attempts, - 'connection_healthy': stream.connection_healthy, - 'last_frame_age_seconds': frame_age, - 'last_error': stream.last_error, - 'last_error_time': stream.last_error_time - } - - def get_all_metrics(self) -> Dict[str, Dict[str, Any]]: - """Get metrics for all streams.""" - with self._lock: - return { - camera_id: self.get_stream_metrics(camera_id) - for camera_id in self.streams.keys() - } - - def _perform_health_checks(self) -> List[HealthCheck]: - """Perform health checks for all streams.""" - checks = [] - current_time = time.time() - - with self._lock: - for camera_id, stream in self.streams.items(): - checks.extend(self._check_stream_health(camera_id, stream, current_time)) - - return checks - - def _check_stream_health(self, camera_id: str, stream: StreamMetrics, current_time: float) -> List[HealthCheck]: - """Perform health checks for a single stream.""" - checks = [] - - # Check frame freshness - if stream.last_frame_time: - frame_age = current_time - stream.last_frame_time - - if frame_age > self.frame_timeout_critical: - checks.append(HealthCheck( - name=f"stream_{camera_id}_frames", - status=HealthStatus.CRITICAL, - message=f"No frames for {frame_age:.1f}s (critical threshold: {self.frame_timeout_critical}s)", - details={ - 'frame_age': frame_age, - 'threshold': self.frame_timeout_critical, - 'last_frame_time': stream.last_frame_time - }, - recovery_action="restart_stream" - )) - elif frame_age > self.frame_timeout_warning: - checks.append(HealthCheck( - name=f"stream_{camera_id}_frames", - status=HealthStatus.WARNING, - message=f"Frames aging: {frame_age:.1f}s (warning threshold: {self.frame_timeout_warning}s)", - details={ - 'frame_age': frame_age, - 'threshold': self.frame_timeout_warning, - 'last_frame_time': stream.last_frame_time - } - )) - else: - # No frames received yet - startup_time = current_time - stream.start_time - if startup_time > 60: # Allow 1 minute for initial connection - checks.append(HealthCheck( - name=f"stream_{camera_id}_startup", - status=HealthStatus.CRITICAL, - message=f"No frames received since startup {startup_time:.1f}s ago", - details={ - 'startup_time': startup_time, - 'start_time': stream.start_time - }, - recovery_action="restart_stream" - )) - - # Check error rate - if stream.frame_count > 10: # Need sufficient samples - error_rate = stream.error_count / stream.frame_count - if error_rate > self.error_rate_threshold: - checks.append(HealthCheck( - name=f"stream_{camera_id}_errors", - status=HealthStatus.WARNING, - message=f"High error rate: {error_rate:.1%} ({stream.error_count}/{stream.frame_count})", - details={ - 'error_rate': error_rate, - 'error_count': stream.error_count, - 'frame_count': stream.frame_count, - 'last_error': stream.last_error - } - )) - - # Check connection health - if not stream.connection_healthy: - checks.append(HealthCheck( - name=f"stream_{camera_id}_connection", - status=HealthStatus.WARNING, - message="Connection unhealthy (last test failed)", - details={ - 'connection_attempts': stream.connection_attempts, - 'last_connection_test': stream.last_connection_test - } - )) - - # Check excessive reconnects - uptime_hours = (current_time - stream.start_time) / 3600 - if uptime_hours > 1 and stream.reconnect_count > 5: # More than 5 reconnects per hour - reconnect_rate = stream.reconnect_count / uptime_hours - checks.append(HealthCheck( - name=f"stream_{camera_id}_stability", - status=HealthStatus.WARNING, - message=f"Frequent reconnects: {reconnect_rate:.1f}/hour ({stream.reconnect_count} total)", - details={ - 'reconnect_rate': reconnect_rate, - 'reconnect_count': stream.reconnect_count, - 'uptime_hours': uptime_hours - } - )) - - # Check frame rate health - if stream.last_frame_time and stream.frames_per_second > 0: - expected_fps = 6.0 # Expected FPS for streams - if stream.frames_per_second < expected_fps * 0.5: # Less than 50% of expected - checks.append(HealthCheck( - name=f"stream_{camera_id}_framerate", - status=HealthStatus.WARNING, - message=f"Low frame rate: {stream.frames_per_second:.1f} fps (expected: ~{expected_fps} fps)", - details={ - 'current_fps': stream.frames_per_second, - 'expected_fps': expected_fps - } - )) - - return checks - - -# Global stream health tracker instance -stream_health_tracker = StreamHealthTracker() \ No newline at end of file diff --git a/core/monitoring/thread_health.py b/core/monitoring/thread_health.py deleted file mode 100644 index a29625b..0000000 --- a/core/monitoring/thread_health.py +++ /dev/null @@ -1,381 +0,0 @@ -""" -Thread health monitoring for detecting unresponsive and deadlocked threads. -Provides thread liveness detection and responsiveness testing. -""" -import time -import threading -import logging -import signal -import traceback -from typing import Dict, List, Optional, Any, Callable -from dataclasses import dataclass -from collections import defaultdict - -from .health import HealthCheck, HealthStatus, health_monitor - - -logger = logging.getLogger(__name__) - - -@dataclass -class ThreadInfo: - """Information about a monitored thread.""" - thread_id: int - thread_name: str - start_time: float - last_heartbeat: float - heartbeat_count: int = 0 - is_responsive: bool = True - last_activity: Optional[str] = None - stack_traces: List[str] = None - - -class ThreadHealthMonitor: - """Monitors thread health and responsiveness.""" - - def __init__(self): - self.monitored_threads: Dict[int, ThreadInfo] = {} - self.heartbeat_callbacks: Dict[int, Callable[[], bool]] = {} - self._lock = threading.RLock() - - # Configuration - self.heartbeat_timeout = 60.0 # 1 minute without heartbeat = unresponsive - self.responsiveness_test_interval = 30.0 # Test responsiveness every 30 seconds - self.stack_trace_count = 5 # Keep last 5 stack traces for analysis - - # Register with health monitor - health_monitor.register_health_checker(self._perform_health_checks) - - # Enable periodic responsiveness testing - self.test_thread = threading.Thread(target=self._responsiveness_test_loop, daemon=True) - self.test_thread.start() - - def register_thread(self, thread: threading.Thread, heartbeat_callback: Optional[Callable[[], bool]] = None): - """ - Register a thread for monitoring. - - Args: - thread: Thread to monitor - heartbeat_callback: Optional callback to test thread responsiveness - """ - with self._lock: - thread_info = ThreadInfo( - thread_id=thread.ident, - thread_name=thread.name, - start_time=time.time(), - last_heartbeat=time.time() - ) - - self.monitored_threads[thread.ident] = thread_info - - if heartbeat_callback: - self.heartbeat_callbacks[thread.ident] = heartbeat_callback - - logger.info(f"Registered thread for monitoring: {thread.name} (ID: {thread.ident})") - - def unregister_thread(self, thread_id: int): - """Unregister a thread from monitoring.""" - with self._lock: - if thread_id in self.monitored_threads: - thread_name = self.monitored_threads[thread_id].thread_name - del self.monitored_threads[thread_id] - - if thread_id in self.heartbeat_callbacks: - del self.heartbeat_callbacks[thread_id] - - logger.info(f"Unregistered thread from monitoring: {thread_name} (ID: {thread_id})") - - def heartbeat(self, thread_id: Optional[int] = None, activity: Optional[str] = None): - """ - Report thread heartbeat. - - Args: - thread_id: Thread ID (uses current thread if None) - activity: Description of current activity - """ - if thread_id is None: - thread_id = threading.current_thread().ident - - current_time = time.time() - - with self._lock: - if thread_id in self.monitored_threads: - thread_info = self.monitored_threads[thread_id] - thread_info.last_heartbeat = current_time - thread_info.heartbeat_count += 1 - thread_info.is_responsive = True - - if activity: - thread_info.last_activity = activity - - # Report to health monitor - health_monitor.update_metrics( - f"thread_{thread_info.thread_name}", - thread_alive=True, - last_frame_time=current_time - ) - - def get_thread_info(self, thread_id: int) -> Optional[Dict[str, Any]]: - """Get information about a monitored thread.""" - with self._lock: - if thread_id not in self.monitored_threads: - return None - - thread_info = self.monitored_threads[thread_id] - current_time = time.time() - - return { - 'thread_id': thread_id, - 'thread_name': thread_info.thread_name, - 'uptime_seconds': current_time - thread_info.start_time, - 'last_heartbeat_age': current_time - thread_info.last_heartbeat, - 'heartbeat_count': thread_info.heartbeat_count, - 'is_responsive': thread_info.is_responsive, - 'last_activity': thread_info.last_activity, - 'stack_traces': thread_info.stack_traces or [] - } - - def get_all_thread_info(self) -> Dict[int, Dict[str, Any]]: - """Get information about all monitored threads.""" - with self._lock: - return { - thread_id: self.get_thread_info(thread_id) - for thread_id in self.monitored_threads.keys() - } - - def test_thread_responsiveness(self, thread_id: int) -> bool: - """ - Test if a thread is responsive by calling its heartbeat callback. - - Args: - thread_id: ID of thread to test - - Returns: - True if thread responds within timeout - """ - if thread_id not in self.heartbeat_callbacks: - return True # Can't test if no callback provided - - try: - # Call the heartbeat callback with a timeout - callback = self.heartbeat_callbacks[thread_id] - - # This is a simple approach - in practice you might want to use - # threading.Timer or asyncio for more sophisticated timeout handling - start_time = time.time() - result = callback() - response_time = time.time() - start_time - - with self._lock: - if thread_id in self.monitored_threads: - self.monitored_threads[thread_id].is_responsive = result - - if response_time > 5.0: # Slow response - logger.warning(f"Thread {thread_id} slow response: {response_time:.1f}s") - - return result - - except Exception as e: - logger.error(f"Error testing thread {thread_id} responsiveness: {e}") - with self._lock: - if thread_id in self.monitored_threads: - self.monitored_threads[thread_id].is_responsive = False - return False - - def capture_stack_trace(self, thread_id: int) -> Optional[str]: - """ - Capture stack trace for a thread. - - Args: - thread_id: ID of thread to capture - - Returns: - Stack trace string or None if not available - """ - try: - # Get all frames for all threads - frames = dict(threading._current_frames()) - - if thread_id not in frames: - return None - - # Format stack trace - frame = frames[thread_id] - stack_trace = ''.join(traceback.format_stack(frame)) - - # Store in thread info - with self._lock: - if thread_id in self.monitored_threads: - thread_info = self.monitored_threads[thread_id] - if thread_info.stack_traces is None: - thread_info.stack_traces = [] - - thread_info.stack_traces.append(f"{time.time()}: {stack_trace}") - - # Keep only last N stack traces - if len(thread_info.stack_traces) > self.stack_trace_count: - thread_info.stack_traces = thread_info.stack_traces[-self.stack_trace_count:] - - return stack_trace - - except Exception as e: - logger.error(f"Error capturing stack trace for thread {thread_id}: {e}") - return None - - def detect_deadlocks(self) -> List[Dict[str, Any]]: - """ - Attempt to detect potential deadlocks by analyzing thread states. - - Returns: - List of potential deadlock scenarios - """ - deadlocks = [] - current_time = time.time() - - with self._lock: - # Look for threads that haven't had heartbeats for a long time - # and are supposedly alive - for thread_id, thread_info in self.monitored_threads.items(): - heartbeat_age = current_time - thread_info.last_heartbeat - - if heartbeat_age > self.heartbeat_timeout * 2: # Double the timeout - # Check if thread still exists - thread_exists = any( - t.ident == thread_id and t.is_alive() - for t in threading.enumerate() - ) - - if thread_exists: - # Thread exists but not responding - potential deadlock - stack_trace = self.capture_stack_trace(thread_id) - - deadlock_info = { - 'thread_id': thread_id, - 'thread_name': thread_info.thread_name, - 'heartbeat_age': heartbeat_age, - 'last_activity': thread_info.last_activity, - 'stack_trace': stack_trace, - 'detection_time': current_time - } - - deadlocks.append(deadlock_info) - logger.warning(f"Potential deadlock detected in thread {thread_info.thread_name}") - - return deadlocks - - def _responsiveness_test_loop(self): - """Background loop to test thread responsiveness.""" - logger.info("Thread responsiveness testing started") - - while True: - try: - time.sleep(self.responsiveness_test_interval) - - with self._lock: - thread_ids = list(self.monitored_threads.keys()) - - for thread_id in thread_ids: - try: - self.test_thread_responsiveness(thread_id) - except Exception as e: - logger.error(f"Error testing thread {thread_id}: {e}") - - except Exception as e: - logger.error(f"Error in responsiveness test loop: {e}") - time.sleep(10.0) # Fallback sleep - - def _perform_health_checks(self) -> List[HealthCheck]: - """Perform health checks for all monitored threads.""" - checks = [] - current_time = time.time() - - with self._lock: - for thread_id, thread_info in self.monitored_threads.items(): - checks.extend(self._check_thread_health(thread_id, thread_info, current_time)) - - # Check for deadlocks - deadlocks = self.detect_deadlocks() - for deadlock in deadlocks: - checks.append(HealthCheck( - name=f"deadlock_detection_{deadlock['thread_id']}", - status=HealthStatus.CRITICAL, - message=f"Potential deadlock in thread {deadlock['thread_name']} " - f"(unresponsive for {deadlock['heartbeat_age']:.1f}s)", - details=deadlock, - recovery_action="restart_thread" - )) - - return checks - - def _check_thread_health(self, thread_id: int, thread_info: ThreadInfo, current_time: float) -> List[HealthCheck]: - """Perform health checks for a single thread.""" - checks = [] - - # Check if thread still exists - thread_exists = any( - t.ident == thread_id and t.is_alive() - for t in threading.enumerate() - ) - - if not thread_exists: - checks.append(HealthCheck( - name=f"thread_{thread_info.thread_name}_alive", - status=HealthStatus.CRITICAL, - message=f"Thread {thread_info.thread_name} is no longer alive", - details={ - 'thread_id': thread_id, - 'uptime': current_time - thread_info.start_time, - 'last_heartbeat': thread_info.last_heartbeat - }, - recovery_action="restart_thread" - )) - return checks - - # Check heartbeat freshness - heartbeat_age = current_time - thread_info.last_heartbeat - - if heartbeat_age > self.heartbeat_timeout: - checks.append(HealthCheck( - name=f"thread_{thread_info.thread_name}_responsive", - status=HealthStatus.CRITICAL, - message=f"Thread {thread_info.thread_name} unresponsive for {heartbeat_age:.1f}s", - details={ - 'thread_id': thread_id, - 'heartbeat_age': heartbeat_age, - 'heartbeat_count': thread_info.heartbeat_count, - 'last_activity': thread_info.last_activity, - 'is_responsive': thread_info.is_responsive - }, - recovery_action="restart_thread" - )) - elif heartbeat_age > self.heartbeat_timeout * 0.5: # Warning at 50% of timeout - checks.append(HealthCheck( - name=f"thread_{thread_info.thread_name}_responsive", - status=HealthStatus.WARNING, - message=f"Thread {thread_info.thread_name} slow heartbeat: {heartbeat_age:.1f}s", - details={ - 'thread_id': thread_id, - 'heartbeat_age': heartbeat_age, - 'heartbeat_count': thread_info.heartbeat_count, - 'last_activity': thread_info.last_activity, - 'is_responsive': thread_info.is_responsive - } - )) - - # Check responsiveness test results - if not thread_info.is_responsive: - checks.append(HealthCheck( - name=f"thread_{thread_info.thread_name}_callback", - status=HealthStatus.WARNING, - message=f"Thread {thread_info.thread_name} failed responsiveness test", - details={ - 'thread_id': thread_id, - 'last_activity': thread_info.last_activity - } - )) - - return checks - - -# Global thread health monitor instance -thread_health_monitor = ThreadHealthMonitor() \ No newline at end of file diff --git a/core/storage/__init__.py b/core/storage/__init__.py deleted file mode 100644 index b2ff324..0000000 --- a/core/storage/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -Storage module for the Python Detector Worker. - -This module provides Redis operations for data persistence -and caching in the detection pipeline. - -Note: PostgreSQL operations have been disabled as database functionality -has been moved to microservices architecture. -""" -from .redis import RedisManager -# from .database import DatabaseManager # Disabled - moved to microservices - -__all__ = ['RedisManager'] # Removed 'DatabaseManager' \ No newline at end of file diff --git a/core/storage/database.py b/core/storage/database.py deleted file mode 100644 index 4715768..0000000 --- a/core/storage/database.py +++ /dev/null @@ -1,369 +0,0 @@ -""" -Database Operations Module - DISABLED - -NOTE: This module has been disabled as PostgreSQL database operations have been -moved to microservices architecture. All database connections, reads, and writes -are now handled by separate backend services. - -The detection worker now communicates results via: -- WebSocket imageDetection messages (primary data flow to backend) -- Redis image storage and pub/sub (temporary storage) - -Original functionality: PostgreSQL operations for the detection pipeline. -Status: Commented out - DO NOT ENABLE without updating architecture -""" - -# All PostgreSQL functionality below has been commented out -# import psycopg2 -# import psycopg2.extras -from typing import Optional, Dict, Any -import logging -# import uuid - -logger = logging.getLogger(__name__) - -# DatabaseManager class is disabled - all methods commented out -# class DatabaseManager: -# """ -# Manages PostgreSQL connections and operations for the detection pipeline. -# Handles database operations and schema management. -# """ -# -# def __init__(self, config: Dict[str, Any]): -# """ -# Initialize database manager with configuration. -# -# Args: -# config: Database configuration dictionary -# """ -# self.config = config -# self.connection: Optional[psycopg2.extensions.connection] = None -# -# def connect(self) -> bool: -# """ -# Connect to PostgreSQL database. -# -# Returns: -# True if successful, False otherwise -# """ -# try: -# self.connection = psycopg2.connect( -# host=self.config['host'], -# port=self.config['port'], -# database=self.config['database'], -# user=self.config['username'], -# password=self.config['password'] -# ) -# logger.info("PostgreSQL connection established successfully") -# return True -# except Exception as e: -# logger.error(f"Failed to connect to PostgreSQL: {e}") -# return False -# -# def disconnect(self): -# """Disconnect from PostgreSQL database.""" -# if self.connection: -# self.connection.close() -# self.connection = None -# logger.info("PostgreSQL connection closed") -# -# def is_connected(self) -> bool: -# """ -# Check if database connection is active. -# -# Returns: -# True if connected, False otherwise -# """ -# try: -# if self.connection and not self.connection.closed: -# cur = self.connection.cursor() -# cur.execute("SELECT 1") -# cur.fetchone() -# cur.close() -# return True -# except: -# pass -# return False -# -# def update_car_info(self, session_id: str, brand: str, model: str, body_type: str) -> bool: -# """ -# Update car information in the database. -# -# Args: -# session_id: Session identifier -# brand: Car brand -# model: Car model -# body_type: Car body type -# -# Returns: -# True if successful, False otherwise -# """ -# if not self.is_connected(): -# if not self.connect(): -# return False -# -# try: -# cur = self.connection.cursor() -# query = """ -# INSERT INTO car_frontal_info (session_id, car_brand, car_model, car_body_type, updated_at) -# VALUES (%s, %s, %s, %s, NOW()) -# ON CONFLICT (session_id) -# DO UPDATE SET -# car_brand = EXCLUDED.car_brand, -# car_model = EXCLUDED.car_model, -# car_body_type = EXCLUDED.car_body_type, -# updated_at = NOW() -# """ -# cur.execute(query, (session_id, brand, model, body_type)) -# self.connection.commit() -# cur.close() -# logger.info(f"Updated car info for session {session_id}: {brand} {model} ({body_type})") -# return True -# except Exception as e: -# logger.error(f"Failed to update car info: {e}") -# if self.connection: -# self.connection.rollback() -# return False -# -# def execute_update(self, table: str, key_field: str, key_value: str, fields: Dict[str, str]) -> bool: -# """ -# Execute a dynamic update query on the database. -# -# Args: -# table: Table name -# key_field: Primary key field name -# key_value: Primary key value -# fields: Dictionary of fields to update -# -# Returns: -# True if successful, False otherwise -# """ -# if not self.is_connected(): -# if not self.connect(): -# return False -# -# try: -# cur = self.connection.cursor() -# -# # Build the UPDATE query dynamically -# set_clauses = [] -# values = [] -# -# for field, value in fields.items(): -# if value == "NOW()": -# set_clauses.append(f"{field} = NOW()") -# else: -# set_clauses.append(f"{field} = %s") -# values.append(value) -# -# # Add schema prefix if table doesn't already have it -# full_table_name = table if '.' in table else f"gas_station_1.{table}" -# -# query = f""" -# INSERT INTO {full_table_name} ({key_field}, {', '.join(fields.keys())}) -# VALUES (%s, {', '.join(['%s'] * len(fields))}) -# ON CONFLICT ({key_field}) -# DO UPDATE SET {', '.join(set_clauses)} -# """ -# -# # Add key_value to the beginning of values list -# all_values = [key_value] + list(fields.values()) + values -# -# cur.execute(query, all_values) -# self.connection.commit() -# cur.close() -# logger.info(f"Updated {table} for {key_field}={key_value}") -# return True -# except Exception as e: -# logger.error(f"Failed to execute update on {table}: {e}") -# if self.connection: -# self.connection.rollback() -# return False -# -# def create_car_frontal_info_table(self) -> bool: -# """ -# Create the car_frontal_info table in gas_station_1 schema if it doesn't exist. -# -# Returns: -# True if successful, False otherwise -# """ -# if not self.is_connected(): -# if not self.connect(): -# return False -# -# try: -# # Since the database already exists, just verify connection -# cur = self.connection.cursor() -# -# # Simple verification that the table exists -# cur.execute(""" -# SELECT EXISTS ( -# SELECT FROM information_schema.tables -# WHERE table_schema = 'gas_station_1' -# AND table_name = 'car_frontal_info' -# ) -# """) -# -# table_exists = cur.fetchone()[0] -# cur.close() -# -# if table_exists: -# logger.info("Verified car_frontal_info table exists") -# return True -# else: -# logger.error("car_frontal_info table does not exist in the database") -# return False -# -# except Exception as e: -# logger.error(f"Failed to create car_frontal_info table: {e}") -# if self.connection: -# self.connection.rollback() -# return False -# -# def insert_initial_detection(self, display_id: str, captured_timestamp: str, session_id: str = None) -> str: -# """ -# Insert initial detection record and return the session_id. -# -# Args: -# display_id: Display identifier -# captured_timestamp: Timestamp of the detection -# session_id: Optional session ID, generates one if not provided -# -# Returns: -# Session ID string or None on error -# """ -# if not self.is_connected(): -# if not self.connect(): -# return None -# -# # Generate session_id if not provided -# if not session_id: -# session_id = str(uuid.uuid4()) -# -# try: -# # Ensure table exists -# if not self.create_car_frontal_info_table(): -# logger.error("Failed to create/verify table before insertion") -# return None -# -# cur = self.connection.cursor() -# insert_query = """ -# INSERT INTO gas_station_1.car_frontal_info -# (display_id, captured_timestamp, session_id, license_character, license_type, car_brand, car_model, car_body_type) -# VALUES (%s, %s, %s, NULL, 'No model available', NULL, NULL, NULL) -# ON CONFLICT (session_id) DO NOTHING -# """ -# -# cur.execute(insert_query, (display_id, captured_timestamp, session_id)) -# self.connection.commit() -# cur.close() -# logger.info(f"Inserted initial detection record with session_id: {session_id}") -# return session_id -# -# except Exception as e: -# logger.error(f"Failed to insert initial detection record: {e}") -# if self.connection: -# self.connection.rollback() -# return None -# -# def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]: -# """ -# Get session information from the database. -# -# Args: -# session_id: Session identifier -# -# Returns: -# Dictionary with session data or None if not found -# """ -# if not self.is_connected(): -# if not self.connect(): -# return None -# -# try: -# cur = self.connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor) -# query = "SELECT * FROM gas_station_1.car_frontal_info WHERE session_id = %s" -# cur.execute(query, (session_id,)) -# result = cur.fetchone() -# cur.close() -# -# if result: -# return dict(result) -# else: -# logger.debug(f"No session info found for session_id: {session_id}") -# return None -# -# except Exception as e: -# logger.error(f"Failed to get session info: {e}") -# return None -# -# def delete_session(self, session_id: str) -> bool: -# """ -# Delete session record from the database. -# -# Args: -# session_id: Session identifier -# -# Returns: -# True if successful, False otherwise -# """ -# if not self.is_connected(): -# if not self.connect(): -# return False -# -# try: -# cur = self.connection.cursor() -# query = "DELETE FROM gas_station_1.car_frontal_info WHERE session_id = %s" -# cur.execute(query, (session_id,)) -# rows_affected = cur.rowcount -# self.connection.commit() -# cur.close() -# -# if rows_affected > 0: -# logger.info(f"Deleted session record: {session_id}") -# return True -# else: -# logger.warning(f"No session record found to delete: {session_id}") -# return False -# -# except Exception as e: -# logger.error(f"Failed to delete session: {e}") -# if self.connection: -# self.connection.rollback() -# return False -# -# def get_statistics(self) -> Dict[str, Any]: -# """ -# Get database statistics. -# -# Returns: -# Dictionary with database statistics -# """ -# stats = { -# 'connected': self.is_connected(), -# 'host': self.config.get('host', 'unknown'), -# 'port': self.config.get('port', 'unknown'), -# 'database': self.config.get('database', 'unknown') -# } -# -# if self.is_connected(): -# try: -# cur = self.connection.cursor() -# -# # Get table record count -# cur.execute("SELECT COUNT(*) FROM gas_station_1.car_frontal_info") -# stats['total_records'] = cur.fetchone()[0] -# -# # Get recent records count (last hour) -# cur.execute(""" -# SELECT COUNT(*) FROM gas_station_1.car_frontal_info -# WHERE created_at > NOW() - INTERVAL '1 hour' -# """) -# stats['recent_records'] = cur.fetchone()[0] -# -# cur.close() -# except Exception as e: -# logger.warning(f"Failed to get database statistics: {e}") -# stats['error'] = str(e) -# -# return stats diff --git a/core/storage/license_plate.py b/core/storage/license_plate.py deleted file mode 100644 index 19cbf73..0000000 --- a/core/storage/license_plate.py +++ /dev/null @@ -1,300 +0,0 @@ -""" -License Plate Manager Module. -Handles Redis subscription to license plate results from LPR service. -""" -import logging -import json -import asyncio -from typing import Dict, Optional, Any, Callable -import redis.asyncio as redis - -logger = logging.getLogger(__name__) - - -class LicensePlateManager: - """ - Manages license plate result subscription from Redis channel. - Subscribes to 'license_results' channel for license plate data from LPR service. - """ - - def __init__(self, redis_config: Dict[str, Any]): - """ - Initialize license plate manager with Redis configuration. - - Args: - redis_config: Redis configuration dictionary - """ - self.config = redis_config - self.redis_client: Optional[redis.Redis] = None - self.pubsub = None - self.subscription_task = None - self.callback = None - - # Connection parameters - self.host = redis_config.get('host', 'localhost') - self.port = redis_config.get('port', 6379) - self.password = redis_config.get('password') - self.db = redis_config.get('db', 0) - - # License plate data cache - store recent results by session_id - self.license_plate_cache: Dict[str, Dict[str, Any]] = {} - self.cache_ttl = 300 # 5 minutes TTL for cached results - - logger.info(f"LicensePlateManager initialized for {self.host}:{self.port}") - - async def initialize(self, callback: Optional[Callable] = None) -> bool: - """ - Initialize Redis connection and start subscription to license_results channel. - - Args: - callback: Optional callback function for processing license plate results - - Returns: - True if successful, False otherwise - """ - try: - # Create Redis connection - self.redis_client = redis.Redis( - host=self.host, - port=self.port, - password=self.password, - db=self.db, - decode_responses=True - ) - - # Test connection - await self.redis_client.ping() - logger.info(f"Connected to Redis for license plate subscription") - - # Set callback - self.callback = callback - - # Start subscription - await self._start_subscription() - - return True - - except Exception as e: - logger.error(f"Failed to initialize license plate manager: {e}", exc_info=True) - return False - - async def _start_subscription(self): - """Start Redis subscription to license_results channel.""" - try: - if not self.redis_client: - logger.error("Redis client not initialized") - return - - # Create pubsub and subscribe - self.pubsub = self.redis_client.pubsub() - await self.pubsub.subscribe('license_results') - - logger.info("Subscribed to Redis channel: license_results") - - # Start listening task - self.subscription_task = asyncio.create_task(self._listen_for_messages()) - - except Exception as e: - logger.error(f"Error starting license plate subscription: {e}", exc_info=True) - - async def _listen_for_messages(self): - """Listen for messages on the license_results channel.""" - listen_generator = None - try: - if not self.pubsub: - return - - listen_generator = self.pubsub.listen() - async for message in listen_generator: - if message['type'] == 'message': - try: - # Log the raw message from Redis channel - logger.info(f"[LICENSE PLATE RAW] Received from 'license_results' channel: {message['data']}") - - # Parse the license plate result message - data = json.loads(message['data']) - logger.info(f"[LICENSE PLATE PARSED] Parsed JSON data: {data}") - await self._process_license_plate_result(data) - except json.JSONDecodeError as e: - logger.error(f"[LICENSE PLATE ERROR] Invalid JSON in license plate message: {e}") - logger.error(f"[LICENSE PLATE ERROR] Raw message was: {message['data']}") - except Exception as e: - logger.error(f"Error processing license plate message: {e}", exc_info=True) - - except asyncio.CancelledError: - logger.info("License plate subscription task cancelled") - # Don't try to close generator here - let it be handled by the context - # The async generator will be properly closed by the cancellation mechanism - raise # Re-raise to maintain proper cancellation semantics - except Exception as e: - logger.error(f"Error in license plate message listener: {e}", exc_info=True) - # Only attempt cleanup if it's not a cancellation - finally: - # Safe cleanup of async generator - if listen_generator is not None: - try: - # Check if we can safely close without conflicting with ongoing operations - if hasattr(listen_generator, 'aclose') and not asyncio.current_task().cancelled(): - await listen_generator.aclose() - except (RuntimeError, AttributeError) as e: - # Generator is already closing or in invalid state - safe to ignore - logger.debug(f"Generator cleanup skipped (safe): {e}") - except Exception as e: - logger.debug(f"Generator cleanup error (non-critical): {e}") - - async def _process_license_plate_result(self, data: Dict[str, Any]): - """ - Process incoming license plate result from LPR service. - - Expected message format (from actual LPR service): - { - "session_id": "511", - "license_character": "ข3184" - } - or - { - "session_id": "508", - "display_id": "test3", - "license_plate_text": "ABC-123", - "confidence": 0.95, - "timestamp": "2025-09-24T21:10:00Z" - } - - Args: - data: License plate result data - """ - try: - session_id = data.get('session_id') - if not session_id: - logger.warning("License plate result missing session_id") - return - - # Handle different message formats - # Format 1: {"session_id": "511", "license_character": "ข3184"} - # Format 2: {"session_id": "508", "license_plate_text": "ABC-123", "confidence": 0.95, ...} - license_plate_text = data.get('license_plate_text') or data.get('license_character') - confidence = data.get('confidence', 1.0) # Default confidence for LPR service results - display_id = data.get('display_id', '') - timestamp = data.get('timestamp', '') - - logger.info(f"[LICENSE PLATE] Received result for session {session_id}: " - f"text='{license_plate_text}', confidence={confidence:.3f}") - - # Store in cache - self.license_plate_cache[session_id] = { - 'license_plate_text': license_plate_text, - 'confidence': confidence, - 'display_id': display_id, - 'timestamp': timestamp, - 'received_at': asyncio.get_event_loop().time() - } - - # Call callback if provided - if self.callback: - await self.callback(session_id, { - 'license_plate_text': license_plate_text, - 'confidence': confidence, - 'display_id': display_id, - 'timestamp': timestamp - }) - - except Exception as e: - logger.error(f"Error processing license plate result: {e}", exc_info=True) - - def get_license_plate_result(self, session_id: str) -> Optional[Dict[str, Any]]: - """ - Get cached license plate result for a session. - - Args: - session_id: Session identifier - - Returns: - License plate result dictionary or None if not found - """ - if session_id not in self.license_plate_cache: - return None - - result = self.license_plate_cache[session_id] - - # Check TTL - current_time = asyncio.get_event_loop().time() - if current_time - result.get('received_at', 0) > self.cache_ttl: - # Expired, remove from cache - del self.license_plate_cache[session_id] - return None - - return { - 'license_plate_text': result.get('license_plate_text'), - 'confidence': result.get('confidence'), - 'display_id': result.get('display_id'), - 'timestamp': result.get('timestamp') - } - - def cleanup_expired_results(self): - """Remove expired license plate results from cache.""" - try: - current_time = asyncio.get_event_loop().time() - expired_sessions = [] - - for session_id, result in self.license_plate_cache.items(): - if current_time - result.get('received_at', 0) > self.cache_ttl: - expired_sessions.append(session_id) - - for session_id in expired_sessions: - del self.license_plate_cache[session_id] - logger.debug(f"Removed expired license plate result for session {session_id}") - - except Exception as e: - logger.error(f"Error cleaning up expired license plate results: {e}", exc_info=True) - - async def close(self): - """Close Redis connection and cleanup resources.""" - try: - # Cancel subscription task first - if self.subscription_task and not self.subscription_task.done(): - self.subscription_task.cancel() - try: - await self.subscription_task - except asyncio.CancelledError: - logger.debug("License plate subscription task cancelled successfully") - except Exception as e: - logger.warning(f"Error waiting for subscription task cancellation: {e}") - - # Close pubsub connection properly - if self.pubsub: - try: - # First unsubscribe from channels - await self.pubsub.unsubscribe('license_results') - # Then close the pubsub connection - await self.pubsub.aclose() - except Exception as e: - logger.warning(f"Error closing pubsub connection: {e}") - finally: - self.pubsub = None - - # Close Redis connection - if self.redis_client: - try: - await self.redis_client.aclose() - except Exception as e: - logger.warning(f"Error closing Redis connection: {e}") - finally: - self.redis_client = None - - # Clear cache - self.license_plate_cache.clear() - - logger.info("License plate manager closed successfully") - - except Exception as e: - logger.error(f"Error closing license plate manager: {e}", exc_info=True) - - def get_statistics(self) -> Dict[str, Any]: - """Get license plate manager statistics.""" - return { - 'cached_results': len(self.license_plate_cache), - 'connected': self.redis_client is not None, - 'subscribed': self.pubsub is not None, - 'host': self.host, - 'port': self.port - } \ No newline at end of file diff --git a/core/storage/redis.py b/core/storage/redis.py deleted file mode 100644 index 6672a1b..0000000 --- a/core/storage/redis.py +++ /dev/null @@ -1,478 +0,0 @@ -""" -Redis Operations Module. -Handles Redis connections, image storage, and pub/sub messaging. -""" -import logging -import json -import time -from typing import Optional, Dict, Any, Union -import asyncio -import cv2 -import numpy as np -import redis.asyncio as redis -from redis.exceptions import ConnectionError, TimeoutError - -logger = logging.getLogger(__name__) - - -class RedisManager: - """ - Manages Redis connections and operations for the detection pipeline. - Handles image storage with region cropping and pub/sub messaging. - """ - - def __init__(self, redis_config: Dict[str, Any]): - """ - Initialize Redis manager with configuration. - - Args: - redis_config: Redis configuration dictionary - """ - self.config = redis_config - self.redis_client: Optional[redis.Redis] = None - - # Connection parameters - self.host = redis_config.get('host', 'localhost') - self.port = redis_config.get('port', 6379) - self.password = redis_config.get('password') - self.db = redis_config.get('db', 0) - self.decode_responses = redis_config.get('decode_responses', True) - - # Connection pool settings - self.max_connections = redis_config.get('max_connections', 10) - self.socket_timeout = redis_config.get('socket_timeout', 5) - self.socket_connect_timeout = redis_config.get('socket_connect_timeout', 5) - self.health_check_interval = redis_config.get('health_check_interval', 30) - - # Statistics - self.stats = { - 'images_stored': 0, - 'messages_published': 0, - 'connection_errors': 0, - 'operations_successful': 0, - 'operations_failed': 0 - } - - logger.info(f"RedisManager initialized for {self.host}:{self.port}") - - async def initialize(self) -> bool: - """ - Initialize Redis connection and test connectivity. - - Returns: - True if successful, False otherwise - """ - try: - # Validate configuration - if not self._validate_config(): - return False - - # Create Redis connection - self.redis_client = redis.Redis( - host=self.host, - port=self.port, - password=self.password, - db=self.db, - decode_responses=self.decode_responses, - max_connections=self.max_connections, - socket_timeout=self.socket_timeout, - socket_connect_timeout=self.socket_connect_timeout, - health_check_interval=self.health_check_interval - ) - - # Test connection - await self.redis_client.ping() - logger.info(f"Successfully connected to Redis at {self.host}:{self.port}") - return True - - except ConnectionError as e: - logger.error(f"Failed to connect to Redis: {e}") - self.stats['connection_errors'] += 1 - return False - except Exception as e: - logger.error(f"Error initializing Redis connection: {e}", exc_info=True) - self.stats['connection_errors'] += 1 - return False - - def _validate_config(self) -> bool: - """ - Validate Redis configuration parameters. - - Returns: - True if valid, False otherwise - """ - required_fields = ['host', 'port'] - for field in required_fields: - if field not in self.config: - logger.error(f"Missing required Redis config field: {field}") - return False - - if not isinstance(self.port, int) or self.port <= 0: - logger.error(f"Invalid Redis port: {self.port}") - return False - - return True - - async def is_connected(self) -> bool: - """ - Check if Redis connection is active. - - Returns: - True if connected, False otherwise - """ - try: - if self.redis_client: - await self.redis_client.ping() - return True - except Exception: - pass - return False - - async def save_image(self, - key: str, - image: np.ndarray, - expire_seconds: Optional[int] = None, - image_format: str = 'jpeg', - quality: int = 90) -> bool: - """ - Save image to Redis with optional expiration. - - Args: - key: Redis key for the image - image: Image array to save - expire_seconds: Optional expiration time in seconds - image_format: Image format ('jpeg' or 'png') - quality: JPEG quality (1-100) - - Returns: - True if successful, False otherwise - """ - try: - if not self.redis_client: - logger.error("Redis client not initialized") - self.stats['operations_failed'] += 1 - return False - - # Encode image - encoded_image = self._encode_image(image, image_format, quality) - if encoded_image is None: - logger.error("Failed to encode image") - self.stats['operations_failed'] += 1 - return False - - # Save to Redis - if expire_seconds: - await self.redis_client.setex(key, expire_seconds, encoded_image) - logger.debug(f"Saved image to Redis with key: {key} (expires in {expire_seconds}s)") - else: - await self.redis_client.set(key, encoded_image) - logger.debug(f"Saved image to Redis with key: {key}") - - self.stats['images_stored'] += 1 - self.stats['operations_successful'] += 1 - return True - - except Exception as e: - logger.error(f"Error saving image to Redis: {e}", exc_info=True) - self.stats['operations_failed'] += 1 - return False - - async def get_image(self, key: str) -> Optional[np.ndarray]: - """ - Retrieve image from Redis. - - Args: - key: Redis key for the image - - Returns: - Image array or None if not found - """ - try: - if not self.redis_client: - logger.error("Redis client not initialized") - self.stats['operations_failed'] += 1 - return None - - # Get image data from Redis - image_data = await self.redis_client.get(key) - if image_data is None: - logger.debug(f"Image not found for key: {key}") - return None - - # Decode image - image_array = np.frombuffer(image_data, np.uint8) - image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) - - if image is not None: - logger.debug(f"Retrieved image from Redis with key: {key}") - self.stats['operations_successful'] += 1 - return image - else: - logger.error(f"Failed to decode image for key: {key}") - self.stats['operations_failed'] += 1 - return None - - except Exception as e: - logger.error(f"Error retrieving image from Redis: {e}", exc_info=True) - self.stats['operations_failed'] += 1 - return None - - async def delete_image(self, key: str) -> bool: - """ - Delete image from Redis. - - Args: - key: Redis key for the image - - Returns: - True if successful, False otherwise - """ - try: - if not self.redis_client: - logger.error("Redis client not initialized") - self.stats['operations_failed'] += 1 - return False - - result = await self.redis_client.delete(key) - if result > 0: - logger.debug(f"Deleted image from Redis with key: {key}") - self.stats['operations_successful'] += 1 - return True - else: - logger.debug(f"Image not found for deletion: {key}") - return False - - except Exception as e: - logger.error(f"Error deleting image from Redis: {e}", exc_info=True) - self.stats['operations_failed'] += 1 - return False - - async def publish_message(self, channel: str, message: Union[str, Dict]) -> int: - """ - Publish message to Redis channel. - - Args: - channel: Redis channel name - message: Message to publish (string or dict) - - Returns: - Number of subscribers that received the message, -1 on error - """ - try: - if not self.redis_client: - logger.error("Redis client not initialized") - self.stats['operations_failed'] += 1 - return -1 - - # Convert dict to JSON string if needed - if isinstance(message, dict): - message_str = json.dumps(message) - else: - message_str = str(message) - - # Test connection before publishing - await self.redis_client.ping() - - # Publish message - result = await self.redis_client.publish(channel, message_str) - - logger.info(f"Published message to Redis channel '{channel}': {message_str}") - logger.info(f"Redis publish result (subscribers count): {result}") - - if result == 0: - logger.warning(f"No subscribers listening to channel '{channel}'") - else: - logger.info(f"Message delivered to {result} subscriber(s)") - - self.stats['messages_published'] += 1 - self.stats['operations_successful'] += 1 - return result - - except Exception as e: - logger.error(f"Error publishing message to Redis: {e}", exc_info=True) - self.stats['operations_failed'] += 1 - return -1 - - async def subscribe_to_channel(self, channel: str, callback=None): - """ - Subscribe to Redis channel (for future use). - - Args: - channel: Redis channel name - callback: Optional callback function for messages - """ - try: - if not self.redis_client: - logger.error("Redis client not initialized") - return - - pubsub = self.redis_client.pubsub() - await pubsub.subscribe(channel) - - logger.info(f"Subscribed to Redis channel: {channel}") - - if callback: - async for message in pubsub.listen(): - if message['type'] == 'message': - try: - await callback(message['data']) - except Exception as e: - logger.error(f"Error in message callback: {e}") - - except Exception as e: - logger.error(f"Error subscribing to Redis channel: {e}", exc_info=True) - - async def set_key(self, key: str, value: Union[str, bytes], expire_seconds: Optional[int] = None) -> bool: - """ - Set a key-value pair in Redis. - - Args: - key: Redis key - value: Value to store - expire_seconds: Optional expiration time in seconds - - Returns: - True if successful, False otherwise - """ - try: - if not self.redis_client: - logger.error("Redis client not initialized") - self.stats['operations_failed'] += 1 - return False - - if expire_seconds: - await self.redis_client.setex(key, expire_seconds, value) - else: - await self.redis_client.set(key, value) - - logger.debug(f"Set Redis key: {key}") - self.stats['operations_successful'] += 1 - return True - - except Exception as e: - logger.error(f"Error setting Redis key: {e}", exc_info=True) - self.stats['operations_failed'] += 1 - return False - - async def get_key(self, key: str) -> Optional[Union[str, bytes]]: - """ - Get value for a Redis key. - - Args: - key: Redis key - - Returns: - Value or None if not found - """ - try: - if not self.redis_client: - logger.error("Redis client not initialized") - self.stats['operations_failed'] += 1 - return None - - value = await self.redis_client.get(key) - if value is not None: - logger.debug(f"Retrieved Redis key: {key}") - self.stats['operations_successful'] += 1 - - return value - - except Exception as e: - logger.error(f"Error getting Redis key: {e}", exc_info=True) - self.stats['operations_failed'] += 1 - return None - - async def delete_key(self, key: str) -> bool: - """ - Delete a Redis key. - - Args: - key: Redis key - - Returns: - True if successful, False otherwise - """ - try: - if not self.redis_client: - logger.error("Redis client not initialized") - self.stats['operations_failed'] += 1 - return False - - result = await self.redis_client.delete(key) - if result > 0: - logger.debug(f"Deleted Redis key: {key}") - self.stats['operations_successful'] += 1 - return True - else: - logger.debug(f"Redis key not found: {key}") - return False - - except Exception as e: - logger.error(f"Error deleting Redis key: {e}", exc_info=True) - self.stats['operations_failed'] += 1 - return False - - def _encode_image(self, image: np.ndarray, image_format: str, quality: int) -> Optional[bytes]: - """ - Encode image to bytes for Redis storage. - - Args: - image: Image array - image_format: Image format ('jpeg' or 'png') - quality: JPEG quality (1-100) - - Returns: - Encoded image bytes or None on error - """ - try: - format_lower = image_format.lower() - - if format_lower == 'jpeg' or format_lower == 'jpg': - encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality] - success, buffer = cv2.imencode('.jpg', image, encode_params) - elif format_lower == 'png': - success, buffer = cv2.imencode('.png', image) - else: - logger.warning(f"Unknown image format '{image_format}', using JPEG") - encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality] - success, buffer = cv2.imencode('.jpg', image, encode_params) - - if success: - return buffer.tobytes() - else: - logger.error(f"Failed to encode image as {image_format}") - return None - - except Exception as e: - logger.error(f"Error encoding image: {e}", exc_info=True) - return None - - def get_statistics(self) -> Dict[str, Any]: - """ - Get Redis manager statistics. - - Returns: - Dictionary with statistics - """ - return { - **self.stats, - 'connected': self.redis_client is not None, - 'host': self.host, - 'port': self.port, - 'db': self.db - } - - def cleanup(self): - """Cleanup Redis connection.""" - if self.redis_client: - # Note: redis.asyncio doesn't have a synchronous close method - # The connection will be closed when the event loop shuts down - self.redis_client = None - logger.info("Redis connection cleaned up") - - async def aclose(self): - """Async cleanup for Redis connection.""" - if self.redis_client: - await self.redis_client.aclose() - self.redis_client = None - logger.info("Redis connection closed") \ No newline at end of file diff --git a/core/streaming/__init__.py b/core/streaming/__init__.py deleted file mode 100644 index 93005ab..0000000 --- a/core/streaming/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -Streaming system for RTSP and HTTP camera feeds. -Provides modular frame readers, buffers, and stream management. -""" -from .readers import HTTPSnapshotReader, FFmpegRTSPReader -from .buffers import FrameBuffer, CacheBuffer, shared_frame_buffer, shared_cache_buffer -from .manager import StreamManager, StreamConfig, SubscriptionInfo, shared_stream_manager, initialize_stream_manager - -__all__ = [ - # Readers - 'HTTPSnapshotReader', - 'FFmpegRTSPReader', - - # Buffers - 'FrameBuffer', - 'CacheBuffer', - 'shared_frame_buffer', - 'shared_cache_buffer', - - # Manager - 'StreamManager', - 'StreamConfig', - 'SubscriptionInfo', - 'shared_stream_manager', - 'initialize_stream_manager' -] \ No newline at end of file diff --git a/core/streaming/buffers.py b/core/streaming/buffers.py deleted file mode 100644 index f2c5787..0000000 --- a/core/streaming/buffers.py +++ /dev/null @@ -1,295 +0,0 @@ -""" -Frame buffering and caching system optimized for different stream formats. -Supports 1280x720 RTSP streams and 2560x1440 HTTP snapshots. -""" -import threading -import time -import cv2 -import logging -import numpy as np -from typing import Optional, Dict, Any, Tuple -from collections import defaultdict - - -logger = logging.getLogger(__name__) - - -class FrameBuffer: - """Thread-safe frame buffer for all camera streams.""" - - def __init__(self, max_age_seconds: int = 5): - self.max_age_seconds = max_age_seconds - self._frames: Dict[str, Dict[str, Any]] = {} - self._lock = threading.RLock() - - def put_frame(self, camera_id: str, frame: np.ndarray): - """Store a frame for the given camera ID.""" - with self._lock: - # Validate frame - if not self._validate_frame(frame): - logger.warning(f"Frame validation failed for camera {camera_id}") - return - - self._frames[camera_id] = { - 'frame': frame.copy(), - 'timestamp': time.time(), - 'shape': frame.shape, - 'dtype': str(frame.dtype), - 'size_mb': frame.nbytes / (1024 * 1024) - } - - def get_frame(self, camera_id: str) -> Optional[np.ndarray]: - """Get the latest frame for the given camera ID.""" - with self._lock: - if camera_id not in self._frames: - return None - - frame_data = self._frames[camera_id] - - # Return frame regardless of age - frames persist until replaced - return frame_data['frame'].copy() - - def get_frame_info(self, camera_id: str) -> Optional[Dict[str, Any]]: - """Get frame metadata without copying the frame data.""" - with self._lock: - if camera_id not in self._frames: - return None - - frame_data = self._frames[camera_id] - age = time.time() - frame_data['timestamp'] - - # Return frame info regardless of age - frames persist until replaced - return { - 'timestamp': frame_data['timestamp'], - 'age': age, - 'shape': frame_data['shape'], - 'dtype': frame_data['dtype'], - 'size_mb': frame_data.get('size_mb', 0) - } - - def has_frame(self, camera_id: str) -> bool: - """Check if a valid frame exists for the camera.""" - return self.get_frame_info(camera_id) is not None - - def clear_camera(self, camera_id: str): - """Remove all frames for a specific camera.""" - with self._lock: - if camera_id in self._frames: - del self._frames[camera_id] - logger.debug(f"Cleared frames for camera {camera_id}") - - def clear_all(self): - """Clear all stored frames.""" - with self._lock: - count = len(self._frames) - self._frames.clear() - logger.debug(f"Cleared all frames ({count} cameras)") - - def get_camera_list(self) -> list: - """Get list of cameras with frames - all frames persist until replaced.""" - with self._lock: - # Return all cameras that have frames - no age-based filtering - return list(self._frames.keys()) - - def get_stats(self) -> Dict[str, Any]: - """Get buffer statistics.""" - with self._lock: - current_time = time.time() - stats = { - 'total_cameras': len(self._frames), - 'recent_cameras': 0, - 'stale_cameras': 0, - 'total_memory_mb': 0, - 'cameras': {} - } - - for camera_id, frame_data in self._frames.items(): - age = current_time - frame_data['timestamp'] - size_mb = frame_data.get('size_mb', 0) - - # All frames are valid/available, but categorize by freshness for monitoring - if age <= self.max_age_seconds: - stats['recent_cameras'] += 1 - else: - stats['stale_cameras'] += 1 - - stats['total_memory_mb'] += size_mb - - stats['cameras'][camera_id] = { - 'age': age, - 'recent': age <= self.max_age_seconds, # Recent but all frames available - 'shape': frame_data['shape'], - 'dtype': frame_data['dtype'], - 'size_mb': size_mb - } - - return stats - - def _validate_frame(self, frame: np.ndarray) -> bool: - """Validate frame - basic validation for any stream type.""" - if frame is None or frame.size == 0: - return False - - h, w = frame.shape[:2] - size_mb = frame.nbytes / (1024 * 1024) - - # Basic size validation - reject extremely large frames regardless of type - max_size_mb = 50 # Generous limit for any frame type - if size_mb > max_size_mb: - logger.warning(f"Frame too large: {size_mb:.2f}MB (max {max_size_mb}MB) for {w}x{h}") - return False - - # Basic dimension validation - if w < 100 or h < 100: - logger.warning(f"Frame too small: {w}x{h}") - return False - - return True - - -class CacheBuffer: - """Enhanced frame cache with support for cropping.""" - - def __init__(self, max_age_seconds: int = 10): - self.frame_buffer = FrameBuffer(max_age_seconds) - self._crop_cache: Dict[str, Dict[str, Any]] = {} - self._cache_lock = threading.RLock() - self.jpeg_quality = 95 # High quality for all frames - - def put_frame(self, camera_id: str, frame: np.ndarray): - """Store a frame and clear any associated crop cache.""" - self.frame_buffer.put_frame(camera_id, frame) - - # Clear crop cache for this camera since we have a new frame - with self._cache_lock: - keys_to_remove = [key for key in self._crop_cache.keys() if key.startswith(f"{camera_id}_")] - for key in keys_to_remove: - del self._crop_cache[key] - - def get_frame(self, camera_id: str, crop_coords: Optional[Tuple[int, int, int, int]] = None) -> Optional[np.ndarray]: - """Get frame with optional cropping.""" - if crop_coords is None: - return self.frame_buffer.get_frame(camera_id) - - # Check crop cache first - crop_key = f"{camera_id}_{crop_coords}" - with self._cache_lock: - if crop_key in self._crop_cache: - cache_entry = self._crop_cache[crop_key] - age = time.time() - cache_entry['timestamp'] - if age <= self.frame_buffer.max_age_seconds: - return cache_entry['cropped_frame'].copy() - else: - del self._crop_cache[crop_key] - - # Get original frame and crop it - original_frame = self.frame_buffer.get_frame(camera_id) - if original_frame is None: - return None - - try: - x1, y1, x2, y2 = crop_coords - - # Ensure coordinates are within frame bounds - h, w = original_frame.shape[:2] - x1 = max(0, min(x1, w)) - y1 = max(0, min(y1, h)) - x2 = max(x1, min(x2, w)) - y2 = max(y1, min(y2, h)) - - cropped_frame = original_frame[y1:y2, x1:x2] - - # Cache the cropped frame - with self._cache_lock: - # Limit cache size to prevent memory issues - if len(self._crop_cache) > 100: - # Remove oldest entries - oldest_keys = sorted(self._crop_cache.keys(), - key=lambda k: self._crop_cache[k]['timestamp'])[:50] - for key in oldest_keys: - del self._crop_cache[key] - - self._crop_cache[crop_key] = { - 'cropped_frame': cropped_frame.copy(), - 'timestamp': time.time(), - 'crop_coords': (x1, y1, x2, y2) - } - - return cropped_frame - - except Exception as e: - logger.error(f"Error cropping frame for camera {camera_id}: {e}") - return original_frame - - def get_frame_as_jpeg(self, camera_id: str, crop_coords: Optional[Tuple[int, int, int, int]] = None, - quality: Optional[int] = None) -> Optional[bytes]: - """Get frame as JPEG bytes.""" - frame = self.get_frame(camera_id, crop_coords) - if frame is None: - return None - - try: - # Use specified quality or default - if quality is None: - quality = self.jpeg_quality - - # Encode as JPEG with specified quality - encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality] - success, encoded_img = cv2.imencode('.jpg', frame, encode_params) - - if success: - jpeg_bytes = encoded_img.tobytes() - logger.debug(f"Encoded JPEG for camera {camera_id}: quality={quality}, size={len(jpeg_bytes)} bytes") - return jpeg_bytes - - return None - - except Exception as e: - logger.error(f"Error encoding frame as JPEG for camera {camera_id}: {e}") - return None - - def has_frame(self, camera_id: str) -> bool: - """Check if a valid frame exists for the camera.""" - return self.frame_buffer.has_frame(camera_id) - - def clear_camera(self, camera_id: str): - """Remove all frames and cache for a specific camera.""" - self.frame_buffer.clear_camera(camera_id) - with self._cache_lock: - # Clear crop cache entries for this camera - keys_to_remove = [key for key in self._crop_cache.keys() if key.startswith(f"{camera_id}_")] - for key in keys_to_remove: - del self._crop_cache[key] - - def clear_all(self): - """Clear all stored frames and cache.""" - self.frame_buffer.clear_all() - with self._cache_lock: - self._crop_cache.clear() - - def get_stats(self) -> Dict[str, Any]: - """Get comprehensive buffer and cache statistics.""" - buffer_stats = self.frame_buffer.get_stats() - - with self._cache_lock: - cache_stats = { - 'crop_cache_entries': len(self._crop_cache), - 'crop_cache_cameras': len(set(key.split('_')[0] for key in self._crop_cache.keys() if '_' in key)), - 'crop_cache_memory_mb': sum( - entry['cropped_frame'].nbytes / (1024 * 1024) - for entry in self._crop_cache.values() - ) - } - - return { - 'buffer': buffer_stats, - 'cache': cache_stats, - 'total_memory_mb': buffer_stats.get('total_memory_mb', 0) + cache_stats.get('crop_cache_memory_mb', 0) - } - - -# Global shared instances for application use -shared_frame_buffer = FrameBuffer(max_age_seconds=5) -shared_cache_buffer = CacheBuffer(max_age_seconds=10) - - diff --git a/core/streaming/manager.py b/core/streaming/manager.py deleted file mode 100644 index c4ebd77..0000000 --- a/core/streaming/manager.py +++ /dev/null @@ -1,722 +0,0 @@ -""" -Stream coordination and lifecycle management. -Optimized for 1280x720@6fps RTSP and 2560x1440 HTTP snapshots. -""" -import logging -import threading -import time -import queue -import asyncio -from typing import Dict, Set, Optional, List, Any -from dataclasses import dataclass -from collections import defaultdict - -from .readers import HTTPSnapshotReader, FFmpegRTSPReader -from .buffers import shared_cache_buffer -from ..tracking.integration import TrackingPipelineIntegration - - -logger = logging.getLogger(__name__) - - -@dataclass -class StreamConfig: - """Configuration for a stream.""" - camera_id: str - rtsp_url: Optional[str] = None - snapshot_url: Optional[str] = None - snapshot_interval: int = 5000 # milliseconds - max_retries: int = 3 - - -@dataclass -class SubscriptionInfo: - """Information about a subscription.""" - subscription_id: str - camera_id: str - stream_config: StreamConfig - created_at: float - crop_coords: Optional[tuple] = None - model_id: Optional[str] = None - model_url: Optional[str] = None - tracking_integration: Optional[TrackingPipelineIntegration] = None - - -class StreamManager: - """Manages multiple camera streams with shared optimization.""" - - def __init__(self, max_streams: int = 10): - self.max_streams = max_streams - self._streams: Dict[str, Any] = {} # camera_id -> reader instance - self._subscriptions: Dict[str, SubscriptionInfo] = {} # subscription_id -> info - self._camera_subscribers: Dict[str, Set[str]] = defaultdict(set) # camera_id -> set of subscription_ids - self._lock = threading.RLock() - - # Fair tracking queue system - per camera queues - self._tracking_queues: Dict[str, queue.Queue] = {} # camera_id -> queue - self._tracking_workers = [] - self._stop_workers = threading.Event() - self._dropped_frame_counts: Dict[str, int] = {} # per-camera drop counts - - # Round-robin scheduling state - self._camera_list = [] # Ordered list of active cameras - self._camera_round_robin_index = 0 - self._round_robin_lock = threading.Lock() - - # Start worker threads for tracking processing - num_workers = min(4, max_streams // 2 + 1) # Scale with streams - for i in range(num_workers): - worker = threading.Thread( - target=self._tracking_worker_loop, - name=f"TrackingWorker-{i}", - daemon=True - ) - worker.start() - self._tracking_workers.append(worker) - - logger.info(f"Started {num_workers} tracking worker threads") - - def _ensure_camera_queue(self, camera_id: str): - """Ensure a tracking queue exists for the camera.""" - if camera_id not in self._tracking_queues: - self._tracking_queues[camera_id] = queue.Queue(maxsize=10) # 10 frames per camera - self._dropped_frame_counts[camera_id] = 0 - - with self._round_robin_lock: - if camera_id not in self._camera_list: - self._camera_list.append(camera_id) - logger.info(f"Created tracking queue for camera {camera_id}") - else: - logger.debug(f"Camera {camera_id} already has tracking queue") - - def _remove_camera_queue(self, camera_id: str): - """Remove tracking queue for a camera that's no longer active.""" - if camera_id in self._tracking_queues: - # Clear any remaining items - while not self._tracking_queues[camera_id].empty(): - try: - self._tracking_queues[camera_id].get_nowait() - except queue.Empty: - break - - del self._tracking_queues[camera_id] - del self._dropped_frame_counts[camera_id] - - with self._round_robin_lock: - if camera_id in self._camera_list: - self._camera_list.remove(camera_id) - # Reset index if needed - if self._camera_round_robin_index >= len(self._camera_list): - self._camera_round_robin_index = 0 - - logger.info(f"Removed tracking queue for camera {camera_id}") - - def add_subscription(self, subscription_id: str, stream_config: StreamConfig, - crop_coords: Optional[tuple] = None, - model_id: Optional[str] = None, - model_url: Optional[str] = None, - tracking_integration: Optional[TrackingPipelineIntegration] = None) -> bool: - """Add a new subscription. Returns True if successful.""" - with self._lock: - if subscription_id in self._subscriptions: - logger.warning(f"Subscription {subscription_id} already exists") - return False - - camera_id = stream_config.camera_id - - # Create subscription info - subscription_info = SubscriptionInfo( - subscription_id=subscription_id, - camera_id=camera_id, - stream_config=stream_config, - created_at=time.time(), - crop_coords=crop_coords, - model_id=model_id, - model_url=model_url, - tracking_integration=tracking_integration - ) - - # Pass subscription info to tracking integration for snapshot access - if tracking_integration: - tracking_integration.set_subscription_info(subscription_info) - - self._subscriptions[subscription_id] = subscription_info - self._camera_subscribers[camera_id].add(subscription_id) - - # Start stream if not already running - if camera_id not in self._streams: - if len(self._streams) >= self.max_streams: - logger.error(f"Maximum streams ({self.max_streams}) reached, cannot add {camera_id}") - self._remove_subscription_internal(subscription_id) - return False - - success = self._start_stream(camera_id, stream_config) - if not success: - self._remove_subscription_internal(subscription_id) - return False - else: - # Stream already exists, but ensure queue exists too - logger.info(f"Stream already exists for {camera_id}, ensuring queue exists") - self._ensure_camera_queue(camera_id) - - logger.info(f"Added subscription {subscription_id} for camera {camera_id} " - f"({len(self._camera_subscribers[camera_id])} total subscribers)") - return True - - def remove_subscription(self, subscription_id: str) -> bool: - """Remove a subscription. Returns True if found and removed.""" - with self._lock: - return self._remove_subscription_internal(subscription_id) - - def _remove_subscription_internal(self, subscription_id: str) -> bool: - """Internal method to remove subscription (assumes lock is held).""" - if subscription_id not in self._subscriptions: - logger.warning(f"Subscription {subscription_id} not found") - return False - - subscription_info = self._subscriptions[subscription_id] - camera_id = subscription_info.camera_id - - # Remove from tracking - del self._subscriptions[subscription_id] - self._camera_subscribers[camera_id].discard(subscription_id) - - # Stop stream if no more subscribers - if not self._camera_subscribers[camera_id]: - self._stop_stream(camera_id) - del self._camera_subscribers[camera_id] - - logger.info(f"Removed subscription {subscription_id} for camera {camera_id} " - f"({len(self._camera_subscribers[camera_id])} remaining subscribers)") - return True - - def _start_stream(self, camera_id: str, stream_config: StreamConfig) -> bool: - """Start a stream for the given camera.""" - try: - if stream_config.rtsp_url: - # RTSP stream using FFmpeg subprocess with CUDA acceleration - logger.info(f"\033[94m[RTSP] Starting {camera_id}\033[0m") - reader = FFmpegRTSPReader( - camera_id=camera_id, - rtsp_url=stream_config.rtsp_url, - max_retries=stream_config.max_retries - ) - reader.set_frame_callback(self._frame_callback) - reader.start() - self._streams[camera_id] = reader - self._ensure_camera_queue(camera_id) # Create tracking queue - logger.info(f"\033[92m[RTSP] {camera_id} connected\033[0m") - - elif stream_config.snapshot_url: - # HTTP snapshot stream - logger.info(f"\033[95m[HTTP] Starting {camera_id}\033[0m") - reader = HTTPSnapshotReader( - camera_id=camera_id, - snapshot_url=stream_config.snapshot_url, - interval_ms=stream_config.snapshot_interval, - max_retries=stream_config.max_retries - ) - reader.set_frame_callback(self._frame_callback) - reader.start() - self._streams[camera_id] = reader - self._ensure_camera_queue(camera_id) # Create tracking queue - logger.info(f"\033[92m[HTTP] {camera_id} connected\033[0m") - - else: - logger.error(f"No valid URL provided for camera {camera_id}") - return False - - return True - - except Exception as e: - logger.error(f"Error starting stream for camera {camera_id}: {e}") - return False - - def _stop_stream(self, camera_id: str): - """Stop a stream for the given camera.""" - if camera_id in self._streams: - try: - self._streams[camera_id].stop() - del self._streams[camera_id] - self._remove_camera_queue(camera_id) # Remove tracking queue - # DON'T clear frames - they should persist until replaced - # shared_cache_buffer.clear_camera(camera_id) # REMOVED - frames should persist - logger.info(f"Stopped stream for camera {camera_id} (frames preserved in buffer)") - except Exception as e: - logger.error(f"Error stopping stream for camera {camera_id}: {e}") - - def _frame_callback(self, camera_id: str, frame): - """Callback for when a new frame is available.""" - try: - # Store frame in shared buffer - shared_cache_buffer.put_frame(camera_id, frame) - # Quieter frame callback logging - only log occasionally - if hasattr(self, '_frame_log_count'): - self._frame_log_count += 1 - else: - self._frame_log_count = 1 - - # Log every 100 frames to avoid spam - if self._frame_log_count % 100 == 0: - available_cameras = shared_cache_buffer.frame_buffer.get_camera_list() - logger.info(f"\033[96m[BUFFER] {len(available_cameras)} active cameras: {', '.join(available_cameras)}\033[0m") - - # Queue for tracking processing (non-blocking) - route to camera-specific queue - if camera_id in self._tracking_queues: - try: - self._tracking_queues[camera_id].put_nowait({ - 'frame': frame, - 'timestamp': time.time() - }) - except queue.Full: - # Drop frame if camera queue is full (maintain real-time) - self._dropped_frame_counts[camera_id] += 1 - - if self._dropped_frame_counts[camera_id] % 50 == 0: - logger.warning(f"Dropped {self._dropped_frame_counts[camera_id]} frames for camera {camera_id} due to full queue") - - except Exception as e: - logger.error(f"Error in frame callback for camera {camera_id}: {e}") - - def _process_tracking_for_camera(self, camera_id: str, frame): - """Process tracking for all subscriptions of a camera.""" - try: - with self._lock: - for subscription_id in self._camera_subscribers[camera_id]: - subscription_info = self._subscriptions[subscription_id] - - # Skip if no tracking integration - if not subscription_info.tracking_integration: - continue - - # Extract display_id from subscription_id - display_id = subscription_id.split(';')[0] if ';' in subscription_id else subscription_id - - # Process frame through tracking asynchronously - # Note: This is synchronous for now, can be made async in future - try: - # Create a simple asyncio event loop for this frame - import asyncio - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - result = loop.run_until_complete( - subscription_info.tracking_integration.process_frame( - frame, display_id, subscription_id - ) - ) - # Log tracking results - if result: - tracked_count = len(result.get('tracked_vehicles', [])) - validated_vehicle = result.get('validated_vehicle') - pipeline_result = result.get('pipeline_result') - - if tracked_count > 0: - logger.info(f"[Tracking] {camera_id}: {tracked_count} vehicles tracked") - - if validated_vehicle: - logger.info(f"[Tracking] {camera_id}: Vehicle {validated_vehicle['track_id']} " - f"validated as {validated_vehicle['state']} " - f"(confidence: {validated_vehicle['confidence']:.2f})") - - if pipeline_result: - logger.info(f"[Pipeline] {camera_id}: {pipeline_result.get('status', 'unknown')} - " - f"{pipeline_result.get('message', 'no message')}") - finally: - loop.close() - except Exception as track_e: - logger.error(f"Error in tracking for {subscription_id}: {track_e}") - - except Exception as e: - logger.error(f"Error processing tracking for camera {camera_id}: {e}") - - def _tracking_worker_loop(self): - """Worker thread loop for round-robin processing of camera queues.""" - logger.info(f"Tracking worker {threading.current_thread().name} started") - - consecutive_empty = 0 - max_consecutive_empty = 10 # Sleep if all cameras empty this many times - - while not self._stop_workers.is_set(): - try: - # Get next camera in round-robin fashion - camera_id, item = self._get_next_camera_item() - - if camera_id is None: - # No cameras have items, sleep briefly - consecutive_empty += 1 - if consecutive_empty >= max_consecutive_empty: - time.sleep(0.1) # Sleep 100ms if nothing to process - consecutive_empty = 0 - continue - - consecutive_empty = 0 # Reset counter when we find work - - frame = item['frame'] - timestamp = item['timestamp'] - - # Check if frame is too old (drop if > 1 second old) - age = time.time() - timestamp - if age > 1.0: - logger.debug(f"Dropping old frame for {camera_id} (age: {age:.2f}s)") - continue - - # Process tracking for this camera's frame - self._process_tracking_for_camera_sync(camera_id, frame) - - except Exception as e: - logger.error(f"Error in tracking worker: {e}", exc_info=True) - - logger.info(f"Tracking worker {threading.current_thread().name} stopped") - - def _get_next_camera_item(self): - """Get next item from camera queues using round-robin scheduling.""" - with self._round_robin_lock: - # Get current list of cameras from actual tracking queues (central state) - camera_list = list(self._tracking_queues.keys()) - - if not camera_list: - return None, None - - attempts = 0 - max_attempts = len(camera_list) - - while attempts < max_attempts: - # Get current camera using round-robin index - if self._camera_round_robin_index >= len(camera_list): - self._camera_round_robin_index = 0 - - camera_id = camera_list[self._camera_round_robin_index] - - # Move to next camera for next call - self._camera_round_robin_index = (self._camera_round_robin_index + 1) % len(camera_list) - - # Try to get item from this camera's queue - try: - item = self._tracking_queues[camera_id].get_nowait() - return camera_id, item - except queue.Empty: - pass # Try next camera - - attempts += 1 - - return None, None # All cameras empty - - def _process_tracking_for_camera_sync(self, camera_id: str, frame): - """Synchronous version of tracking processing for worker threads.""" - try: - with self._lock: - subscription_ids = list(self._camera_subscribers.get(camera_id, [])) - - for subscription_id in subscription_ids: - subscription_info = self._subscriptions.get(subscription_id) - - if not subscription_info: - logger.warning(f"No subscription info found for {subscription_id}") - continue - - if not subscription_info.tracking_integration: - logger.debug(f"No tracking integration for {subscription_id} (camera {camera_id}), skipping inference") - continue - - display_id = subscription_id.split(';')[0] if ';' in subscription_id else subscription_id - - try: - # Run async tracking in thread's event loop - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - result = loop.run_until_complete( - subscription_info.tracking_integration.process_frame( - frame, display_id, subscription_id - ) - ) - - # Log tracking results - if result: - tracked_count = len(result.get('tracked_vehicles', [])) - validated_vehicle = result.get('validated_vehicle') - pipeline_result = result.get('pipeline_result') - - if tracked_count > 0: - logger.info(f"[Tracking] {camera_id}: {tracked_count} vehicles tracked") - - if validated_vehicle: - logger.info(f"[Tracking] {camera_id}: Vehicle {validated_vehicle['track_id']} " - f"validated as {validated_vehicle['state']} " - f"(confidence: {validated_vehicle['confidence']:.2f})") - - if pipeline_result: - logger.info(f"[Pipeline] {camera_id}: {pipeline_result.get('status', 'unknown')} - " - f"{pipeline_result.get('message', 'no message')}") - finally: - loop.close() - - except Exception as track_e: - logger.error(f"Error in tracking for {subscription_id}: {track_e}") - - except Exception as e: - logger.error(f"Error processing tracking for camera {camera_id}: {e}") - - def get_frame(self, camera_id: str, crop_coords: Optional[tuple] = None): - """Get the latest frame for a camera with optional cropping.""" - return shared_cache_buffer.get_frame(camera_id, crop_coords) - - def get_frame_as_jpeg(self, camera_id: str, crop_coords: Optional[tuple] = None, - quality: int = 100) -> Optional[bytes]: - """Get frame as JPEG bytes for HTTP responses with highest quality by default.""" - return shared_cache_buffer.get_frame_as_jpeg(camera_id, crop_coords, quality) - - def has_frame(self, camera_id: str) -> bool: - """Check if a frame is available for the camera.""" - return shared_cache_buffer.has_frame(camera_id) - - def get_subscription_info(self, subscription_id: str) -> Optional[SubscriptionInfo]: - """Get information about a subscription.""" - with self._lock: - return self._subscriptions.get(subscription_id) - - def get_camera_subscribers(self, camera_id: str) -> Set[str]: - """Get all subscription IDs for a camera.""" - with self._lock: - return self._camera_subscribers[camera_id].copy() - - def get_active_cameras(self) -> List[str]: - """Get list of cameras with active streams.""" - with self._lock: - return list(self._streams.keys()) - - def get_all_subscriptions(self) -> List[SubscriptionInfo]: - """Get all active subscriptions.""" - with self._lock: - return list(self._subscriptions.values()) - - def reconcile_subscriptions(self, target_subscriptions: List[Dict[str, Any]]) -> Dict[str, Any]: - """ - Reconcile current subscriptions with target list. - Returns summary of changes made. - """ - with self._lock: - current_subscription_ids = set(self._subscriptions.keys()) - target_subscription_ids = {sub['subscriptionIdentifier'] for sub in target_subscriptions} - - # Find subscriptions to remove and add - to_remove = current_subscription_ids - target_subscription_ids - to_add = target_subscription_ids - current_subscription_ids - - # Remove old subscriptions - removed_count = 0 - for subscription_id in to_remove: - if self._remove_subscription_internal(subscription_id): - removed_count += 1 - - # Add new subscriptions - added_count = 0 - failed_count = 0 - for target_sub in target_subscriptions: - subscription_id = target_sub['subscriptionIdentifier'] - if subscription_id in to_add: - success = self._add_subscription_from_payload(subscription_id, target_sub) - if success: - added_count += 1 - else: - failed_count += 1 - - result = { - 'removed': removed_count, - 'added': added_count, - 'failed': failed_count, - 'total_active': len(self._subscriptions), - 'active_streams': len(self._streams) - } - - logger.info(f"Subscription reconciliation: {result}") - return result - - def _add_subscription_from_payload(self, subscription_id: str, payload: Dict[str, Any]) -> bool: - """Add subscription from WebSocket payload format.""" - try: - # Extract camera ID from subscription identifier - # Format: "display-001;cam-001" -> camera_id = "cam-001" - camera_id = subscription_id.split(';')[-1] - - # Extract crop coordinates if present - crop_coords = None - if all(key in payload for key in ['cropX1', 'cropY1', 'cropX2', 'cropY2']): - crop_coords = ( - payload['cropX1'], - payload['cropY1'], - payload['cropX2'], - payload['cropY2'] - ) - - # Create stream configuration - stream_config = StreamConfig( - camera_id=camera_id, - rtsp_url=payload.get('rtspUrl'), - snapshot_url=payload.get('snapshotUrl'), - snapshot_interval=payload.get('snapshotInterval', 5000), - max_retries=3, - ) - - return self.add_subscription( - subscription_id, - stream_config, - crop_coords, - model_id=payload.get('modelId'), - model_url=payload.get('modelUrl') - ) - - except Exception as e: - logger.error(f"Error adding subscription from payload {subscription_id}: {e}") - return False - - def stop_all(self): - """Stop all streams and clear all subscriptions.""" - # Signal workers to stop - self._stop_workers.set() - - # Clear all camera queues - for camera_id, camera_queue in list(self._tracking_queues.items()): - while not camera_queue.empty(): - try: - camera_queue.get_nowait() - except queue.Empty: - break - - # Wait for workers to finish - for worker in self._tracking_workers: - worker.join(timeout=2.0) - - # Clear queue management structures - self._tracking_queues.clear() - self._dropped_frame_counts.clear() - with self._round_robin_lock: - self._camera_list.clear() - self._camera_round_robin_index = 0 - - logger.info("Stopped all tracking worker threads") - - with self._lock: - # Stop all streams - for camera_id in list(self._streams.keys()): - self._stop_stream(camera_id) - - # Clear all tracking - self._subscriptions.clear() - self._camera_subscribers.clear() - shared_cache_buffer.clear_all() - - logger.info("Stopped all streams and cleared all subscriptions") - - def set_session_id(self, display_id: str, session_id: str): - """Set session ID for tracking integration.""" - # Ensure session_id is always a string for consistent type handling - session_id = str(session_id) if session_id is not None else None - with self._lock: - for subscription_info in self._subscriptions.values(): - # Check if this subscription matches the display_id - subscription_display_id = subscription_info.subscription_id.split(';')[0] - if subscription_display_id == display_id and subscription_info.tracking_integration: - # Pass the full subscription_id (displayId;cameraId) to the tracking integration - subscription_info.tracking_integration.set_session_id( - display_id, - session_id, - subscription_id=subscription_info.subscription_id - ) - logger.debug(f"Set session {session_id} for display {display_id} with subscription {subscription_info.subscription_id}") - - def clear_session_id(self, session_id: str): - """Clear session ID from the specific tracking integration handling this session.""" - with self._lock: - # Find the subscription that's handling this session - session_subscription = None - for subscription_info in self._subscriptions.values(): - if subscription_info.tracking_integration: - # Check if this integration is handling the given session_id - integration = subscription_info.tracking_integration - if session_id in integration.session_vehicles: - session_subscription = subscription_info - break - - if session_subscription and session_subscription.tracking_integration: - session_subscription.tracking_integration.clear_session_id(session_id) - logger.debug(f"Cleared session {session_id} from subscription {session_subscription.subscription_id}") - else: - logger.warning(f"No tracking integration found for session {session_id}, broadcasting to all subscriptions") - # Fallback: broadcast to all (original behavior) - for subscription_info in self._subscriptions.values(): - if subscription_info.tracking_integration: - subscription_info.tracking_integration.clear_session_id(session_id) - - def set_progression_stage(self, session_id: str, stage: str): - """Set progression stage for the specific tracking integration handling this session.""" - with self._lock: - # Find the subscription that's handling this session - session_subscription = None - for subscription_info in self._subscriptions.values(): - if subscription_info.tracking_integration: - # Check if this integration is handling the given session_id - # We need to check the integration's active sessions - integration = subscription_info.tracking_integration - if session_id in integration.session_vehicles: - session_subscription = subscription_info - break - - if session_subscription and session_subscription.tracking_integration: - session_subscription.tracking_integration.set_progression_stage(session_id, stage) - logger.debug(f"Set progression stage for session {session_id}: {stage} on subscription {session_subscription.subscription_id}") - else: - logger.warning(f"No tracking integration found for session {session_id}, broadcasting to all subscriptions") - # Fallback: broadcast to all (original behavior) - for subscription_info in self._subscriptions.values(): - if subscription_info.tracking_integration: - subscription_info.tracking_integration.set_progression_stage(session_id, stage) - - def get_tracking_stats(self) -> Dict[str, Any]: - """Get tracking statistics from all subscriptions.""" - stats = {} - with self._lock: - for subscription_id, subscription_info in self._subscriptions.items(): - if subscription_info.tracking_integration: - stats[subscription_id] = subscription_info.tracking_integration.get_statistics() - return stats - - - def get_stats(self) -> Dict[str, Any]: - """Get comprehensive streaming statistics.""" - with self._lock: - buffer_stats = shared_cache_buffer.get_stats() - tracking_stats = self.get_tracking_stats() - - return { - 'active_subscriptions': len(self._subscriptions), - 'active_streams': len(self._streams), - 'cameras_with_subscribers': len(self._camera_subscribers), - 'max_streams': self.max_streams, - 'subscriptions_by_camera': { - camera_id: len(subscribers) - for camera_id, subscribers in self._camera_subscribers.items() - }, - 'buffer_stats': buffer_stats, - 'tracking_stats': tracking_stats, - 'memory_usage_mb': buffer_stats.get('total_memory_mb', 0) - } - - -# Global shared instance for application use -# Default initialization, will be updated with config value in app.py -shared_stream_manager = StreamManager(max_streams=20) - -def initialize_stream_manager(max_streams: int = 10): - """Re-initialize the global stream manager with config value.""" - global shared_stream_manager - # Release old manager if exists - if shared_stream_manager: - try: - # Stop all existing streams gracefully - shared_stream_manager.stop_all() - except Exception as e: - logger.warning(f"Error stopping previous stream manager: {e}") - shared_stream_manager = StreamManager(max_streams=max_streams) - return shared_stream_manager \ No newline at end of file diff --git a/core/streaming/readers/__init__.py b/core/streaming/readers/__init__.py deleted file mode 100644 index 0903d6d..0000000 --- a/core/streaming/readers/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -""" -Stream readers for RTSP and HTTP camera feeds. -""" -from .base import VideoReader -from .ffmpeg_rtsp import FFmpegRTSPReader -from .http_snapshot import HTTPSnapshotReader -from .utils import log_success, log_warning, log_error, log_info, Colors - -__all__ = [ - 'VideoReader', - 'FFmpegRTSPReader', - 'HTTPSnapshotReader', - 'log_success', - 'log_warning', - 'log_error', - 'log_info', - 'Colors' -] \ No newline at end of file diff --git a/core/streaming/readers/base.py b/core/streaming/readers/base.py deleted file mode 100644 index 56c41cb..0000000 --- a/core/streaming/readers/base.py +++ /dev/null @@ -1,65 +0,0 @@ -""" -Abstract base class for video stream readers. -""" -from abc import ABC, abstractmethod -from typing import Optional, Callable -import numpy as np - - -class VideoReader(ABC): - """Abstract base class for video stream readers.""" - - def __init__(self, camera_id: str, source_url: str, max_retries: int = 3): - """ - Initialize the video reader. - - Args: - camera_id: Unique identifier for the camera - source_url: URL or path to the video source - max_retries: Maximum number of retry attempts - """ - self.camera_id = camera_id - self.source_url = source_url - self.max_retries = max_retries - self.frame_callback: Optional[Callable[[str, np.ndarray], None]] = None - - @abstractmethod - def start(self) -> None: - """Start the video reader.""" - pass - - @abstractmethod - def stop(self) -> None: - """Stop the video reader.""" - pass - - @abstractmethod - def set_frame_callback(self, callback: Callable[[str, np.ndarray], None]) -> None: - """ - Set callback function to handle captured frames. - - Args: - callback: Function that takes (camera_id, frame) as arguments - """ - pass - - @property - @abstractmethod - def is_running(self) -> bool: - """Check if the reader is currently running.""" - pass - - @property - @abstractmethod - def reader_type(self) -> str: - """Get the type of reader (e.g., 'rtsp', 'http_snapshot').""" - pass - - def __enter__(self): - """Context manager entry.""" - self.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Context manager exit.""" - self.stop() \ No newline at end of file diff --git a/core/streaming/readers/ffmpeg_rtsp.py b/core/streaming/readers/ffmpeg_rtsp.py deleted file mode 100644 index 88f45ae..0000000 --- a/core/streaming/readers/ffmpeg_rtsp.py +++ /dev/null @@ -1,436 +0,0 @@ -""" -FFmpeg RTSP stream reader using subprocess piping frames directly to buffer. -Enhanced with comprehensive health monitoring and automatic recovery. -""" -import cv2 -import time -import threading -import numpy as np -import subprocess -import struct -from typing import Optional, Callable, Dict, Any - -from .base import VideoReader -from .utils import log_success, log_warning, log_error, log_info -from ...monitoring.stream_health import stream_health_tracker -from ...monitoring.thread_health import thread_health_monitor -from ...monitoring.recovery import recovery_manager, RecoveryAction - - -class FFmpegRTSPReader(VideoReader): - """RTSP stream reader using subprocess FFmpeg piping frames directly to buffer.""" - - def __init__(self, camera_id: str, rtsp_url: str, max_retries: int = 3): - super().__init__(camera_id, rtsp_url, max_retries) - self.rtsp_url = rtsp_url - self.process = None - self.stop_event = threading.Event() - self.thread = None - self.stderr_thread = None - - # Expected stream specs (for reference, actual dimensions read from PPM header) - self.width = 1280 - self.height = 720 - - # Watchdog timers for stream reliability - self.process_start_time = None - self.last_frame_time = None - self.is_restart = False # Track if this is a restart (shorter timeout) - self.first_start_timeout = 30.0 # 30s timeout on first start - self.restart_timeout = 15.0 # 15s timeout after restart - - # Health monitoring setup - self.last_heartbeat = time.time() - self.consecutive_errors = 0 - self.ffmpeg_restart_count = 0 - - # Register recovery handlers - recovery_manager.register_recovery_handler( - RecoveryAction.RESTART_STREAM, - self._handle_restart_recovery - ) - recovery_manager.register_recovery_handler( - RecoveryAction.RECONNECT, - self._handle_reconnect_recovery - ) - - @property - def is_running(self) -> bool: - """Check if the reader is currently running.""" - return self.thread is not None and self.thread.is_alive() - - @property - def reader_type(self) -> str: - """Get the type of reader.""" - return "rtsp_ffmpeg" - - def set_frame_callback(self, callback: Callable[[str, np.ndarray], None]): - """Set callback function to handle captured frames.""" - self.frame_callback = callback - - def start(self): - """Start the FFmpeg subprocess reader.""" - if self.thread and self.thread.is_alive(): - log_warning(self.camera_id, "FFmpeg reader already running") - return - - self.stop_event.clear() - self.thread = threading.Thread(target=self._read_frames, daemon=True) - self.thread.start() - - # Register with health monitoring - stream_health_tracker.register_stream(self.camera_id, "rtsp_ffmpeg", self.rtsp_url) - thread_health_monitor.register_thread(self.thread, self._heartbeat_callback) - - log_success(self.camera_id, "Stream started with health monitoring") - - def stop(self): - """Stop the FFmpeg subprocess reader.""" - self.stop_event.set() - - # Unregister from health monitoring - if self.thread: - thread_health_monitor.unregister_thread(self.thread.ident) - - if self.process: - self.process.terminate() - try: - self.process.wait(timeout=5) - except subprocess.TimeoutExpired: - self.process.kill() - - if self.thread: - self.thread.join(timeout=5.0) - if self.stderr_thread: - self.stderr_thread.join(timeout=2.0) - - stream_health_tracker.unregister_stream(self.camera_id) - - log_info(self.camera_id, "Stream stopped") - - def _start_ffmpeg_process(self): - """Start FFmpeg subprocess outputting BMP frames to stdout pipe.""" - cmd = [ - 'ffmpeg', - # DO NOT REMOVE - '-hwaccel', 'cuda', - '-hwaccel_device', '0', - # Real-time input flags - '-fflags', 'nobuffer+genpts', - '-flags', 'low_delay', - '-max_delay', '0', # No reordering delay - # RTSP configuration - '-rtsp_transport', 'tcp', - '-i', self.rtsp_url, - # Output configuration (keeping BMP) - '-f', 'image2pipe', # Output images to pipe - '-vcodec', 'bmp', # BMP format with header containing dimensions - '-vsync', 'passthrough', # Pass frames as-is - # Use native stream resolution and framerate - '-an', # No audio - '-' # Output to stdout - ] - - try: - # Start FFmpeg with stdout pipe to read frames directly - self.process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, # Capture stdout for frame data - stderr=subprocess.PIPE, # Capture stderr for error logging - bufsize=0 # Unbuffered for real-time processing - ) - - # Start stderr reading thread - if self.stderr_thread and self.stderr_thread.is_alive(): - # Stop previous stderr thread - try: - self.stderr_thread.join(timeout=1.0) - except: - pass - - self.stderr_thread = threading.Thread(target=self._read_stderr, daemon=True) - self.stderr_thread.start() - - # Set process start time for watchdog - self.process_start_time = time.time() - self.last_frame_time = None # Reset frame time - - # After successful restart, next timeout will be back to 30s - if self.is_restart: - log_info(self.camera_id, f"FFmpeg restarted successfully, next timeout: {self.first_start_timeout}s") - self.is_restart = False - - return True - except Exception as e: - log_error(self.camera_id, f"FFmpeg startup failed: {e}") - return False - - def _read_bmp_frame(self, pipe): - """Read BMP frame from pipe - BMP header contains dimensions.""" - try: - # Read BMP header (14 bytes file header + 40 bytes info header = 54 bytes minimum) - header_data = b'' - bytes_to_read = 54 - - while len(header_data) < bytes_to_read: - chunk = pipe.read(bytes_to_read - len(header_data)) - if not chunk: - return None # Silent end of stream - header_data += chunk - - # Parse BMP header - if header_data[:2] != b'BM': - return None # Invalid format, skip frame silently - - # Extract file size from header (bytes 2-5) - file_size = struct.unpack(' bool: - """Check if watchdog timeout has been exceeded.""" - if not self.process_start_time: - return False - - current_time = time.time() - time_since_start = current_time - self.process_start_time - - # Determine timeout based on whether this is a restart - timeout = self.restart_timeout if self.is_restart else self.first_start_timeout - - # If no frames received yet, check against process start time - if not self.last_frame_time: - if time_since_start > timeout: - log_warning(self.camera_id, f"Watchdog timeout: No frames for {time_since_start:.1f}s (limit: {timeout}s)") - return True - else: - # Check time since last frame - time_since_frame = current_time - self.last_frame_time - if time_since_frame > timeout: - log_warning(self.camera_id, f"Watchdog timeout: No frames for {time_since_frame:.1f}s (limit: {timeout}s)") - return True - - return False - - def _restart_ffmpeg_process(self): - """Restart FFmpeg process due to watchdog timeout.""" - log_warning(self.camera_id, "Watchdog triggered FFmpeg restart") - - # Terminate current process - if self.process: - try: - self.process.terminate() - self.process.wait(timeout=3) - except subprocess.TimeoutExpired: - self.process.kill() - except Exception: - pass - self.process = None - - # Mark as restart for shorter timeout - self.is_restart = True - - # Small delay before restart - time.sleep(1.0) - - def _read_frames(self): - """Read frames directly from FFmpeg stdout pipe.""" - frame_count = 0 - last_log_time = time.time() - - while not self.stop_event.is_set(): - try: - # Send heartbeat for thread health monitoring - self._send_heartbeat("reading_frames") - - # Check watchdog timeout if process is running - if self.process and self.process.poll() is None: - if self._check_watchdog_timeout(): - self._restart_ffmpeg_process() - continue - - # Start FFmpeg if not running - if not self.process or self.process.poll() is not None: - if self.process and self.process.poll() is not None: - log_warning(self.camera_id, "Stream disconnected, reconnecting...") - stream_health_tracker.report_error( - self.camera_id, - "FFmpeg process disconnected" - ) - - if not self._start_ffmpeg_process(): - self.consecutive_errors += 1 - stream_health_tracker.report_error( - self.camera_id, - "Failed to start FFmpeg process" - ) - time.sleep(5.0) - continue - - # Read frames directly from FFmpeg stdout - try: - if self.process and self.process.stdout: - # Read BMP frame data - frame = self._read_bmp_frame(self.process.stdout) - if frame is None: - continue - - # Update watchdog - we got a frame - self.last_frame_time = time.time() - - # Reset error counter on successful frame - self.consecutive_errors = 0 - - # Report successful frame to health monitoring - frame_size = frame.nbytes - stream_health_tracker.report_frame_received(self.camera_id, frame_size) - - # Call frame callback - if self.frame_callback: - try: - self.frame_callback(self.camera_id, frame) - except Exception as e: - stream_health_tracker.report_error( - self.camera_id, - f"Frame callback error: {e}" - ) - - frame_count += 1 - - # Log progress every 60 seconds (quieter) - current_time = time.time() - if current_time - last_log_time >= 60: - log_success(self.camera_id, f"{frame_count} frames captured ({frame.shape[1]}x{frame.shape[0]})") - last_log_time = current_time - - except Exception as e: - # Process might have died, let it restart on next iteration - stream_health_tracker.report_error( - self.camera_id, - f"Frame reading error: {e}" - ) - if self.process: - self.process.terminate() - self.process = None - time.sleep(1.0) - - except Exception as e: - stream_health_tracker.report_error( - self.camera_id, - f"Main loop error: {e}" - ) - time.sleep(1.0) - - # Cleanup - if self.process: - self.process.terminate() - - # Health monitoring methods - def _send_heartbeat(self, activity: str = "running"): - """Send heartbeat to thread health monitor.""" - self.last_heartbeat = time.time() - thread_health_monitor.heartbeat(activity=activity) - - def _heartbeat_callback(self) -> bool: - """Heartbeat callback for thread responsiveness testing.""" - try: - # Check if thread is responsive by checking recent heartbeat - current_time = time.time() - age = current_time - self.last_heartbeat - - # Thread is responsive if heartbeat is recent - return age < 30.0 # 30 second responsiveness threshold - - except Exception: - return False - - def _handle_restart_recovery(self, component: str, details: Dict[str, Any]) -> bool: - """Handle restart recovery action.""" - try: - log_info(self.camera_id, "Restarting FFmpeg RTSP reader for health recovery") - - # Stop current instance - self.stop() - - # Small delay - time.sleep(2.0) - - # Restart - self.start() - - # Report successful restart - stream_health_tracker.report_reconnect(self.camera_id, "health_recovery_restart") - self.ffmpeg_restart_count += 1 - - return True - - except Exception as e: - log_error(self.camera_id, f"Failed to restart FFmpeg RTSP reader: {e}") - return False - - def _handle_reconnect_recovery(self, component: str, details: Dict[str, Any]) -> bool: - """Handle reconnect recovery action.""" - try: - log_info(self.camera_id, "Reconnecting FFmpeg RTSP reader for health recovery") - - # Force restart FFmpeg process - self._restart_ffmpeg_process() - - # Reset error counters - self.consecutive_errors = 0 - stream_health_tracker.report_reconnect(self.camera_id, "health_recovery_reconnect") - - return True - - except Exception as e: - log_error(self.camera_id, f"Failed to reconnect FFmpeg RTSP reader: {e}") - return False \ No newline at end of file diff --git a/core/streaming/readers/http_snapshot.py b/core/streaming/readers/http_snapshot.py deleted file mode 100644 index bbbf943..0000000 --- a/core/streaming/readers/http_snapshot.py +++ /dev/null @@ -1,378 +0,0 @@ -""" -HTTP snapshot reader optimized for 2560x1440 (2K) high quality images. -Enhanced with comprehensive health monitoring and automatic recovery. -""" -import cv2 -import logging -import time -import threading -import requests -import numpy as np -from typing import Optional, Callable, Dict, Any - -from .base import VideoReader -from .utils import log_success, log_warning, log_error, log_info -from ...monitoring.stream_health import stream_health_tracker -from ...monitoring.thread_health import thread_health_monitor -from ...monitoring.recovery import recovery_manager, RecoveryAction - -logger = logging.getLogger(__name__) - - -class HTTPSnapshotReader(VideoReader): - """HTTP snapshot reader optimized for 2560x1440 (2K) high quality images.""" - - def __init__(self, camera_id: str, snapshot_url: str, interval_ms: int = 5000, max_retries: int = 3): - super().__init__(camera_id, snapshot_url, max_retries) - self.snapshot_url = snapshot_url - self.interval_ms = interval_ms - self.stop_event = threading.Event() - self.thread = None - - # Expected snapshot specifications - self.expected_width = 2560 - self.expected_height = 1440 - self.max_file_size = 10 * 1024 * 1024 # 10MB max for 2K image - - # Health monitoring setup - self.last_heartbeat = time.time() - self.consecutive_errors = 0 - self.connection_test_interval = 300 # Test connection every 5 minutes - self.last_connection_test = None - - # Register recovery handlers - recovery_manager.register_recovery_handler( - RecoveryAction.RESTART_STREAM, - self._handle_restart_recovery - ) - recovery_manager.register_recovery_handler( - RecoveryAction.RECONNECT, - self._handle_reconnect_recovery - ) - - @property - def is_running(self) -> bool: - """Check if the reader is currently running.""" - return self.thread is not None and self.thread.is_alive() - - @property - def reader_type(self) -> str: - """Get the type of reader.""" - return "http_snapshot" - - def set_frame_callback(self, callback: Callable[[str, np.ndarray], None]): - """Set callback function to handle captured frames.""" - self.frame_callback = callback - - def start(self): - """Start the snapshot reader thread.""" - if self.thread and self.thread.is_alive(): - logger.warning(f"Snapshot reader for {self.camera_id} already running") - return - - self.stop_event.clear() - self.thread = threading.Thread(target=self._read_snapshots, daemon=True) - self.thread.start() - - # Register with health monitoring - stream_health_tracker.register_stream(self.camera_id, "http_snapshot", self.snapshot_url) - thread_health_monitor.register_thread(self.thread, self._heartbeat_callback) - - logger.info(f"Started snapshot reader for camera {self.camera_id} with health monitoring") - - def stop(self): - """Stop the snapshot reader thread.""" - self.stop_event.set() - - # Unregister from health monitoring - if self.thread: - thread_health_monitor.unregister_thread(self.thread.ident) - self.thread.join(timeout=5.0) - - stream_health_tracker.unregister_stream(self.camera_id) - - logger.info(f"Stopped snapshot reader for camera {self.camera_id}") - - def _read_snapshots(self): - """Main snapshot reading loop for high quality 2K images.""" - retries = 0 - frame_count = 0 - last_log_time = time.time() - last_connection_test = time.time() - interval_seconds = self.interval_ms / 1000.0 - - logger.info(f"Snapshot interval for camera {self.camera_id}: {interval_seconds}s") - - while not self.stop_event.is_set(): - try: - # Send heartbeat for thread health monitoring - self._send_heartbeat("fetching_snapshot") - - start_time = time.time() - frame = self._fetch_snapshot() - - if frame is None: - retries += 1 - self.consecutive_errors += 1 - - # Report error to health monitoring - stream_health_tracker.report_error( - self.camera_id, - f"Failed to fetch snapshot (retry {retries}/{self.max_retries})" - ) - - logger.warning(f"Failed to fetch snapshot for camera {self.camera_id}, retry {retries}/{self.max_retries}") - - if self.max_retries != -1 and retries > self.max_retries: - logger.error(f"Max retries reached for snapshot camera {self.camera_id}") - break - - time.sleep(min(2.0, interval_seconds)) - continue - - # Accept any valid image dimensions - don't force specific resolution - if frame.shape[1] <= 0 or frame.shape[0] <= 0: - logger.warning(f"Camera {self.camera_id}: Invalid frame dimensions {frame.shape[1]}x{frame.shape[0]}") - stream_health_tracker.report_error( - self.camera_id, - f"Invalid frame dimensions: {frame.shape[1]}x{frame.shape[0]}" - ) - continue - - # Reset retry counter on successful fetch - retries = 0 - self.consecutive_errors = 0 - frame_count += 1 - - # Report successful frame to health monitoring - frame_size = frame.nbytes - stream_health_tracker.report_frame_received(self.camera_id, frame_size) - - # Call frame callback - if self.frame_callback: - try: - self.frame_callback(self.camera_id, frame) - except Exception as e: - logger.error(f"Camera {self.camera_id}: Frame callback error: {e}") - stream_health_tracker.report_error(self.camera_id, f"Frame callback error: {e}") - - # Periodic connection health test - current_time = time.time() - if current_time - last_connection_test >= self.connection_test_interval: - self._test_connection_health() - last_connection_test = current_time - - # Log progress every 30 seconds - if current_time - last_log_time >= 30: - logger.info(f"Camera {self.camera_id}: {frame_count} snapshots processed") - last_log_time = current_time - - # Wait for next interval - elapsed = time.time() - start_time - sleep_time = max(0, interval_seconds - elapsed) - if sleep_time > 0: - self.stop_event.wait(sleep_time) - - except Exception as e: - logger.error(f"Error in snapshot loop for camera {self.camera_id}: {e}") - stream_health_tracker.report_error(self.camera_id, f"Snapshot loop error: {e}") - retries += 1 - if self.max_retries != -1 and retries > self.max_retries: - break - time.sleep(min(2.0, interval_seconds)) - - logger.info(f"Snapshot reader thread ended for camera {self.camera_id}") - - def _fetch_snapshot(self) -> Optional[np.ndarray]: - """Fetch a single high quality snapshot from HTTP URL.""" - try: - # Parse URL for authentication - from urllib.parse import urlparse - parsed_url = urlparse(self.snapshot_url) - - headers = { - 'User-Agent': 'Python-Detector-Worker/1.0', - 'Accept': 'image/jpeg, image/png, image/*' - } - auth = None - - if parsed_url.username and parsed_url.password: - from requests.auth import HTTPBasicAuth, HTTPDigestAuth - auth = HTTPBasicAuth(parsed_url.username, parsed_url.password) - - # Reconstruct URL without credentials - clean_url = f"{parsed_url.scheme}://{parsed_url.hostname}" - if parsed_url.port: - clean_url += f":{parsed_url.port}" - clean_url += parsed_url.path - if parsed_url.query: - clean_url += f"?{parsed_url.query}" - - # Try Basic Auth first - response = requests.get(clean_url, auth=auth, timeout=15, headers=headers, - stream=True, verify=False) - - # If Basic Auth fails, try Digest Auth - if response.status_code == 401: - auth = HTTPDigestAuth(parsed_url.username, parsed_url.password) - response = requests.get(clean_url, auth=auth, timeout=15, headers=headers, - stream=True, verify=False) - else: - response = requests.get(self.snapshot_url, timeout=15, headers=headers, - stream=True, verify=False) - - if response.status_code == 200: - # Check content size - content_length = int(response.headers.get('content-length', 0)) - if content_length > self.max_file_size: - logger.warning(f"Snapshot too large for camera {self.camera_id}: {content_length} bytes") - return None - - # Read content - content = response.content - - # Convert to numpy array - image_array = np.frombuffer(content, np.uint8) - - # Decode as high quality image - frame = cv2.imdecode(image_array, cv2.IMREAD_COLOR) - - if frame is None: - logger.error(f"Failed to decode snapshot for camera {self.camera_id}") - return None - - logger.debug(f"Fetched snapshot for camera {self.camera_id}: {frame.shape[1]}x{frame.shape[0]}") - return frame - else: - logger.warning(f"HTTP {response.status_code} from {self.camera_id}") - return None - - except requests.RequestException as e: - logger.error(f"Request error fetching snapshot for {self.camera_id}: {e}") - return None - except Exception as e: - logger.error(f"Error decoding snapshot for {self.camera_id}: {e}") - return None - - def fetch_single_snapshot(self) -> Optional[np.ndarray]: - """ - Fetch a single high-quality snapshot on demand for pipeline processing. - This method is for one-time fetch from HTTP URL, not continuous streaming. - - Returns: - High quality 2K snapshot frame or None if failed - """ - logger.info(f"[SNAPSHOT] Fetching snapshot for {self.camera_id} from {self.snapshot_url}") - - # Try to fetch snapshot with retries - for attempt in range(self.max_retries): - frame = self._fetch_snapshot() - - if frame is not None: - logger.info(f"[SNAPSHOT] Successfully fetched {frame.shape[1]}x{frame.shape[0]} snapshot for {self.camera_id}") - return frame - - if attempt < self.max_retries - 1: - logger.warning(f"[SNAPSHOT] Attempt {attempt + 1}/{self.max_retries} failed for {self.camera_id}, retrying...") - time.sleep(0.5) - - logger.error(f"[SNAPSHOT] Failed to fetch snapshot for {self.camera_id} after {self.max_retries} attempts") - return None - - def _resize_maintain_aspect(self, frame: np.ndarray, target_width: int, target_height: int) -> np.ndarray: - """Resize image while maintaining aspect ratio for high quality.""" - h, w = frame.shape[:2] - aspect = w / h - target_aspect = target_width / target_height - - if aspect > target_aspect: - # Image is wider - new_width = target_width - new_height = int(target_width / aspect) - else: - # Image is taller - new_height = target_height - new_width = int(target_height * aspect) - - # Use INTER_LANCZOS4 for high quality downsampling - resized = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4) - - # Pad to target size if needed - if new_width < target_width or new_height < target_height: - top = (target_height - new_height) // 2 - bottom = target_height - new_height - top - left = (target_width - new_width) // 2 - right = target_width - new_width - left - resized = cv2.copyMakeBorder(resized, top, bottom, left, right, cv2.BORDER_CONSTANT, value=[0, 0, 0]) - - return resized - - # Health monitoring methods - def _send_heartbeat(self, activity: str = "running"): - """Send heartbeat to thread health monitor.""" - self.last_heartbeat = time.time() - thread_health_monitor.heartbeat(activity=activity) - - def _heartbeat_callback(self) -> bool: - """Heartbeat callback for thread responsiveness testing.""" - try: - # Check if thread is responsive by checking recent heartbeat - current_time = time.time() - age = current_time - self.last_heartbeat - - # Thread is responsive if heartbeat is recent - return age < 30.0 # 30 second responsiveness threshold - - except Exception: - return False - - def _test_connection_health(self): - """Test HTTP connection health.""" - try: - stream_health_tracker.test_http_connection(self.camera_id, self.snapshot_url) - except Exception as e: - logger.error(f"Error testing connection health for {self.camera_id}: {e}") - - def _handle_restart_recovery(self, component: str, details: Dict[str, Any]) -> bool: - """Handle restart recovery action.""" - try: - logger.info(f"Restarting HTTP snapshot reader for {self.camera_id}") - - # Stop current instance - self.stop() - - # Small delay - time.sleep(2.0) - - # Restart - self.start() - - # Report successful restart - stream_health_tracker.report_reconnect(self.camera_id, "health_recovery_restart") - - return True - - except Exception as e: - logger.error(f"Failed to restart HTTP snapshot reader for {self.camera_id}: {e}") - return False - - def _handle_reconnect_recovery(self, component: str, details: Dict[str, Any]) -> bool: - """Handle reconnect recovery action.""" - try: - logger.info(f"Reconnecting HTTP snapshot reader for {self.camera_id}") - - # Test connection first - success = stream_health_tracker.test_http_connection(self.camera_id, self.snapshot_url) - - if success: - # Reset error counters - self.consecutive_errors = 0 - stream_health_tracker.report_reconnect(self.camera_id, "health_recovery_reconnect") - return True - else: - logger.warning(f"Connection test failed during recovery for {self.camera_id}") - return False - - except Exception as e: - logger.error(f"Failed to reconnect HTTP snapshot reader for {self.camera_id}: {e}") - return False \ No newline at end of file diff --git a/core/streaming/readers/utils.py b/core/streaming/readers/utils.py deleted file mode 100644 index 813f49f..0000000 --- a/core/streaming/readers/utils.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -Utility functions for stream readers. -""" -import logging -import os - -# Keep OpenCV errors visible but allow FFmpeg stderr logging -os.environ["OPENCV_LOG_LEVEL"] = "ERROR" - -logger = logging.getLogger(__name__) - -# Color codes for pretty logging -class Colors: - GREEN = '\033[92m' - YELLOW = '\033[93m' - RED = '\033[91m' - BLUE = '\033[94m' - PURPLE = '\033[95m' - CYAN = '\033[96m' - WHITE = '\033[97m' - BOLD = '\033[1m' - END = '\033[0m' - -def log_success(camera_id: str, message: str): - """Log success messages in green""" - logger.info(f"{Colors.GREEN}[{camera_id}] {message}{Colors.END}") - -def log_warning(camera_id: str, message: str): - """Log warnings in yellow""" - logger.warning(f"{Colors.YELLOW}[{camera_id}] {message}{Colors.END}") - -def log_error(camera_id: str, message: str): - """Log errors in red""" - logger.error(f"{Colors.RED}[{camera_id}] {message}{Colors.END}") - -def log_info(camera_id: str, message: str): - """Log info in cyan""" - logger.info(f"{Colors.CYAN}[{camera_id}] {message}{Colors.END}") \ No newline at end of file diff --git a/core/tracking/__init__.py b/core/tracking/__init__.py deleted file mode 100644 index a493062..0000000 --- a/core/tracking/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Tracking module for vehicle tracking and validation - -from .tracker import VehicleTracker, TrackedVehicle -from .validator import StableCarValidator, ValidationResult, VehicleState -from .integration import TrackingPipelineIntegration - -__all__ = [ - 'VehicleTracker', - 'TrackedVehicle', - 'StableCarValidator', - 'ValidationResult', - 'VehicleState', - 'TrackingPipelineIntegration' -] \ No newline at end of file diff --git a/core/tracking/bot_sort_tracker.py b/core/tracking/bot_sort_tracker.py deleted file mode 100644 index f487a6a..0000000 --- a/core/tracking/bot_sort_tracker.py +++ /dev/null @@ -1,408 +0,0 @@ -""" -BoT-SORT Multi-Object Tracker with Camera Isolation -Based on BoT-SORT: Robust Associations Multi-Pedestrian Tracking -""" - -import logging -import time -import numpy as np -from typing import Dict, List, Optional, Tuple, Any -from dataclasses import dataclass -from scipy.optimize import linear_sum_assignment -from filterpy.kalman import KalmanFilter -import cv2 - -logger = logging.getLogger(__name__) - - -@dataclass -class TrackState: - """Track state enumeration""" - TENTATIVE = "tentative" # New track, not confirmed yet - CONFIRMED = "confirmed" # Confirmed track - DELETED = "deleted" # Track to be deleted - - -class Track: - """ - Individual track representation with Kalman filter for motion prediction - """ - - def __init__(self, detection, track_id: int, camera_id: str): - """ - Initialize a new track - - Args: - detection: Initial detection (bbox, confidence, class) - track_id: Unique track identifier within camera - camera_id: Camera identifier - """ - self.track_id = track_id - self.camera_id = camera_id - self.state = TrackState.TENTATIVE - - # Time tracking - self.start_time = time.time() - self.last_update_time = time.time() - - # Appearance and motion - self.bbox = detection.bbox # [x1, y1, x2, y2] - self.confidence = detection.confidence - self.class_name = detection.class_name - - # Track management - self.hit_streak = 1 - self.time_since_update = 0 - self.age = 1 - - # Kalman filter for motion prediction - self.kf = self._create_kalman_filter() - self._update_kalman_filter(detection.bbox) - - # Track history - self.history = [detection.bbox] - self.max_history = 10 - - def _create_kalman_filter(self) -> KalmanFilter: - """Create Kalman filter for bbox tracking (x, y, w, h, vx, vy, vw, vh)""" - kf = KalmanFilter(dim_x=8, dim_z=4) - - # State transition matrix (constant velocity model) - kf.F = np.array([ - [1, 0, 0, 0, 1, 0, 0, 0], - [0, 1, 0, 0, 0, 1, 0, 0], - [0, 0, 1, 0, 0, 0, 1, 0], - [0, 0, 0, 1, 0, 0, 0, 1], - [0, 0, 0, 0, 1, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 1] - ]) - - # Measurement matrix (observe x, y, w, h) - kf.H = np.array([ - [1, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0] - ]) - - # Process noise - kf.Q *= 0.01 - - # Measurement noise - kf.R *= 10 - - # Initial covariance - kf.P *= 100 - - return kf - - def _update_kalman_filter(self, bbox: List[float]): - """Update Kalman filter with new bbox""" - # Convert [x1, y1, x2, y2] to [cx, cy, w, h] - x1, y1, x2, y2 = bbox - cx = (x1 + x2) / 2 - cy = (y1 + y2) / 2 - w = x2 - x1 - h = y2 - y1 - - # Properly assign to column vector - self.kf.x[:4, 0] = [cx, cy, w, h] - - def predict(self) -> np.ndarray: - """Predict next position using Kalman filter""" - self.kf.predict() - - # Convert back to [x1, y1, x2, y2] format - cx, cy, w, h = self.kf.x[:4, 0] # Extract from column vector - x1 = cx - w/2 - y1 = cy - h/2 - x2 = cx + w/2 - y2 = cy + h/2 - - return np.array([x1, y1, x2, y2]) - - def update(self, detection): - """Update track with new detection""" - self.last_update_time = time.time() - self.time_since_update = 0 - self.hit_streak += 1 - self.age += 1 - - # Update track properties - self.bbox = detection.bbox - self.confidence = detection.confidence - - # Update Kalman filter - x1, y1, x2, y2 = detection.bbox - cx = (x1 + x2) / 2 - cy = (y1 + y2) / 2 - w = x2 - x1 - h = y2 - y1 - - self.kf.update([cx, cy, w, h]) - - # Update history - self.history.append(detection.bbox) - if len(self.history) > self.max_history: - self.history.pop(0) - - # Update state - if self.state == TrackState.TENTATIVE and self.hit_streak >= 3: - self.state = TrackState.CONFIRMED - - def mark_missed(self): - """Mark track as missed in this frame""" - self.time_since_update += 1 - self.age += 1 - - if self.time_since_update > 5: # Delete after 5 missed frames - self.state = TrackState.DELETED - - def is_confirmed(self) -> bool: - """Check if track is confirmed""" - return self.state == TrackState.CONFIRMED - - def is_deleted(self) -> bool: - """Check if track should be deleted""" - return self.state == TrackState.DELETED - - -class CameraTracker: - """ - BoT-SORT tracker for a single camera - """ - - def __init__(self, camera_id: str, max_disappeared: int = 10): - """ - Initialize camera tracker - - Args: - camera_id: Unique camera identifier - max_disappeared: Maximum frames a track can be missed before deletion - """ - self.camera_id = camera_id - self.max_disappeared = max_disappeared - - # Track management - self.tracks: Dict[int, Track] = {} - self.next_id = 1 - self.frame_count = 0 - - logger.info(f"Initialized BoT-SORT tracker for camera {camera_id}") - - def update(self, detections: List) -> List[Track]: - """ - Update tracker with new detections - - Args: - detections: List of Detection objects - - Returns: - List of active confirmed tracks - """ - self.frame_count += 1 - - # Predict all existing tracks - for track in self.tracks.values(): - track.predict() - - # Associate detections to tracks - matched_tracks, unmatched_detections, unmatched_tracks = self._associate(detections) - - # Update matched tracks - for track_id, detection in matched_tracks: - self.tracks[track_id].update(detection) - - # Mark unmatched tracks as missed - for track_id in unmatched_tracks: - self.tracks[track_id].mark_missed() - - # Create new tracks for unmatched detections - for detection in unmatched_detections: - track = Track(detection, self.next_id, self.camera_id) - self.tracks[self.next_id] = track - self.next_id += 1 - - # Remove deleted tracks - tracks_to_remove = [tid for tid, track in self.tracks.items() if track.is_deleted()] - for tid in tracks_to_remove: - del self.tracks[tid] - - # Return confirmed tracks - confirmed_tracks = [track for track in self.tracks.values() if track.is_confirmed()] - - return confirmed_tracks - - def _associate(self, detections: List) -> Tuple[List[Tuple[int, Any]], List[Any], List[int]]: - """ - Associate detections to existing tracks using IoU distance - - Returns: - (matched_tracks, unmatched_detections, unmatched_tracks) - """ - if not detections or not self.tracks: - return [], detections, list(self.tracks.keys()) - - # Calculate IoU distance matrix - track_ids = list(self.tracks.keys()) - cost_matrix = np.zeros((len(track_ids), len(detections))) - - for i, track_id in enumerate(track_ids): - track = self.tracks[track_id] - predicted_bbox = track.predict() - - for j, detection in enumerate(detections): - iou = self._calculate_iou(predicted_bbox, detection.bbox) - cost_matrix[i, j] = 1 - iou # Convert IoU to distance - - # Solve assignment problem - row_indices, col_indices = linear_sum_assignment(cost_matrix) - - # Filter matches by IoU threshold - iou_threshold = 0.3 - matched_tracks = [] - matched_detection_indices = set() - matched_track_indices = set() - - for row, col in zip(row_indices, col_indices): - if cost_matrix[row, col] <= (1 - iou_threshold): - track_id = track_ids[row] - detection = detections[col] - matched_tracks.append((track_id, detection)) - matched_detection_indices.add(col) - matched_track_indices.add(row) - - # Find unmatched detections and tracks - unmatched_detections = [detections[i] for i in range(len(detections)) - if i not in matched_detection_indices] - unmatched_tracks = [track_ids[i] for i in range(len(track_ids)) - if i not in matched_track_indices] - - return matched_tracks, unmatched_detections, unmatched_tracks - - def _calculate_iou(self, bbox1: np.ndarray, bbox2: List[float]) -> float: - """Calculate IoU between two bounding boxes""" - x1_1, y1_1, x2_1, y2_1 = bbox1 - x1_2, y1_2, x2_2, y2_2 = bbox2 - - # Calculate intersection area - x1_i = max(x1_1, x1_2) - y1_i = max(y1_1, y1_2) - x2_i = min(x2_1, x2_2) - y2_i = min(y2_1, y2_2) - - if x2_i <= x1_i or y2_i <= y1_i: - return 0.0 - - intersection = (x2_i - x1_i) * (y2_i - y1_i) - - # Calculate union area - area1 = (x2_1 - x1_1) * (y2_1 - y1_1) - area2 = (x2_2 - x1_2) * (y2_2 - y1_2) - union = area1 + area2 - intersection - - return intersection / union if union > 0 else 0.0 - - -class MultiCameraBoTSORT: - """ - Multi-camera BoT-SORT tracker with complete camera isolation - """ - - def __init__(self, trigger_classes: List[str], min_confidence: float = 0.6): - """ - Initialize multi-camera tracker - - Args: - trigger_classes: List of class names to track - min_confidence: Minimum detection confidence threshold - """ - self.trigger_classes = trigger_classes - self.min_confidence = min_confidence - - # Camera-specific trackers - self.camera_trackers: Dict[str, CameraTracker] = {} - - logger.info(f"Initialized MultiCameraBoTSORT with classes={trigger_classes}, " - f"min_confidence={min_confidence}") - - def get_or_create_tracker(self, camera_id: str) -> CameraTracker: - """Get or create tracker for specific camera""" - if camera_id not in self.camera_trackers: - self.camera_trackers[camera_id] = CameraTracker(camera_id) - logger.info(f"Created new tracker for camera {camera_id}") - - return self.camera_trackers[camera_id] - - def update(self, camera_id: str, inference_result) -> List[Dict]: - """ - Update tracker for specific camera with detections - - Args: - camera_id: Camera identifier - inference_result: InferenceResult with detections - - Returns: - List of track information dictionaries - """ - # Filter detections by confidence and trigger classes - filtered_detections = [] - - if hasattr(inference_result, 'detections') and inference_result.detections: - for detection in inference_result.detections: - if (detection.confidence >= self.min_confidence and - detection.class_name in self.trigger_classes): - filtered_detections.append(detection) - - # Get camera tracker and update - tracker = self.get_or_create_tracker(camera_id) - confirmed_tracks = tracker.update(filtered_detections) - - # Convert tracks to output format - track_results = [] - for track in confirmed_tracks: - track_results.append({ - 'track_id': track.track_id, - 'camera_id': track.camera_id, - 'bbox': track.bbox, - 'confidence': track.confidence, - 'class_name': track.class_name, - 'hit_streak': track.hit_streak, - 'age': track.age - }) - - return track_results - - def get_statistics(self) -> Dict[str, Any]: - """Get tracking statistics across all cameras""" - stats = {} - total_tracks = 0 - - for camera_id, tracker in self.camera_trackers.items(): - camera_stats = { - 'active_tracks': len([t for t in tracker.tracks.values() if t.is_confirmed()]), - 'total_tracks': len(tracker.tracks), - 'frame_count': tracker.frame_count - } - stats[camera_id] = camera_stats - total_tracks += camera_stats['active_tracks'] - - stats['summary'] = { - 'total_cameras': len(self.camera_trackers), - 'total_active_tracks': total_tracks - } - - return stats - - def reset_camera(self, camera_id: str): - """Reset tracking for specific camera""" - if camera_id in self.camera_trackers: - del self.camera_trackers[camera_id] - logger.info(f"Reset tracking for camera {camera_id}") - - def reset_all(self): - """Reset all camera trackers""" - self.camera_trackers.clear() - logger.info("Reset all camera trackers") \ No newline at end of file diff --git a/core/tracking/integration.py b/core/tracking/integration.py deleted file mode 100644 index 6ff2ee7..0000000 --- a/core/tracking/integration.py +++ /dev/null @@ -1,883 +0,0 @@ -""" -Tracking-Pipeline Integration Module. -Connects the tracking system with the main detection pipeline and manages the flow. -""" -import logging -import time -import uuid -from typing import Dict, Optional, Any, List, Tuple -from concurrent.futures import ThreadPoolExecutor -import asyncio -import numpy as np - -from .tracker import VehicleTracker, TrackedVehicle -from .validator import StableCarValidator -from ..models.inference import YOLOWrapper -from ..models.pipeline import PipelineParser -from ..detection.pipeline import DetectionPipeline - -logger = logging.getLogger(__name__) - - -class TrackingPipelineIntegration: - """ - Integrates vehicle tracking with the detection pipeline. - Manages tracking state transitions and pipeline execution triggers. - """ - - def __init__(self, pipeline_parser: PipelineParser, model_manager: Any, model_id: int, message_sender=None): - """ - Initialize tracking-pipeline integration. - - Args: - pipeline_parser: Pipeline parser with loaded configuration - model_manager: Model manager for loading models - model_id: The model ID to use for loading models - message_sender: Optional callback function for sending WebSocket messages - """ - self.pipeline_parser = pipeline_parser - self.model_manager = model_manager - self.model_id = model_id - self.message_sender = message_sender - - # Store subscription info for snapshot access - self.subscription_info = None - - # Initialize tracking components - tracking_config = pipeline_parser.tracking_config.__dict__ if pipeline_parser.tracking_config else {} - self.tracker = VehicleTracker(tracking_config) - self.validator = StableCarValidator() - - # Tracking model - self.tracking_model: Optional[YOLOWrapper] = None - self.tracking_model_id = None - - # Detection pipeline (Phase 5) - self.detection_pipeline: Optional[DetectionPipeline] = None - - # Session management - self.active_sessions: Dict[str, str] = {} # display_id -> session_id - self.session_vehicles: Dict[str, int] = {} # session_id -> track_id - self.cleared_sessions: Dict[str, float] = {} # session_id -> clear_time - self.pending_vehicles: Dict[str, int] = {} # display_id -> track_id (waiting for session ID) - self.pending_processing_data: Dict[str, Dict] = {} # display_id -> processing data (waiting for session ID) - self.display_to_subscription: Dict[str, str] = {} # display_id -> subscription_id (for fallback) - - # Additional validators for enhanced flow control - self.permanently_processed: Dict[str, float] = {} # "camera_id:track_id" -> process_time (never process again) - self.progression_stages: Dict[str, str] = {} # session_id -> current_stage - self.last_detection_time: Dict[str, float] = {} # display_id -> last_detection_timestamp - self.abandonment_timeout = 3.0 # seconds to wait before declaring car abandoned - - # Thread pool for pipeline execution - # Increased to 8 workers to handle 8 concurrent cameras without queuing - self.executor = ThreadPoolExecutor(max_workers=8) - - # Min bbox filtering configuration - # TODO: Make this configurable via pipeline.json in the future - self.min_bbox_area_percentage = 3.5 # 3.5% of frame area minimum - - # Statistics - self.stats = { - 'frames_processed': 0, - 'vehicles_detected': 0, - 'vehicles_validated': 0, - 'pipelines_executed': 0, - 'frontals_filtered_small': 0 # Track filtered detections - } - - - logger.info("TrackingPipelineIntegration initialized") - - async def initialize_tracking_model(self) -> bool: - """ - Load and initialize the tracking model. - - Returns: - True if successful, False otherwise - """ - try: - if not self.pipeline_parser.tracking_config: - logger.warning("No tracking configuration found in pipeline") - return False - - model_file = self.pipeline_parser.tracking_config.model_file - model_id = self.pipeline_parser.tracking_config.model_id - - if not model_file: - logger.warning("No tracking model file specified") - return False - - # Load tracking model - logger.info(f"Loading tracking model: {model_id} ({model_file})") - self.tracking_model = self.model_manager.get_yolo_model(self.model_id, model_file) - if not self.tracking_model: - logger.error(f"Failed to load tracking model {model_file} from model {self.model_id}") - return False - self.tracking_model_id = model_id - - if self.tracking_model: - logger.info(f"Tracking model {model_id} loaded successfully") - - # Initialize detection pipeline (Phase 5) - await self._initialize_detection_pipeline() - - return True - else: - logger.error(f"Failed to load tracking model {model_id}") - return False - - except Exception as e: - logger.error(f"Error initializing tracking model: {e}", exc_info=True) - return False - - async def _initialize_detection_pipeline(self) -> bool: - """ - Initialize the detection pipeline for main detection processing. - - Returns: - True if successful, False otherwise - """ - try: - if not self.pipeline_parser: - logger.warning("No pipeline parser available for detection pipeline") - return False - - # Create detection pipeline with message sender capability - self.detection_pipeline = DetectionPipeline(self.pipeline_parser, self.model_manager, self.model_id, self.message_sender) - - # Initialize detection pipeline - if await self.detection_pipeline.initialize(): - logger.info("Detection pipeline initialized successfully") - return True - else: - logger.error("Failed to initialize detection pipeline") - return False - - except Exception as e: - logger.error(f"Error initializing detection pipeline: {e}", exc_info=True) - return False - - async def process_frame(self, - frame: np.ndarray, - display_id: str, - subscription_id: str, - session_id: Optional[str] = None) -> Dict[str, Any]: - """ - Process a frame through tracking and potentially the detection pipeline. - - Args: - frame: Input frame to process - display_id: Display identifier - subscription_id: Full subscription identifier - session_id: Optional session ID from backend - - Returns: - Dictionary with processing results - """ - start_time = time.time() - result = { - 'tracked_vehicles': [], - 'validated_vehicle': None, - 'pipeline_result': None, - 'session_id': session_id, - 'processing_time': 0.0 - } - - try: - # Update stats - self.stats['frames_processed'] += 1 - - # Run tracking model - if self.tracking_model: - # Run detection-only (tracking handled by our own tracker) - tracking_results = self.tracking_model.track( - frame, - confidence_threshold=self.tracker.min_confidence, - trigger_classes=self.tracker.trigger_classes, - persist=True - ) - - # Debug: Log raw detection results - if tracking_results and hasattr(tracking_results, 'detections'): - raw_detections = len(tracking_results.detections) - if raw_detections > 0: - class_names = [detection.class_name for detection in tracking_results.detections] - logger.debug(f"Raw detections: {raw_detections}, classes: {class_names}") - else: - logger.debug(f"No raw detections found") - else: - logger.debug(f"No tracking results or detections attribute") - - # Filter out small frontal detections (neighboring pumps/distant cars) - if tracking_results and hasattr(tracking_results, 'detections'): - tracking_results = self._filter_small_frontals(tracking_results, frame) - - # Process tracking results - tracked_vehicles = self.tracker.process_detections( - tracking_results, - display_id, - frame - ) - - # Update last detection time for abandonment detection - # Update when vehicles ARE detected, so when they leave, timestamp ages - if tracked_vehicles: - self.last_detection_time[display_id] = time.time() - logger.debug(f"Updated last_detection_time for {display_id}: {len(tracked_vehicles)} vehicles") - - # Check for car abandonment (vehicle left after getting car_wait_staff stage) - await self._check_car_abandonment(display_id, subscription_id) - - result['tracked_vehicles'] = [ - { - 'track_id': v.track_id, - 'bbox': v.bbox, - 'confidence': v.confidence, - 'is_stable': v.is_stable, - 'session_id': v.session_id - } - for v in tracked_vehicles - ] - - # Log tracking info periodically - if self.stats['frames_processed'] % 30 == 0: # Every 30 frames - logger.debug(f"Tracking: {len(tracked_vehicles)} vehicles, " - f"display={display_id}") - - # Get stable vehicles for validation - stable_vehicles = self.tracker.get_stable_vehicles(display_id) - - # Validate and potentially process stable vehicles - for vehicle in stable_vehicles: - # Check if vehicle is already processed or has session - if vehicle.processed_pipeline: - continue - - # Check for session cleared (post-fueling) - if session_id and vehicle.session_id == session_id: - # Same vehicle with same session, skip - continue - - # Check if this was a recently cleared session - session_cleared = False - if vehicle.session_id in self.cleared_sessions: - clear_time = self.cleared_sessions[vehicle.session_id] - if (time.time() - clear_time) < 30: # 30 second cooldown - session_cleared = True - - # Skip same car after session clear or if permanently processed - if self.validator.should_skip_same_car(vehicle, session_cleared, self.permanently_processed): - continue - - # Validate vehicle - validation_result = self.validator.validate_vehicle(vehicle, frame.shape) - - if validation_result.is_valid and validation_result.should_process: - logger.info(f"Vehicle {vehicle.track_id} validated for processing: " - f"{validation_result.reason}") - - result['validated_vehicle'] = { - 'track_id': vehicle.track_id, - 'state': validation_result.state.value, - 'confidence': validation_result.confidence - } - - # Execute detection pipeline - this will send real imageDetection when detection is found - - # Mark vehicle as pending session ID assignment - self.pending_vehicles[display_id] = vehicle.track_id - logger.info(f"Vehicle {vehicle.track_id} waiting for session ID from backend") - - # Execute detection pipeline (placeholder for Phase 5) - pipeline_result = await self._execute_pipeline( - frame, - vehicle, - display_id, - None, # No session ID yet - subscription_id - ) - - result['pipeline_result'] = pipeline_result - # No session_id in result yet - backend will provide it - self.stats['pipelines_executed'] += 1 - - # Only process one vehicle per frame - break - - self.stats['vehicles_detected'] = len(tracked_vehicles) - self.stats['vehicles_validated'] = len(stable_vehicles) - - else: - logger.warning("No tracking model available") - - except Exception as e: - logger.error(f"Error in tracking pipeline: {e}", exc_info=True) - - - result['processing_time'] = time.time() - start_time - return result - - async def _execute_pipeline(self, - frame: np.ndarray, - vehicle: TrackedVehicle, - display_id: str, - session_id: str, - subscription_id: str) -> Dict[str, Any]: - """ - Execute the main detection pipeline for a validated vehicle. - - Args: - frame: Input frame - vehicle: Validated tracked vehicle - display_id: Display identifier - session_id: Session identifier - subscription_id: Full subscription identifier - - Returns: - Pipeline execution results - """ - logger.info(f"Executing detection pipeline for vehicle {vehicle.track_id}, " - f"session={session_id}, display={display_id}") - - try: - # Check if detection pipeline is available - if not self.detection_pipeline: - logger.warning("Detection pipeline not initialized, using fallback") - return { - 'status': 'error', - 'message': 'Detection pipeline not available', - 'vehicle_id': vehicle.track_id, - 'session_id': session_id - } - - # Fetch high-quality 2K snapshot for detection phase (not RTSP frame) - # This ensures bbox coordinates match the frame used in processing phase - logger.info(f"[DETECTION PHASE] Fetching 2K snapshot for vehicle {vehicle.track_id}") - snapshot_frame = self._fetch_snapshot() - - if snapshot_frame is None: - logger.warning(f"[DETECTION PHASE] Failed to fetch snapshot, falling back to RTSP frame") - snapshot_frame = frame # Fallback to RTSP if snapshot fails - else: - logger.info(f"[DETECTION PHASE] Using {snapshot_frame.shape[1]}x{snapshot_frame.shape[0]} snapshot for detection") - - # Execute only the detection phase (first phase) - # This will run detection and send imageDetection message to backend - detection_result = await self.detection_pipeline.execute_detection_phase( - frame=snapshot_frame, # Use 2K snapshot instead of RTSP frame - display_id=display_id, - subscription_id=subscription_id - ) - - # Add vehicle information to result - detection_result['vehicle_id'] = vehicle.track_id - detection_result['vehicle_bbox'] = vehicle.bbox - detection_result['vehicle_confidence'] = vehicle.confidence - detection_result['phase'] = 'detection' - - logger.info(f"Detection phase executed for vehicle {vehicle.track_id}: " - f"status={detection_result.get('status', 'unknown')}, " - f"message_sent={detection_result.get('message_sent', False)}, " - f"processing_time={detection_result.get('processing_time', 0):.3f}s") - - # Store frame and detection results for processing phase - if detection_result['message_sent']: - # Store for later processing when sessionId is received - self.pending_processing_data[display_id] = { - 'frame': snapshot_frame.copy(), # Store copy of 2K snapshot (not RTSP frame!) - 'vehicle': vehicle, - 'subscription_id': subscription_id, - 'detection_result': detection_result, - 'timestamp': time.time() - } - logger.info(f"Stored processing data ({snapshot_frame.shape[1]}x{snapshot_frame.shape[0]} frame) for {display_id}, waiting for sessionId from backend") - - return detection_result - - except Exception as e: - logger.error(f"Error executing detection pipeline: {e}", exc_info=True) - return { - 'status': 'error', - 'message': str(e), - 'vehicle_id': vehicle.track_id, - 'session_id': session_id, - 'processing_time': 0.0 - } - - async def _execute_processing_phase(self, - processing_data: Dict[str, Any], - session_id: str, - display_id: str) -> None: - """ - Execute the processing phase after receiving sessionId from backend. - This includes branch processing and database operations. - - Args: - processing_data: Stored processing data from detection phase - session_id: Session ID from backend - display_id: Display identifier - """ - try: - vehicle = processing_data['vehicle'] - subscription_id = processing_data['subscription_id'] - detection_result = processing_data['detection_result'] - - logger.info(f"Executing processing phase for session {session_id}, vehicle {vehicle.track_id}") - - # Reuse the snapshot from detection phase OR fetch fresh one if detection used RTSP fallback - detection_frame = processing_data['frame'] - frame_height = detection_frame.shape[0] - - # Check if detection phase used 2K snapshot (height > 1000) or RTSP fallback (height = 720) - if frame_height >= 1000: - # Detection used 2K snapshot - reuse it for consistent coordinates - logger.info(f"[PROCESSING PHASE] Reusing 2K snapshot from detection phase ({detection_frame.shape[1]}x{detection_frame.shape[0]})") - frame = detection_frame - else: - # Detection used RTSP fallback - need to fetch fresh 2K snapshot - logger.warning(f"[PROCESSING PHASE] Detection used RTSP fallback ({detection_frame.shape[1]}x{detection_frame.shape[0]}), fetching fresh 2K snapshot") - frame = self._fetch_snapshot() - - if frame is None: - logger.error(f"[PROCESSING PHASE] Failed to fetch snapshot and detection used RTSP - coordinate mismatch will occur!") - logger.error(f"[PROCESSING PHASE] Cannot proceed with mismatched coordinates. Aborting processing phase.") - return # Cannot process safely - bbox coordinates won't match frame resolution - else: - logger.warning(f"[PROCESSING PHASE] Fetched fresh 2K snapshot ({frame.shape[1]}x{frame.shape[0]}), but coordinates may not match exactly") - logger.warning(f"[PROCESSING PHASE] Re-running detection on fresh snapshot is recommended but not implemented yet") - - # Extract detected regions from detection phase result if available - detected_regions = detection_result.get('detected_regions', {}) - logger.info(f"[INTEGRATION] Passing detected_regions to processing phase: {list(detected_regions.keys())}") - - # Execute processing phase with detection pipeline - if self.detection_pipeline: - processing_result = await self.detection_pipeline.execute_processing_phase( - frame=frame, - display_id=display_id, - session_id=session_id, - subscription_id=subscription_id, - detected_regions=detected_regions - ) - - logger.info(f"Processing phase completed for session {session_id}: " - f"status={processing_result.get('status', 'unknown')}, " - f"branches={len(processing_result.get('branch_results', {}))}, " - f"actions={len(processing_result.get('actions_executed', []))}, " - f"processing_time={processing_result.get('processing_time', 0):.3f}s") - - # Update stats - self.stats['pipelines_executed'] += 1 - - else: - logger.error("Detection pipeline not available for processing phase") - - except Exception as e: - logger.error(f"Error in processing phase for session {session_id}: {e}", exc_info=True) - - - def set_subscription_info(self, subscription_info): - """ - Set subscription info to access snapshot URL and other stream details. - - Args: - subscription_info: SubscriptionInfo object containing stream config - """ - self.subscription_info = subscription_info - logger.debug(f"Set subscription info with snapshot_url: {subscription_info.stream_config.snapshot_url if subscription_info else None}") - - def set_session_id(self, display_id: str, session_id: str, subscription_id: str = None): - """ - Set session ID for a display (from backend). - This is called when backend sends setSessionId after receiving imageDetection. - - Args: - display_id: Display identifier - session_id: Session identifier - subscription_id: Subscription identifier (displayId;cameraId) - needed for fallback - """ - # Ensure session_id is always a string for consistent type handling - session_id = str(session_id) if session_id is not None else None - self.active_sessions[display_id] = session_id - - # Store subscription_id for fallback usage - if subscription_id: - self.display_to_subscription[display_id] = subscription_id - logger.info(f"Set session {session_id} for display {display_id} with subscription {subscription_id}") - else: - logger.info(f"Set session {session_id} for display {display_id}") - - # Check if we have a pending vehicle for this display - if display_id in self.pending_vehicles: - track_id = self.pending_vehicles[display_id] - - # Mark vehicle as processed with the session ID - self.tracker.mark_processed(track_id, session_id) - self.session_vehicles[session_id] = track_id - - # Mark vehicle as permanently processed (won't process again even after session clear) - # Use composite key to distinguish same track IDs across different cameras - camera_id = display_id # Using display_id as camera_id for isolation - permanent_key = f"{camera_id}:{track_id}" - self.permanently_processed[permanent_key] = time.time() - - # Remove from pending - del self.pending_vehicles[display_id] - - logger.info(f"Assigned session {session_id} to vehicle {track_id}, marked as permanently processed") - else: - logger.warning(f"No pending vehicle found for display {display_id} when setting session {session_id}") - - # Check if we have pending processing data for this display - if display_id in self.pending_processing_data: - processing_data = self.pending_processing_data[display_id] - - # Trigger the processing phase asynchronously - asyncio.create_task(self._execute_processing_phase( - processing_data=processing_data, - session_id=session_id, - display_id=display_id - )) - - # Remove from pending processing - del self.pending_processing_data[display_id] - - logger.info(f"Triggered processing phase for session {session_id} on display {display_id}") - else: - logger.warning(f"No pending processing data found for display {display_id} when setting session {session_id}") - - # FALLBACK: Execute pipeline for POS-initiated sessions - # Skip if session_id is None (no car present or car has left) - if session_id is not None: - # Use stored subscription_id instead of creating fake one - stored_subscription_id = self.display_to_subscription.get(display_id) - if stored_subscription_id: - logger.info(f"[FALLBACK] Triggering fallback pipeline for session {session_id} on display {display_id} with subscription {stored_subscription_id}") - - # Trigger the fallback pipeline asynchronously with real subscription_id - asyncio.create_task(self._execute_fallback_pipeline( - display_id=display_id, - session_id=session_id, - subscription_id=stored_subscription_id - )) - else: - logger.error(f"[FALLBACK] No subscription_id stored for display {display_id}, cannot execute fallback pipeline") - else: - logger.debug(f"[FALLBACK] Skipping pipeline execution for session_id=None on display {display_id}") - - def clear_session_id(self, session_id: str): - """ - Clear session ID (post-fueling). - - Args: - session_id: Session identifier to clear - """ - # Mark session as cleared - self.cleared_sessions[session_id] = time.time() - - # Clear from tracker - self.tracker.clear_session(session_id) - - # Remove from active sessions - display_to_remove = None - for display_id, sess_id in self.active_sessions.items(): - if sess_id == session_id: - display_to_remove = display_id - break - - if display_to_remove: - del self.active_sessions[display_to_remove] - - if session_id in self.session_vehicles: - del self.session_vehicles[session_id] - - logger.info(f"Cleared session {session_id}") - - # Clean old cleared sessions (older than 5 minutes) - current_time = time.time() - old_sessions = [ - sid for sid, clear_time in self.cleared_sessions.items() - if (current_time - clear_time) > 300 - ] - for sid in old_sessions: - del self.cleared_sessions[sid] - - def get_session_for_display(self, display_id: str) -> Optional[str]: - """Get active session for a display.""" - return self.active_sessions.get(display_id) - - def reset_tracking(self): - """Reset all tracking state.""" - self.tracker.reset_tracking() - self.active_sessions.clear() - self.session_vehicles.clear() - self.cleared_sessions.clear() - self.pending_vehicles.clear() - self.pending_processing_data.clear() - self.display_to_subscription.clear() - self.permanently_processed.clear() - self.progression_stages.clear() - self.last_detection_time.clear() - logger.info("Tracking pipeline integration reset") - - def get_statistics(self) -> Dict[str, Any]: - """Get comprehensive statistics.""" - tracker_stats = self.tracker.get_statistics() - validator_stats = self.validator.get_statistics() - - return { - 'integration': self.stats, - 'tracker': tracker_stats, - 'validator': validator_stats, - 'active_sessions': len(self.active_sessions), - 'cleared_sessions': len(self.cleared_sessions) - } - - async def _check_car_abandonment(self, display_id: str, subscription_id: str): - """ - Check if a car has abandoned the fueling process (left after getting car_wait_staff stage). - - Args: - display_id: Display identifier - subscription_id: Subscription identifier - """ - current_time = time.time() - - # Check all sessions in car_wait_staff stage - abandoned_sessions = [] - for session_id, stage in self.progression_stages.items(): - if stage == "car_wait_staff": - # Check if we have recent detections for this session's display - session_display = None - for disp_id, sess_id in self.active_sessions.items(): - if sess_id == session_id: - session_display = disp_id - break - - if session_display: - last_detection = self.last_detection_time.get(session_display, 0) - time_since_detection = current_time - last_detection - - logger.info(f"[ABANDON CHECK] Session {session_id} (display: {session_display}): " - f"time_since_detection={time_since_detection:.1f}s, " - f"timeout={self.abandonment_timeout}s") - - if time_since_detection > self.abandonment_timeout: - logger.warning(f"🚨 Car abandonment detected: session {session_id}, " - f"no detection for {time_since_detection:.1f}s") - abandoned_sessions.append(session_id) - else: - logger.debug(f"[ABANDON CHECK] Session {session_id} has no associated display") - - # Send abandonment detection for each abandoned session - for session_id in abandoned_sessions: - await self._send_abandonment_detection(subscription_id, session_id) - # Remove from progression stages to avoid repeated detection - if session_id in self.progression_stages: - del self.progression_stages[session_id] - logger.info(f"[ABANDON] Removed session {session_id} from progression_stages after notification") - - async def _send_abandonment_detection(self, subscription_id: str, session_id: str): - """ - Send imageDetection with null detection to indicate car abandonment. - - Args: - subscription_id: Subscription identifier - session_id: Session ID of the abandoned car - """ - try: - # Import here to avoid circular imports - from ..communication.messages import create_image_detection - - # Create abandonment detection message with null detection - detection_message = create_image_detection( - subscription_identifier=subscription_id, - detection_data=None, # Null detection indicates abandonment - model_id=self.model_id, - model_name=self.pipeline_parser.tracking_config.model_id if self.pipeline_parser.tracking_config else "tracking_model" - ) - - # Send to backend via WebSocket if sender is available - if self.message_sender: - await self.message_sender(detection_message) - logger.info(f"[CAR ABANDONMENT] Sent null detection for session {session_id}") - else: - logger.info(f"[CAR ABANDONMENT] No message sender available, would send: {detection_message}") - - except Exception as e: - logger.error(f"Error sending abandonment detection: {e}", exc_info=True) - - def set_progression_stage(self, session_id: str, stage: str): - """ - Set progression stage for a session (from backend setProgessionStage message). - - Args: - session_id: Session identifier - stage: Progression stage (e.g., "car_wait_staff") - """ - self.progression_stages[session_id] = stage - logger.info(f"Set progression stage for session {session_id}: {stage}") - - # If car reaches car_wait_staff, start monitoring for abandonment - if stage == "car_wait_staff": - logger.info(f"Started monitoring session {session_id} for car abandonment") - - def _fetch_snapshot(self) -> Optional[np.ndarray]: - """ - Fetch high-quality snapshot from camera's snapshot URL. - Reusable method for both processing phase and fallback pipeline. - - Returns: - Snapshot frame or None if unavailable - """ - if not (self.subscription_info and self.subscription_info.stream_config.snapshot_url): - logger.warning("[SNAPSHOT] No subscription info or snapshot URL available") - return None - - try: - from ..streaming.readers import HTTPSnapshotReader - - logger.info(f"[SNAPSHOT] Fetching snapshot for {self.subscription_info.camera_id}") - snapshot_reader = HTTPSnapshotReader( - camera_id=self.subscription_info.camera_id, - snapshot_url=self.subscription_info.stream_config.snapshot_url, - max_retries=3 - ) - - frame = snapshot_reader.fetch_single_snapshot() - - if frame is not None: - logger.info(f"[SNAPSHOT] Successfully fetched {frame.shape[1]}x{frame.shape[0]} snapshot") - return frame - else: - logger.warning("[SNAPSHOT] Failed to fetch snapshot") - return None - - except Exception as e: - logger.error(f"[SNAPSHOT] Error fetching snapshot: {e}", exc_info=True) - return None - - async def _execute_fallback_pipeline(self, display_id: str, session_id: str, subscription_id: str): - """ - Execute fallback pipeline when sessionId is received without prior detection. - This handles POS-initiated sessions where backend starts transaction before car detection. - - Args: - display_id: Display identifier - session_id: Session ID from backend - subscription_id: Subscription identifier for pipeline execution - """ - try: - logger.info(f"[FALLBACK PIPELINE] Executing for session {session_id}, display {display_id}") - - # Fetch fresh snapshot from camera - frame = self._fetch_snapshot() - - if frame is None: - logger.error(f"[FALLBACK] Failed to fetch snapshot for session {session_id}, cannot execute pipeline") - return - - logger.info(f"[FALLBACK] Using snapshot frame {frame.shape[1]}x{frame.shape[0]} for session {session_id}") - - # Check if detection pipeline is available - if not self.detection_pipeline: - logger.error(f"[FALLBACK] Detection pipeline not available for session {session_id}") - return - - # Execute detection phase to get detected regions - detection_result = await self.detection_pipeline.execute_detection_phase( - frame=frame, - display_id=display_id, - subscription_id=subscription_id - ) - - logger.info(f"[FALLBACK] Detection phase completed for session {session_id}: " - f"status={detection_result.get('status', 'unknown')}, " - f"regions={list(detection_result.get('detected_regions', {}).keys())}") - - # If detection found regions, execute processing phase - detected_regions = detection_result.get('detected_regions', {}) - if detected_regions: - processing_result = await self.detection_pipeline.execute_processing_phase( - frame=frame, - display_id=display_id, - session_id=session_id, - subscription_id=subscription_id, - detected_regions=detected_regions - ) - - logger.info(f"[FALLBACK] Processing phase completed for session {session_id}: " - f"status={processing_result.get('status', 'unknown')}, " - f"branches={len(processing_result.get('branch_results', {}))}, " - f"actions={len(processing_result.get('actions_executed', []))}") - - # Update statistics - self.stats['pipelines_executed'] += 1 - - else: - logger.warning(f"[FALLBACK] No detections found in snapshot for session {session_id}") - - except Exception as e: - logger.error(f"[FALLBACK] Error executing fallback pipeline for session {session_id}: {e}", exc_info=True) - - def _filter_small_frontals(self, tracking_results, frame): - """ - Filter out frontal detections that are smaller than minimum bbox area percentage. - This prevents processing of cars from neighboring pumps that appear in camera view. - - Args: - tracking_results: YOLO tracking results with detections - frame: Input frame for calculating frame area - - Returns: - Modified tracking_results with small frontals removed - """ - if not hasattr(tracking_results, 'detections') or not tracking_results.detections: - return tracking_results - - # Calculate frame area and minimum bbox area threshold - frame_area = frame.shape[0] * frame.shape[1] # height * width - min_bbox_area = frame_area * (self.min_bbox_area_percentage / 100.0) - - # Filter detections - filtered_detections = [] - filtered_count = 0 - - for detection in tracking_results.detections: - # Calculate detection bbox area - bbox = detection.bbox # Assuming bbox is [x1, y1, x2, y2] - bbox_area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - - if bbox_area >= min_bbox_area: - # Keep detection - bbox is large enough - filtered_detections.append(detection) - else: - # Filter out small detection - filtered_count += 1 - area_percentage = (bbox_area / frame_area) * 100 - logger.debug(f"Filtered small frontal: area={bbox_area:.0f}px² ({area_percentage:.1f}% of frame, " - f"min required: {self.min_bbox_area_percentage}%)") - - # Update tracking results with filtered detections - tracking_results.detections = filtered_detections - - # Update statistics - if filtered_count > 0: - self.stats['frontals_filtered_small'] += filtered_count - logger.info(f"Filtered {filtered_count} small frontal detections, " - f"{len(filtered_detections)} remaining (total filtered: {self.stats['frontals_filtered_small']})") - - return tracking_results - - def cleanup(self): - """Cleanup resources.""" - self.executor.shutdown(wait=False) - self.reset_tracking() - - - # Cleanup detection pipeline - if self.detection_pipeline: - self.detection_pipeline.cleanup() - - logger.info("Tracking pipeline integration cleaned up") \ No newline at end of file diff --git a/core/tracking/tracker.py b/core/tracking/tracker.py deleted file mode 100644 index 63d0299..0000000 --- a/core/tracking/tracker.py +++ /dev/null @@ -1,293 +0,0 @@ -""" -Vehicle Tracking Module - BoT-SORT based tracking with camera isolation -Implements vehicle identification, persistence, and motion analysis using external tracker. -""" -import logging -import time -import uuid -from typing import Dict, List, Optional, Tuple, Any -from dataclasses import dataclass, field -import numpy as np -from threading import Lock - -from .bot_sort_tracker import MultiCameraBoTSORT - -logger = logging.getLogger(__name__) - - -@dataclass -class TrackedVehicle: - """Represents a tracked vehicle with all its state information.""" - track_id: int - camera_id: str - first_seen: float - last_seen: float - session_id: Optional[str] = None - display_id: Optional[str] = None - confidence: float = 0.0 - bbox: Tuple[int, int, int, int] = (0, 0, 0, 0) # x1, y1, x2, y2 - center: Tuple[float, float] = (0.0, 0.0) - stable_frames: int = 0 - total_frames: int = 0 - is_stable: bool = False - processed_pipeline: bool = False - last_position_history: List[Tuple[float, float]] = field(default_factory=list) - avg_confidence: float = 0.0 - hit_streak: int = 0 - age: int = 0 - - def update_position(self, bbox: Tuple[int, int, int, int], confidence: float): - """Update vehicle position and confidence.""" - self.bbox = bbox - self.center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2) - self.last_seen = time.time() - self.confidence = confidence - self.total_frames += 1 - - # Update confidence average - self.avg_confidence = ((self.avg_confidence * (self.total_frames - 1)) + confidence) / self.total_frames - - # Maintain position history (last 10 positions) - self.last_position_history.append(self.center) - if len(self.last_position_history) > 10: - self.last_position_history.pop(0) - - def calculate_stability(self) -> float: - """Calculate stability score based on position history.""" - if len(self.last_position_history) < 2: - return 0.0 - - # Calculate movement variance - positions = np.array(self.last_position_history) - if len(positions) < 2: - return 0.0 - - # Calculate standard deviation of positions - std_x = np.std(positions[:, 0]) - std_y = np.std(positions[:, 1]) - - # Lower variance means more stable (inverse relationship) - # Normalize to 0-1 range (assuming max reasonable std is 50 pixels) - stability = max(0, 1 - (std_x + std_y) / 100) - return stability - - def is_expired(self, timeout_seconds: float = 2.0) -> bool: - """Check if vehicle tracking has expired.""" - return (time.time() - self.last_seen) > timeout_seconds - - -class VehicleTracker: - """ - Main vehicle tracking implementation using BoT-SORT with camera isolation. - Manages continuous tracking, vehicle identification, and state persistence. - """ - - def __init__(self, tracking_config: Optional[Dict] = None): - """ - Initialize the vehicle tracker. - - Args: - tracking_config: Configuration from pipeline.json tracking section - """ - self.config = tracking_config or {} - self.trigger_classes = self.config.get('trigger_classes', self.config.get('triggerClasses', ['frontal'])) - self.min_confidence = self.config.get('minConfidence', 0.6) - - # BoT-SORT multi-camera tracker - self.bot_sort = MultiCameraBoTSORT(self.trigger_classes, self.min_confidence) - - # Tracking state - maintain compatibility with existing code - self.tracked_vehicles: Dict[str, Dict[int, TrackedVehicle]] = {} # camera_id -> {track_id: vehicle} - self.lock = Lock() - - # Tracking parameters - self.stability_threshold = 0.7 - self.min_stable_frames = 5 - self.timeout_seconds = 2.0 - - logger.info(f"VehicleTracker initialized with BoT-SORT: trigger_classes={self.trigger_classes}, " - f"min_confidence={self.min_confidence}") - - def process_detections(self, - results: Any, - display_id: str, - frame: np.ndarray) -> List[TrackedVehicle]: - """ - Process detection results using BoT-SORT tracking. - - Args: - results: Detection results (InferenceResult) - display_id: Display identifier for this stream - frame: Current frame being processed - - Returns: - List of currently tracked vehicles - """ - current_time = time.time() - - # Extract camera_id from display_id for tracking isolation - camera_id = display_id # Using display_id as camera_id for isolation - - with self.lock: - # Update BoT-SORT tracker - track_results = self.bot_sort.update(camera_id, results) - - # Ensure camera tracking dict exists - if camera_id not in self.tracked_vehicles: - self.tracked_vehicles[camera_id] = {} - - # Update tracked vehicles based on BoT-SORT results - current_tracks = {} - active_tracks = [] - - for track_result in track_results: - track_id = track_result['track_id'] - - # Create or update TrackedVehicle - if track_id in self.tracked_vehicles[camera_id]: - # Update existing vehicle - vehicle = self.tracked_vehicles[camera_id][track_id] - vehicle.update_position(track_result['bbox'], track_result['confidence']) - vehicle.hit_streak = track_result['hit_streak'] - vehicle.age = track_result['age'] - - # Update stability based on hit_streak - if vehicle.hit_streak >= self.min_stable_frames: - vehicle.is_stable = True - vehicle.stable_frames = vehicle.hit_streak - - logger.debug(f"Updated track {track_id}: conf={vehicle.confidence:.2f}, " - f"stable={vehicle.is_stable}, hit_streak={vehicle.hit_streak}") - else: - # Create new vehicle - x1, y1, x2, y2 = track_result['bbox'] - vehicle = TrackedVehicle( - track_id=track_id, - camera_id=camera_id, - first_seen=current_time, - last_seen=current_time, - display_id=display_id, - confidence=track_result['confidence'], - bbox=tuple(track_result['bbox']), - center=((x1 + x2) / 2, (y1 + y2) / 2), - total_frames=1, - hit_streak=track_result['hit_streak'], - age=track_result['age'] - ) - vehicle.last_position_history.append(vehicle.center) - logger.info(f"New vehicle tracked: ID={track_id}, camera={camera_id}, display={display_id}") - - current_tracks[track_id] = vehicle - active_tracks.append(vehicle) - - # Update the camera's tracked vehicles - self.tracked_vehicles[camera_id] = current_tracks - - return active_tracks - - def get_stable_vehicles(self, display_id: Optional[str] = None) -> List[TrackedVehicle]: - """ - Get all stable vehicles, optionally filtered by display. - - Args: - display_id: Optional display ID to filter by - - Returns: - List of stable tracked vehicles - """ - with self.lock: - stable = [] - camera_id = display_id # Using display_id as camera_id - - if camera_id in self.tracked_vehicles: - for vehicle in self.tracked_vehicles[camera_id].values(): - if (vehicle.is_stable and not vehicle.is_expired(self.timeout_seconds) and - (display_id is None or vehicle.display_id == display_id)): - stable.append(vehicle) - - return stable - - def get_vehicle_by_session(self, session_id: str) -> Optional[TrackedVehicle]: - """ - Get a tracked vehicle by its session ID. - - Args: - session_id: Session ID to look up - - Returns: - Tracked vehicle if found, None otherwise - """ - with self.lock: - # Search across all cameras - for camera_vehicles in self.tracked_vehicles.values(): - for vehicle in camera_vehicles.values(): - if vehicle.session_id == session_id: - return vehicle - return None - - def mark_processed(self, track_id: int, session_id: str): - """ - Mark a vehicle as processed through the pipeline. - - Args: - track_id: Track ID of the vehicle - session_id: Session ID assigned to this vehicle - """ - with self.lock: - # Search across all cameras for the track_id - for camera_vehicles in self.tracked_vehicles.values(): - if track_id in camera_vehicles: - vehicle = camera_vehicles[track_id] - vehicle.processed_pipeline = True - vehicle.session_id = session_id - logger.info(f"Marked vehicle {track_id} as processed with session {session_id}") - return - - def clear_session(self, session_id: str): - """ - Clear session ID from a tracked vehicle (post-fueling). - - Args: - session_id: Session ID to clear - """ - with self.lock: - # Search across all cameras - for camera_vehicles in self.tracked_vehicles.values(): - for vehicle in camera_vehicles.values(): - if vehicle.session_id == session_id: - logger.info(f"Clearing session {session_id} from vehicle {vehicle.track_id}") - vehicle.session_id = None - # Keep processed_pipeline=True to prevent re-processing - - def reset_tracking(self): - """Reset all tracking state.""" - with self.lock: - self.tracked_vehicles.clear() - self.bot_sort.reset_all() - logger.info("Vehicle tracking state reset") - - def get_statistics(self) -> Dict: - """Get tracking statistics.""" - with self.lock: - total = 0 - stable = 0 - processed = 0 - all_confidences = [] - - # Aggregate stats across all cameras - for camera_vehicles in self.tracked_vehicles.values(): - total += len(camera_vehicles) - for vehicle in camera_vehicles.values(): - if vehicle.is_stable: - stable += 1 - if vehicle.processed_pipeline: - processed += 1 - all_confidences.append(vehicle.avg_confidence) - - return { - 'total_tracked': total, - 'stable_vehicles': stable, - 'processed_vehicles': processed, - 'avg_confidence': np.mean(all_confidences) if all_confidences else 0.0, - 'bot_sort_stats': self.bot_sort.get_statistics() - } \ No newline at end of file diff --git a/core/tracking/validator.py b/core/tracking/validator.py deleted file mode 100644 index 0c1dca4..0000000 --- a/core/tracking/validator.py +++ /dev/null @@ -1,407 +0,0 @@ -""" -Vehicle Validation Module - Stable car detection and validation logic. -Differentiates between stable (fueling) cars and passing-by vehicles. -""" -import logging -import time -import numpy as np -from typing import List, Optional, Tuple, Dict, Any -from dataclasses import dataclass -from enum import Enum - -from .tracker import TrackedVehicle - -logger = logging.getLogger(__name__) - - -class VehicleState(Enum): - """Vehicle state classification.""" - UNKNOWN = "unknown" - ENTERING = "entering" - STABLE = "stable" - LEAVING = "leaving" - PASSING_BY = "passing_by" - - -@dataclass -class ValidationResult: - """Result of vehicle validation.""" - is_valid: bool - state: VehicleState - confidence: float - reason: str - should_process: bool = False - track_id: Optional[int] = None - - -class StableCarValidator: - """ - Validates whether a tracked vehicle should be processed through the pipeline. - - Updated for BoT-SORT integration: Trusts the sophisticated BoT-SORT tracking algorithm - for stability determination and focuses on business logic validation: - - Duration requirements for processing - - Confidence thresholds - - Session management and cooldowns - - Camera isolation with composite keys - """ - - def __init__(self, config: Optional[Dict] = None): - """ - Initialize the validator with configuration. - - Args: - config: Optional configuration dictionary - """ - self.config = config or {} - - # Validation thresholds - # Optimized for 6 FPS RTSP source with 8 concurrent cameras on GPU - # GPU contention reduces effective FPS to ~3-5 per camera - # Reduced from 3.0s to 1.5s to achieve ~2.75s total validation time (was ~4.25s) - self.min_stable_duration = self.config.get('min_stable_duration', 1.5) # seconds - # Reduced from 10 to 5 to align with tracker requirement and reduce validation time - self.min_stable_frames = self.config.get('min_stable_frames', 5) - self.position_variance_threshold = self.config.get('position_variance_threshold', 25.0) # pixels - # Reduced from 0.7 to 0.45 to be more permissive under GPU load - self.min_confidence = self.config.get('min_confidence', 0.45) - self.velocity_threshold = self.config.get('velocity_threshold', 5.0) # pixels/frame - self.entering_zone_ratio = self.config.get('entering_zone_ratio', 0.3) # 30% of frame - self.leaving_zone_ratio = self.config.get('leaving_zone_ratio', 0.3) - - # Frame dimensions (will be updated on first frame) - self.frame_width = 1920 - self.frame_height = 1080 - - # History for validation - self.validation_history: Dict[int, List[VehicleState]] = {} - self.last_processed_vehicles: Dict[int, float] = {} # track_id -> last_process_time - - logger.info(f"StableCarValidator initialized with min_duration={self.min_stable_duration}s, " - f"min_frames={self.min_stable_frames}, position_variance={self.position_variance_threshold}") - - def update_frame_dimensions(self, width: int, height: int): - """Update frame dimensions for zone calculations.""" - self.frame_width = width - self.frame_height = height - # Commented out verbose frame dimension logging - # logger.debug(f"Updated frame dimensions: {width}x{height}") - - def validate_vehicle(self, vehicle: TrackedVehicle, frame_shape: Optional[Tuple] = None) -> ValidationResult: - """ - Validate whether a tracked vehicle is stable and should be processed. - - Args: - vehicle: The tracked vehicle to validate - frame_shape: Optional frame shape (height, width, channels) - - Returns: - ValidationResult with validation status and reasoning - """ - # Update frame dimensions if provided - if frame_shape: - self.update_frame_dimensions(frame_shape[1], frame_shape[0]) - - # Initialize validation history for new vehicles - if vehicle.track_id not in self.validation_history: - self.validation_history[vehicle.track_id] = [] - - # Check if already processed - if vehicle.processed_pipeline: - return ValidationResult( - is_valid=False, - state=VehicleState.STABLE, - confidence=1.0, - reason="Already processed through pipeline", - should_process=False, - track_id=vehicle.track_id - ) - - # Check if recently processed (cooldown period) - if vehicle.track_id in self.last_processed_vehicles: - time_since_process = time.time() - self.last_processed_vehicles[vehicle.track_id] - if time_since_process < 10.0: # 10 second cooldown - return ValidationResult( - is_valid=False, - state=VehicleState.STABLE, - confidence=1.0, - reason=f"Recently processed ({time_since_process:.1f}s ago)", - should_process=False, - track_id=vehicle.track_id - ) - - # Determine vehicle state - state = self._determine_vehicle_state(vehicle) - - # Update history - self.validation_history[vehicle.track_id].append(state) - if len(self.validation_history[vehicle.track_id]) > 20: - self.validation_history[vehicle.track_id].pop(0) - - # Validate based on state - if state == VehicleState.STABLE: - return self._validate_stable_vehicle(vehicle) - elif state == VehicleState.PASSING_BY: - return ValidationResult( - is_valid=False, - state=state, - confidence=0.8, - reason="Vehicle is passing by", - should_process=False, - track_id=vehicle.track_id - ) - elif state == VehicleState.ENTERING: - return ValidationResult( - is_valid=False, - state=state, - confidence=0.5, - reason="Vehicle is entering, waiting for stability", - should_process=False, - track_id=vehicle.track_id - ) - elif state == VehicleState.LEAVING: - return ValidationResult( - is_valid=False, - state=state, - confidence=0.5, - reason="Vehicle is leaving", - should_process=False, - track_id=vehicle.track_id - ) - else: - return ValidationResult( - is_valid=False, - state=state, - confidence=0.0, - reason="Unknown vehicle state", - should_process=False, - track_id=vehicle.track_id - ) - - def _determine_vehicle_state(self, vehicle: TrackedVehicle) -> VehicleState: - """ - Determine the current state of the vehicle based on BoT-SORT tracking results. - - BoT-SORT provides sophisticated tracking, so we trust its stability determination - and focus on business logic validation. - - Args: - vehicle: The tracked vehicle - - Returns: - Current vehicle state - """ - # Trust BoT-SORT's stability determination - if vehicle.is_stable: - # Check if it's been stable long enough for processing - duration = time.time() - vehicle.first_seen - if duration >= self.min_stable_duration: - return VehicleState.STABLE - else: - return VehicleState.ENTERING - - # For non-stable vehicles, use simplified state determination - if len(vehicle.last_position_history) < 2: - return VehicleState.UNKNOWN - - # Calculate velocity for movement classification - velocity = self._calculate_velocity(vehicle) - - # Basic movement classification - if velocity > self.velocity_threshold: - # Vehicle is moving - classify as passing by or entering/leaving - x_position = vehicle.center[0] / self.frame_width - - # Simple heuristic: vehicles near edges are entering/leaving, center vehicles are passing - if x_position < 0.2 or x_position > 0.8: - return VehicleState.ENTERING - else: - return VehicleState.PASSING_BY - - # Low velocity but not marked stable by tracker - likely entering - return VehicleState.ENTERING - - def _validate_stable_vehicle(self, vehicle: TrackedVehicle) -> ValidationResult: - """ - Perform business logic validation of a stable vehicle. - - Since BoT-SORT already determined the vehicle is stable, we focus on: - - Duration requirements for processing - - Confidence thresholds - - Business logic constraints - - Args: - vehicle: The stable vehicle to validate - - Returns: - Detailed validation result - """ - # Check duration (business requirement) - duration = time.time() - vehicle.first_seen - if duration < self.min_stable_duration: - return ValidationResult( - is_valid=False, - state=VehicleState.STABLE, - confidence=0.6, - reason=f"Not stable long enough ({duration:.1f}s < {self.min_stable_duration}s)", - should_process=False, - track_id=vehicle.track_id - ) - - # Check confidence (business requirement) - if vehicle.avg_confidence < self.min_confidence: - return ValidationResult( - is_valid=False, - state=VehicleState.STABLE, - confidence=vehicle.avg_confidence, - reason=f"Confidence too low ({vehicle.avg_confidence:.2f} < {self.min_confidence})", - should_process=False, - track_id=vehicle.track_id - ) - - # Trust BoT-SORT's stability determination - skip position variance check - # BoT-SORT's sophisticated tracking already ensures consistent positioning - - # Simplified state history check - just ensure recent stability - if vehicle.track_id in self.validation_history: - history = self.validation_history[vehicle.track_id][-3:] # Last 3 states - stable_count = sum(1 for s in history if s == VehicleState.STABLE) - if len(history) >= 2 and stable_count == 0: # Only fail if clear instability - return ValidationResult( - is_valid=False, - state=VehicleState.STABLE, - confidence=0.7, - reason="Recent state history shows instability", - should_process=False, - track_id=vehicle.track_id - ) - - # All checks passed - vehicle is valid for processing - self.last_processed_vehicles[vehicle.track_id] = time.time() - - return ValidationResult( - is_valid=True, - state=VehicleState.STABLE, - confidence=vehicle.avg_confidence, - reason="Vehicle is stable and ready for processing (BoT-SORT validated)", - should_process=True, - track_id=vehicle.track_id - ) - - def _calculate_velocity(self, vehicle: TrackedVehicle) -> float: - """ - Calculate the velocity of the vehicle based on position history. - - Args: - vehicle: The tracked vehicle - - Returns: - Velocity in pixels per frame - """ - if len(vehicle.last_position_history) < 2: - return 0.0 - - positions = np.array(vehicle.last_position_history) - if len(positions) < 2: - return 0.0 - - # Calculate velocity over last 3 frames - recent_positions = positions[-min(3, len(positions)):] - velocities = [] - - for i in range(1, len(recent_positions)): - dx = recent_positions[i][0] - recent_positions[i-1][0] - dy = recent_positions[i][1] - recent_positions[i-1][1] - velocity = np.sqrt(dx**2 + dy**2) - velocities.append(velocity) - - return np.mean(velocities) if velocities else 0.0 - - def _calculate_position_variance(self, vehicle: TrackedVehicle) -> float: - """ - Calculate the position variance of the vehicle. - - Args: - vehicle: The tracked vehicle - - Returns: - Position variance in pixels - """ - if len(vehicle.last_position_history) < 2: - return 0.0 - - positions = np.array(vehicle.last_position_history) - variance_x = np.var(positions[:, 0]) - variance_y = np.var(positions[:, 1]) - - return np.sqrt(variance_x + variance_y) - - def should_skip_same_car(self, - vehicle: TrackedVehicle, - session_cleared: bool = False, - permanently_processed: Dict[str, float] = None) -> bool: - """ - Determine if we should skip processing for the same car after session clear. - - Args: - vehicle: The tracked vehicle - session_cleared: Whether the session was recently cleared - permanently_processed: Dict of permanently processed vehicles (camera_id:track_id -> time) - - Returns: - True if we should skip this vehicle - """ - # Check if this vehicle was permanently processed (never process again) - if permanently_processed: - # Create composite key using camera_id and track_id - permanent_key = f"{vehicle.camera_id}:{vehicle.track_id}" - if permanent_key in permanently_processed: - process_time = permanently_processed[permanent_key] - time_since = time.time() - process_time - logger.debug(f"Skipping permanently processed vehicle {vehicle.track_id} on camera {vehicle.camera_id} " - f"(processed {time_since:.1f}s ago)") - return True - - # If vehicle has a session_id but it was cleared, skip for a period - if vehicle.session_id is None and vehicle.processed_pipeline and session_cleared: - # Check if enough time has passed since processing - if vehicle.track_id in self.last_processed_vehicles: - time_since = time.time() - self.last_processed_vehicles[vehicle.track_id] - if time_since < 30.0: # 30 second cooldown after session clear - logger.debug(f"Skipping same car {vehicle.track_id} after session clear " - f"({time_since:.1f}s since processing)") - return True - - return False - - def reset_vehicle(self, track_id: int): - """ - Reset validation state for a specific vehicle. - - Args: - track_id: Track ID of the vehicle to reset - """ - if track_id in self.validation_history: - del self.validation_history[track_id] - if track_id in self.last_processed_vehicles: - del self.last_processed_vehicles[track_id] - logger.debug(f"Reset validation state for vehicle {track_id}") - - def get_statistics(self) -> Dict: - """Get validation statistics.""" - return { - 'vehicles_in_history': len(self.validation_history), - 'recently_processed': len(self.last_processed_vehicles), - 'state_distribution': self._get_state_distribution() - } - - def _get_state_distribution(self) -> Dict[str, int]: - """Get distribution of current vehicle states.""" - distribution = {state.value: 0 for state in VehicleState} - - for history in self.validation_history.values(): - if history: - current_state = history[-1] - distribution[current_state.value] += 1 - - return distribution \ No newline at end of file diff --git a/core/utils/ffmpeg_detector.py b/core/utils/ffmpeg_detector.py deleted file mode 100644 index 565713c..0000000 --- a/core/utils/ffmpeg_detector.py +++ /dev/null @@ -1,214 +0,0 @@ -""" -FFmpeg hardware acceleration detection and configuration -""" - -import subprocess -import logging -import re -from typing import Dict, List, Optional - -logger = logging.getLogger("detector_worker") - - -class FFmpegCapabilities: - """Detect and configure FFmpeg hardware acceleration capabilities.""" - - def __init__(self): - """Initialize FFmpeg capabilities detector.""" - self.hwaccels = [] - self.codecs = {} - self.nvidia_support = False - self.vaapi_support = False - self.qsv_support = False - - self._detect_capabilities() - - def _detect_capabilities(self): - """Detect available hardware acceleration methods.""" - try: - # Get hardware accelerators - result = subprocess.run( - ['ffmpeg', '-hide_banner', '-hwaccels'], - capture_output=True, text=True, timeout=10 - ) - if result.returncode == 0: - self.hwaccels = [line.strip() for line in result.stdout.strip().split('\n')[1:] if line.strip()] - logger.info(f"Available FFmpeg hardware accelerators: {', '.join(self.hwaccels)}") - - # Check for NVIDIA support - self.nvidia_support = any(hw in self.hwaccels for hw in ['cuda', 'cuvid', 'nvdec']) - self.vaapi_support = 'vaapi' in self.hwaccels - self.qsv_support = 'qsv' in self.hwaccels - - # Get decoder information - self._detect_decoders() - - # Log capabilities - if self.nvidia_support: - logger.info("NVIDIA hardware acceleration available (CUDA/CUVID/NVDEC)") - logger.info(f"Detected hardware codecs: {self.codecs}") - if self.vaapi_support: - logger.info("VAAPI hardware acceleration available") - if self.qsv_support: - logger.info("Intel QuickSync hardware acceleration available") - - except Exception as e: - logger.warning(f"Failed to detect FFmpeg capabilities: {e}") - - def _detect_decoders(self): - """Detect available hardware decoders.""" - try: - result = subprocess.run( - ['ffmpeg', '-hide_banner', '-decoders'], - capture_output=True, text=True, timeout=10 - ) - if result.returncode == 0: - # Parse decoder output to find hardware decoders - for line in result.stdout.split('\n'): - if 'cuvid' in line or 'nvdec' in line: - match = re.search(r'(\w+)\s+.*?(\w+(?:_cuvid|_nvdec))', line) - if match: - codec_type, decoder = match.groups() - if 'h264' in decoder: - self.codecs['h264_hw'] = decoder - elif 'hevc' in decoder or 'h265' in decoder: - self.codecs['h265_hw'] = decoder - elif 'vaapi' in line: - match = re.search(r'(\w+)\s+.*?(\w+_vaapi)', line) - if match: - codec_type, decoder = match.groups() - if 'h264' in decoder: - self.codecs['h264_vaapi'] = decoder - - except Exception as e: - logger.debug(f"Failed to detect decoders: {e}") - - def get_optimal_capture_options(self, codec: str = 'h264') -> Dict[str, str]: - """ - Get optimal FFmpeg capture options for the given codec. - - Args: - codec: Video codec (h264, h265, etc.) - - Returns: - Dictionary of FFmpeg options - """ - options = { - 'rtsp_transport': 'tcp', - 'buffer_size': '1024k', - 'max_delay': '500000', # 500ms - 'fflags': '+genpts', - 'flags': '+low_delay', - 'probesize': '32', - 'analyzeduration': '0' - } - - # Add hardware acceleration if available - if self.nvidia_support: - # Force enable CUDA hardware acceleration for H.264 if CUDA is available - if codec == 'h264': - options.update({ - 'hwaccel': 'cuda', - 'hwaccel_device': '0' - }) - logger.info("Using NVIDIA NVDEC hardware acceleration for H.264") - elif codec == 'h265': - options.update({ - 'hwaccel': 'cuda', - 'hwaccel_device': '0', - 'video_codec': 'hevc_cuvid', - 'hwaccel_output_format': 'cuda' - }) - logger.info("Using NVIDIA CUVID hardware acceleration for H.265") - - elif self.vaapi_support: - if codec == 'h264': - options.update({ - 'hwaccel': 'vaapi', - 'hwaccel_device': '/dev/dri/renderD128', - 'video_codec': 'h264_vaapi' - }) - logger.debug("Using VAAPI hardware acceleration") - - return options - - def format_opencv_options(self, options: Dict[str, str]) -> str: - """ - Format options for OpenCV FFmpeg backend. - - Args: - options: Dictionary of FFmpeg options - - Returns: - Formatted options string for OpenCV - """ - return '|'.join(f"{key};{value}" for key, value in options.items()) - - def get_hardware_encoder_options(self, codec: str = 'h264', quality: str = 'fast') -> Dict[str, str]: - """ - Get optimal hardware encoding options. - - Args: - codec: Video codec for encoding - quality: Quality preset (fast, medium, slow) - - Returns: - Dictionary of encoding options - """ - options = {} - - if self.nvidia_support: - if codec == 'h264': - options.update({ - 'video_codec': 'h264_nvenc', - 'preset': quality, - 'tune': 'zerolatency', - 'gpu': '0', - 'rc': 'cbr_hq', - 'surfaces': '64' - }) - elif codec == 'h265': - options.update({ - 'video_codec': 'hevc_nvenc', - 'preset': quality, - 'tune': 'zerolatency', - 'gpu': '0' - }) - - elif self.vaapi_support: - if codec == 'h264': - options.update({ - 'video_codec': 'h264_vaapi', - 'vaapi_device': '/dev/dri/renderD128' - }) - - return options - - -# Global instance -_ffmpeg_caps = None - -def get_ffmpeg_capabilities() -> FFmpegCapabilities: - """Get or create the global FFmpeg capabilities instance.""" - global _ffmpeg_caps - if _ffmpeg_caps is None: - _ffmpeg_caps = FFmpegCapabilities() - return _ffmpeg_caps - -def get_optimal_rtsp_options(rtsp_url: str) -> str: - """ - Get optimal OpenCV FFmpeg options for RTSP streaming. - - Args: - rtsp_url: RTSP stream URL - - Returns: - Formatted options string for cv2.VideoCapture - """ - caps = get_ffmpeg_capabilities() - - # Detect codec from URL or assume H.264 - codec = 'h265' if any(x in rtsp_url.lower() for x in ['h265', 'hevc']) else 'h264' - - options = caps.get_optimal_capture_options(codec) - return caps.format_opencv_options(options) \ No newline at end of file diff --git a/core/utils/hardware_encoder.py b/core/utils/hardware_encoder.py deleted file mode 100644 index 45bbb35..0000000 --- a/core/utils/hardware_encoder.py +++ /dev/null @@ -1,173 +0,0 @@ -""" -Hardware-accelerated image encoding using NVIDIA NVENC or Intel QuickSync -""" - -import cv2 -import numpy as np -import logging -from typing import Optional, Tuple -import os - -logger = logging.getLogger("detector_worker") - - -class HardwareEncoder: - """Hardware-accelerated JPEG encoder using GPU.""" - - def __init__(self): - """Initialize hardware encoder.""" - self.nvenc_available = False - self.vaapi_available = False - self.turbojpeg_available = False - - # Check for TurboJPEG (fastest CPU-based option) - try: - from turbojpeg import TurboJPEG - self.turbojpeg = TurboJPEG() - self.turbojpeg_available = True - logger.info("TurboJPEG accelerated encoding available") - except ImportError: - logger.debug("TurboJPEG not available") - - # Check for NVIDIA NVENC support - try: - # Test if we can create an NVENC encoder - test_frame = np.zeros((720, 1280, 3), dtype=np.uint8) - fourcc = cv2.VideoWriter_fourcc(*'H264') - test_writer = cv2.VideoWriter( - "test.mp4", - fourcc, - 30, - (1280, 720), - [cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_ANY] - ) - if test_writer.isOpened(): - self.nvenc_available = True - logger.info("NVENC hardware encoding available") - test_writer.release() - if os.path.exists("test.mp4"): - os.remove("test.mp4") - except Exception as e: - logger.debug(f"NVENC not available: {e}") - - def encode_jpeg(self, frame: np.ndarray, quality: int = 85) -> Optional[bytes]: - """ - Encode frame to JPEG using the fastest available method. - - Args: - frame: BGR image frame - quality: JPEG quality (1-100) - - Returns: - Encoded JPEG bytes or None on failure - """ - try: - # Method 1: TurboJPEG (3-5x faster than cv2.imencode) - if self.turbojpeg_available: - # Convert BGR to RGB for TurboJPEG - rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - encoded = self.turbojpeg.encode(rgb_frame, quality=quality) - return encoded - - # Method 2: Hardware-accelerated encoding via GStreamer (if available) - if self.nvenc_available: - return self._encode_with_nvenc(frame, quality) - - # Fallback: Standard OpenCV encoding - encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality] - success, encoded = cv2.imencode('.jpg', frame, encode_params) - if success: - return encoded.tobytes() - - return None - - except Exception as e: - logger.error(f"Failed to encode frame: {e}") - return None - - def _encode_with_nvenc(self, frame: np.ndarray, quality: int) -> Optional[bytes]: - """ - Encode using NVIDIA NVENC hardware encoder. - - This is complex to implement directly, so we'll use a GStreamer pipeline - if available. - """ - try: - # Create a GStreamer pipeline for hardware encoding - height, width = frame.shape[:2] - gst_pipeline = ( - f"appsrc ! " - f"video/x-raw,format=BGR,width={width},height={height},framerate=30/1 ! " - f"videoconvert ! " - f"nvvideoconvert ! " # GPU color conversion - f"nvjpegenc quality={quality} ! " # Hardware JPEG encoder - f"appsink" - ) - - # This would require GStreamer Python bindings - # For now, fall back to TurboJPEG or standard encoding - logger.debug("NVENC JPEG encoding not fully implemented, using fallback") - encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality] - success, encoded = cv2.imencode('.jpg', frame, encode_params) - if success: - return encoded.tobytes() - - return None - - except Exception as e: - logger.error(f"NVENC encoding failed: {e}") - return None - - def encode_batch(self, frames: list, quality: int = 85) -> list: - """ - Batch encode multiple frames for better GPU utilization. - - Args: - frames: List of BGR frames - quality: JPEG quality - - Returns: - List of encoded JPEG bytes - """ - encoded_frames = [] - - if self.turbojpeg_available: - # TurboJPEG can handle batch encoding efficiently - for frame in frames: - rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - encoded = self.turbojpeg.encode(rgb_frame, quality=quality) - encoded_frames.append(encoded) - else: - # Fallback to sequential encoding - for frame in frames: - encoded = self.encode_jpeg(frame, quality) - encoded_frames.append(encoded) - - return encoded_frames - - -# Global encoder instance -_hardware_encoder = None - - -def get_hardware_encoder() -> HardwareEncoder: - """Get or create the global hardware encoder instance.""" - global _hardware_encoder - if _hardware_encoder is None: - _hardware_encoder = HardwareEncoder() - return _hardware_encoder - - -def encode_frame_hardware(frame: np.ndarray, quality: int = 85) -> Optional[bytes]: - """ - Convenience function to encode a frame using hardware acceleration. - - Args: - frame: BGR image frame - quality: JPEG quality (1-100) - - Returns: - Encoded JPEG bytes or None on failure - """ - encoder = get_hardware_encoder() - return encoder.encode_jpeg(frame, quality) \ No newline at end of file diff --git a/debug/cuda.py b/debug/cuda.py new file mode 100644 index 0000000..44265e1 --- /dev/null +++ b/debug/cuda.py @@ -0,0 +1,4 @@ +import torch +print(torch.cuda.is_available()) # True if CUDA is available +print(torch.cuda.get_device_name(0)) # GPU name +print(torch.version.cuda) # CUDA version PyTorch was compiled with \ No newline at end of file diff --git a/feeder/note.txt b/feeder/note.txt new file mode 100644 index 0000000..d3b0ef0 --- /dev/null +++ b/feeder/note.txt @@ -0,0 +1 @@ +python simple_track.py --source video/sample.mp4 --show-vid --save-vid --enable-json-log \ No newline at end of file diff --git a/feeder/sender/__init__.py b/feeder/sender/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feeder/sender/base.py b/feeder/sender/base.py new file mode 100644 index 0000000..8824dfe --- /dev/null +++ b/feeder/sender/base.py @@ -0,0 +1,21 @@ + +import numpy as np +import json + +class NumpyArrayEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + else: + return super(NumpyArrayEncoder, self).default(obj) + +class BasSender: + def __init__(self) -> None: + pass + + def send(self, messages): + raise NotImplementedError() \ No newline at end of file diff --git a/feeder/sender/jsonlogger.py b/feeder/sender/jsonlogger.py new file mode 100644 index 0000000..63200cf --- /dev/null +++ b/feeder/sender/jsonlogger.py @@ -0,0 +1,13 @@ +from .base import BasSender +from loguru import logger +import json +from .base import NumpyArrayEncoder + +class JsonLogger(BasSender): + def __init__(self, log_filename:str = "tracking.log") -> None: + super().__init__() + self.logger = logger + self.logger.add(log_filename, format="{message}", level="INFO") + + def send(self, messages): + self.logger.info(json.dumps(messages, cls=NumpyArrayEncoder)) \ No newline at end of file diff --git a/feeder/sender/szmq.py b/feeder/sender/szmq.py new file mode 100644 index 0000000..059c81a --- /dev/null +++ b/feeder/sender/szmq.py @@ -0,0 +1,14 @@ +from .base import BasSender, NumpyArrayEncoder +import zmq +import json + + +class ZmqLogger(BasSender): + def __init__(self, ip_addr:str = "localhost", port:int = 5555) -> None: + super().__init__() + self.context = zmq.Context() + self.producer = self.context.socket(zmq.PUB) + self.producer.connect(f"tcp://{ip_addr}:{port}") + + def send(self, messages): + self.producer.send_string(json.dumps(messages, cls = NumpyArrayEncoder)) \ No newline at end of file diff --git a/feeder/simple_track.py b/feeder/simple_track.py new file mode 100644 index 0000000..a8bf61c --- /dev/null +++ b/feeder/simple_track.py @@ -0,0 +1,245 @@ +import argparse +import cv2 +import os +os.environ["OMP_NUM_THREADS"] = "1" +os.environ["OPENBLAS_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["VECLIB_MAXIMUM_THREADS"] = "1" +os.environ["NUMEXPR_NUM_THREADS"] = "1" + +import sys +import numpy as np +from pathlib import Path +import torch + +FILE = Path(__file__).resolve() +ROOT = FILE.parents[0] +WEIGHTS = ROOT / 'weights' + +if str(ROOT) not in sys.path: + sys.path.append(str(ROOT)) +if str(ROOT / 'trackers' / 'strongsort') not in sys.path: + sys.path.append(str(ROOT / 'trackers' / 'strongsort')) + +from ultralytics.nn.autobackend import AutoBackend +from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages +from ultralytics.yolo.data.utils import VID_FORMATS +from ultralytics.yolo.utils import LOGGER, colorstr +from ultralytics.yolo.utils.checks import check_file, check_imgsz +from ultralytics.yolo.utils.files import increment_path +from ultralytics.yolo.utils.torch_utils import select_device +from ultralytics.yolo.utils.ops import Profile, non_max_suppression, scale_boxes +from ultralytics.yolo.utils.plotting import Annotator, colors + +from trackers.multi_tracker_zoo import create_tracker +from sender.jsonlogger import JsonLogger +from sender.szmq import ZmqLogger + +@torch.no_grad() +def run( + source='0', + yolo_weights=WEIGHTS / 'yolov8n.pt', + reid_weights=WEIGHTS / 'osnet_x0_25_msmt17.pt', + imgsz=(640, 640), + conf_thres=0.7, + iou_thres=0.45, + max_det=1000, + device='', + show_vid=True, + save_vid=True, + project=ROOT / 'runs' / 'track', + name='exp', + exist_ok=False, + line_thickness=2, + hide_labels=False, + hide_conf=False, + half=False, + vid_stride=1, + enable_json_log=False, + enable_zmq=False, + zmq_ip='localhost', + zmq_port=5555, +): + source = str(source) + is_file = Path(source).suffix[1:] in (VID_FORMATS) + + if is_file: + source = check_file(source) + + device = select_device(device) + + model = AutoBackend(yolo_weights, device=device, dnn=False, fp16=half) + stride, names, pt = model.stride, model.names, model.pt + imgsz = check_imgsz(imgsz, stride=stride) + + dataset = LoadImages( + source, + imgsz=imgsz, + stride=stride, + auto=pt, + transforms=getattr(model.model, 'transforms', None), + vid_stride=vid_stride + ) + bs = len(dataset) + + tracking_config = ROOT / 'trackers' / 'strongsort' / 'configs' / 'strongsort.yaml' + tracker = create_tracker('strongsort', tracking_config, reid_weights, device, half) + + save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) + (save_dir / 'tracks').mkdir(parents=True, exist_ok=True) + + # Initialize loggers + json_logger = JsonLogger(f"{source}-strongsort.log") if enable_json_log else None + zmq_logger = ZmqLogger(zmq_ip, zmq_port) if enable_zmq else None + + vid_path, vid_writer = [None] * bs, [None] * bs + dt = (Profile(), Profile(), Profile()) + + for frame_idx, (path, im, im0s, vid_cap, s) in enumerate(dataset): + + with dt[0]: + im = torch.from_numpy(im).to(model.device) + im = im.half() if model.fp16 else im.float() + im /= 255.0 + if len(im.shape) == 3: + im = im[None] + + with dt[1]: + pred = model(im, augment=False, visualize=False) + + with dt[2]: + pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=max_det) + + for i, det in enumerate(pred): + seen = 0 + p, im0, _ = path, im0s.copy(), dataset.count + p = Path(p) + + annotator = Annotator(im0, line_width=line_thickness, example=str(names)) + + if len(det): + # Filter detections for 'car' class only (class 2 in COCO dataset) + car_mask = det[:, 5] == 2 # car class index is 2 + det = det[car_mask] + + if len(det): + det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() + + for *xyxy, conf, cls in reversed(det): + c = int(cls) + id = f'{c}' + label = None if hide_labels else (f'{id} {names[c]}' if hide_conf else f'{id} {names[c]} {conf:.2f}') + annotator.box_label(xyxy, label, color=colors(c, True)) + + t_outputs = tracker.update(det.cpu(), im0) + + if len(t_outputs) > 0: + for j, (output) in enumerate(t_outputs): + bbox = output[0:4] + id = output[4] + cls = output[5] + conf = output[6] + + # Log tracking data + if json_logger or zmq_logger: + track_data = { + 'bbox': bbox.tolist() if hasattr(bbox, 'tolist') else list(bbox), + 'id': int(id), + 'cls': int(cls), + 'conf': float(conf), + 'frame_idx': frame_idx, + 'source': source, + 'class_name': names[int(cls)] + } + + if json_logger: + json_logger.send(track_data) + if zmq_logger: + zmq_logger.send(track_data) + + if save_vid or show_vid: + c = int(cls) + id = int(id) + label = f'{id} {names[c]}' if not hide_labels else f'{id}' + if not hide_conf: + label += f' {conf:.2f}' + annotator.box_label(bbox, label, color=colors(c, True)) + + im0 = annotator.result() + + if show_vid: + cv2.imshow(str(p), im0) + if cv2.waitKey(1) == ord('q'): + break + + if save_vid: + if vid_path[i] != str(save_dir / p.name): + vid_path[i] = str(save_dir / p.name) + if isinstance(vid_writer[i], cv2.VideoWriter): + vid_writer[i].release() + + if vid_cap: + fps = vid_cap.get(cv2.CAP_PROP_FPS) + w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + else: + fps, w, h = 30, im0.shape[1], im0.shape[0] + + vid_writer[i] = cv2.VideoWriter(vid_path[i], cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) + + vid_writer[i].write(im0) + + LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms") + + for i, vid_writer_obj in enumerate(vid_writer): + if isinstance(vid_writer_obj, cv2.VideoWriter): + vid_writer_obj.release() + + cv2.destroyAllWindows() + + LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") + +def xyxy2xywh(x): + # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center + y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center + y[:, 2] = x[:, 2] - x[:, 0] # width + y[:, 3] = x[:, 3] - x[:, 1] # height + return y + +def parse_opt(): + parser = argparse.ArgumentParser() + parser.add_argument('--source', type=str, default='0', help='file/dir/URL/glob, 0 for webcam') + parser.add_argument('--yolo-weights', nargs='+', type=str, default=WEIGHTS / 'yolov8n.pt', help='model path') + parser.add_argument('--reid-weights', type=str, default=WEIGHTS / 'osnet_x0_25_msmt17.pt') + parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w') + parser.add_argument('--conf-thres', type=float, default=0.7, help='confidence threshold') + parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold') + parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image') + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--show-vid', action='store_true', help='display results') + parser.add_argument('--save-vid', action='store_true', help='save video tracking results') + parser.add_argument('--project', default=ROOT / 'runs' / 'track', help='save results to project/name') + parser.add_argument('--name', default='exp', help='save results to project/name') + parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') + parser.add_argument('--line-thickness', default=2, type=int, help='bounding box thickness (pixels)') + parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels') + parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences') + parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') + parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride') + parser.add_argument('--enable-json-log', action='store_true', help='enable JSON file logging') + parser.add_argument('--enable-zmq', action='store_true', help='enable ZMQ messaging') + parser.add_argument('--zmq-ip', type=str, default='localhost', help='ZMQ server IP') + parser.add_argument('--zmq-port', type=int, default=5555, help='ZMQ server port') + + opt = parser.parse_args() + opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 + return opt + +def main(opt): + run(**vars(opt)) + +if __name__ == "__main__": + opt = parse_opt() + main(opt) \ No newline at end of file diff --git a/feeder/trackers/__init__.py b/feeder/trackers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feeder/trackers/botsort/basetrack.py b/feeder/trackers/botsort/basetrack.py new file mode 100644 index 0000000..c8d4c15 --- /dev/null +++ b/feeder/trackers/botsort/basetrack.py @@ -0,0 +1,60 @@ +import numpy as np +from collections import OrderedDict + + +class TrackState(object): + New = 0 + Tracked = 1 + Lost = 2 + LongLost = 3 + Removed = 4 + + +class BaseTrack(object): + _count = 0 + + track_id = 0 + is_activated = False + state = TrackState.New + + history = OrderedDict() + features = [] + curr_feature = None + score = 0 + start_frame = 0 + frame_id = 0 + time_since_update = 0 + + # multi-camera + location = (np.inf, np.inf) + + @property + def end_frame(self): + return self.frame_id + + @staticmethod + def next_id(): + BaseTrack._count += 1 + return BaseTrack._count + + def activate(self, *args): + raise NotImplementedError + + def predict(self): + raise NotImplementedError + + def update(self, *args, **kwargs): + raise NotImplementedError + + def mark_lost(self): + self.state = TrackState.Lost + + def mark_long_lost(self): + self.state = TrackState.LongLost + + def mark_removed(self): + self.state = TrackState.Removed + + @staticmethod + def clear_count(): + BaseTrack._count = 0 diff --git a/feeder/trackers/botsort/bot_sort.py b/feeder/trackers/botsort/bot_sort.py new file mode 100644 index 0000000..1144c17 --- /dev/null +++ b/feeder/trackers/botsort/bot_sort.py @@ -0,0 +1,534 @@ +import cv2 +import matplotlib.pyplot as plt +import numpy as np +from collections import deque + +from trackers.botsort import matching +from trackers.botsort.gmc import GMC +from trackers.botsort.basetrack import BaseTrack, TrackState +from trackers.botsort.kalman_filter import KalmanFilter + +# from fast_reid.fast_reid_interfece import FastReIDInterface + +from reid_multibackend import ReIDDetectMultiBackend +from ultralytics.yolo.utils.ops import xyxy2xywh, xywh2xyxy + + +class STrack(BaseTrack): + shared_kalman = KalmanFilter() + + def __init__(self, tlwh, score, cls, feat=None, feat_history=50): + + # wait activate + self._tlwh = np.asarray(tlwh, dtype=np.float32) + self.kalman_filter = None + self.mean, self.covariance = None, None + self.is_activated = False + + self.cls = -1 + self.cls_hist = [] # (cls id, freq) + self.update_cls(cls, score) + + self.score = score + self.tracklet_len = 0 + + self.smooth_feat = None + self.curr_feat = None + if feat is not None: + self.update_features(feat) + self.features = deque([], maxlen=feat_history) + self.alpha = 0.9 + + def update_features(self, feat): + feat /= np.linalg.norm(feat) + self.curr_feat = feat + if self.smooth_feat is None: + self.smooth_feat = feat + else: + self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat + self.features.append(feat) + self.smooth_feat /= np.linalg.norm(self.smooth_feat) + + def update_cls(self, cls, score): + if len(self.cls_hist) > 0: + max_freq = 0 + found = False + for c in self.cls_hist: + if cls == c[0]: + c[1] += score + found = True + + if c[1] > max_freq: + max_freq = c[1] + self.cls = c[0] + if not found: + self.cls_hist.append([cls, score]) + self.cls = cls + else: + self.cls_hist.append([cls, score]) + self.cls = cls + + def predict(self): + mean_state = self.mean.copy() + if self.state != TrackState.Tracked: + mean_state[6] = 0 + mean_state[7] = 0 + + self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) + + @staticmethod + def multi_predict(stracks): + if len(stracks) > 0: + multi_mean = np.asarray([st.mean.copy() for st in stracks]) + multi_covariance = np.asarray([st.covariance for st in stracks]) + for i, st in enumerate(stracks): + if st.state != TrackState.Tracked: + multi_mean[i][6] = 0 + multi_mean[i][7] = 0 + multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance) + for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): + stracks[i].mean = mean + stracks[i].covariance = cov + + @staticmethod + def multi_gmc(stracks, H=np.eye(2, 3)): + if len(stracks) > 0: + multi_mean = np.asarray([st.mean.copy() for st in stracks]) + multi_covariance = np.asarray([st.covariance for st in stracks]) + + R = H[:2, :2] + R8x8 = np.kron(np.eye(4, dtype=float), R) + t = H[:2, 2] + + for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): + mean = R8x8.dot(mean) + mean[:2] += t + cov = R8x8.dot(cov).dot(R8x8.transpose()) + + stracks[i].mean = mean + stracks[i].covariance = cov + + def activate(self, kalman_filter, frame_id): + """Start a new tracklet""" + self.kalman_filter = kalman_filter + self.track_id = self.next_id() + + self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xywh(self._tlwh)) + + self.tracklet_len = 0 + self.state = TrackState.Tracked + if frame_id == 1: + self.is_activated = True + self.frame_id = frame_id + self.start_frame = frame_id + + def re_activate(self, new_track, frame_id, new_id=False): + + self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, self.tlwh_to_xywh(new_track.tlwh)) + if new_track.curr_feat is not None: + self.update_features(new_track.curr_feat) + self.tracklet_len = 0 + self.state = TrackState.Tracked + self.is_activated = True + self.frame_id = frame_id + if new_id: + self.track_id = self.next_id() + self.score = new_track.score + + self.update_cls(new_track.cls, new_track.score) + + def update(self, new_track, frame_id): + """ + Update a matched track + :type new_track: STrack + :type frame_id: int + :type update_feature: bool + :return: + """ + self.frame_id = frame_id + self.tracklet_len += 1 + + new_tlwh = new_track.tlwh + + self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, self.tlwh_to_xywh(new_tlwh)) + + if new_track.curr_feat is not None: + self.update_features(new_track.curr_feat) + + self.state = TrackState.Tracked + self.is_activated = True + + self.score = new_track.score + self.update_cls(new_track.cls, new_track.score) + + @property + def tlwh(self): + """Get current position in bounding box format `(top left x, top left y, + width, height)`. + """ + if self.mean is None: + return self._tlwh.copy() + ret = self.mean[:4].copy() + ret[:2] -= ret[2:] / 2 + return ret + + @property + def tlbr(self): + """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., + `(top left, bottom right)`. + """ + ret = self.tlwh.copy() + ret[2:] += ret[:2] + return ret + + @property + def xywh(self): + """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., + `(top left, bottom right)`. + """ + ret = self.tlwh.copy() + ret[:2] += ret[2:] / 2.0 + return ret + + @staticmethod + def tlwh_to_xyah(tlwh): + """Convert bounding box to format `(center x, center y, aspect ratio, + height)`, where the aspect ratio is `width / height`. + """ + ret = np.asarray(tlwh).copy() + ret[:2] += ret[2:] / 2 + ret[2] /= ret[3] + return ret + + @staticmethod + def tlwh_to_xywh(tlwh): + """Convert bounding box to format `(center x, center y, width, + height)`. + """ + ret = np.asarray(tlwh).copy() + ret[:2] += ret[2:] / 2 + return ret + + def to_xywh(self): + return self.tlwh_to_xywh(self.tlwh) + + @staticmethod + def tlbr_to_tlwh(tlbr): + ret = np.asarray(tlbr).copy() + ret[2:] -= ret[:2] + return ret + + @staticmethod + def tlwh_to_tlbr(tlwh): + ret = np.asarray(tlwh).copy() + ret[2:] += ret[:2] + return ret + + def __repr__(self): + return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame) + + +class BoTSORT(object): + def __init__(self, + model_weights, + device, + fp16, + track_high_thresh:float = 0.45, + new_track_thresh:float = 0.6, + track_buffer:int = 30, + match_thresh:float = 0.8, + proximity_thresh:float = 0.5, + appearance_thresh:float = 0.25, + cmc_method:str = 'sparseOptFlow', + frame_rate=30, + lambda_=0.985 + ): + + self.tracked_stracks = [] # type: list[STrack] + self.lost_stracks = [] # type: list[STrack] + self.removed_stracks = [] # type: list[STrack] + BaseTrack.clear_count() + + self.frame_id = 0 + + self.lambda_ = lambda_ + self.track_high_thresh = track_high_thresh + self.new_track_thresh = new_track_thresh + + self.buffer_size = int(frame_rate / 30.0 * track_buffer) + self.max_time_lost = self.buffer_size + self.kalman_filter = KalmanFilter() + + # ReID module + self.proximity_thresh = proximity_thresh + self.appearance_thresh = appearance_thresh + self.match_thresh = match_thresh + + self.model = ReIDDetectMultiBackend(weights=model_weights, device=device, fp16=fp16) + + self.gmc = GMC(method=cmc_method, verbose=[None,False]) + + def update(self, output_results, img): + self.frame_id += 1 + activated_starcks = [] + refind_stracks = [] + lost_stracks = [] + removed_stracks = [] + + xyxys = output_results[:, 0:4] + xywh = xyxy2xywh(xyxys.numpy()) + confs = output_results[:, 4] + clss = output_results[:, 5] + + classes = clss.numpy() + xyxys = xyxys.numpy() + confs = confs.numpy() + + remain_inds = confs > self.track_high_thresh + inds_low = confs > 0.1 + inds_high = confs < self.track_high_thresh + + inds_second = np.logical_and(inds_low, inds_high) + + dets_second = xywh[inds_second] + dets = xywh[remain_inds] + + scores_keep = confs[remain_inds] + scores_second = confs[inds_second] + + classes_keep = classes[remain_inds] + clss_second = classes[inds_second] + + self.height, self.width = img.shape[:2] + + '''Extract embeddings ''' + features_keep = self._get_features(dets, img) + + if len(dets) > 0: + '''Detections''' + + detections = [STrack(xyxy, s, c, f.cpu().numpy()) for + (xyxy, s, c, f) in zip(dets, scores_keep, classes_keep, features_keep)] + else: + detections = [] + + ''' Add newly detected tracklets to tracked_stracks''' + unconfirmed = [] + tracked_stracks = [] # type: list[STrack] + for track in self.tracked_stracks: + if not track.is_activated: + unconfirmed.append(track) + else: + tracked_stracks.append(track) + + ''' Step 2: First association, with high score detection boxes''' + strack_pool = joint_stracks(tracked_stracks, self.lost_stracks) + + # Predict the current location with KF + STrack.multi_predict(strack_pool) + + # Fix camera motion + warp = self.gmc.apply(img, dets) + STrack.multi_gmc(strack_pool, warp) + STrack.multi_gmc(unconfirmed, warp) + + # Associate with high score detection boxes + raw_emb_dists = matching.embedding_distance(strack_pool, detections) + dists = matching.fuse_motion(self.kalman_filter, raw_emb_dists, strack_pool, detections, only_position=False, lambda_=self.lambda_) + + # ious_dists = matching.iou_distance(strack_pool, detections) + # ious_dists_mask = (ious_dists > self.proximity_thresh) + + # ious_dists = matching.fuse_score(ious_dists, detections) + + # emb_dists = matching.embedding_distance(strack_pool, detections) / 2.0 + # raw_emb_dists = emb_dists.copy() + # emb_dists[emb_dists > self.appearance_thresh] = 1.0 + # emb_dists[ious_dists_mask] = 1.0 + # dists = np.minimum(ious_dists, emb_dists) + + # Popular ReID method (JDE / FairMOT) + # raw_emb_dists = matching.embedding_distance(strack_pool, detections) + # dists = matching.fuse_motion(self.kalman_filter, raw_emb_dists, strack_pool, detections) + # emb_dists = dists + + # IoU making ReID + # dists = matching.embedding_distance(strack_pool, detections) + # dists[ious_dists_mask] = 1.0 + + matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.match_thresh) + + for itracked, idet in matches: + track = strack_pool[itracked] + det = detections[idet] + if track.state == TrackState.Tracked: + track.update(detections[idet], self.frame_id) + activated_starcks.append(track) + else: + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + + ''' Step 3: Second association, with low score detection boxes''' + # if len(scores): + # inds_high = scores < self.track_high_thresh + # inds_low = scores > self.track_low_thresh + # inds_second = np.logical_and(inds_low, inds_high) + # dets_second = bboxes[inds_second] + # scores_second = scores[inds_second] + # classes_second = classes[inds_second] + # else: + # dets_second = [] + # scores_second = [] + # classes_second = [] + + # association the untrack to the low score detections + if len(dets_second) > 0: + '''Detections''' + detections_second = [STrack(STrack.tlbr_to_tlwh(tlbr), s, c) for + (tlbr, s, c) in zip(dets_second, scores_second, clss_second)] + else: + detections_second = [] + + r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked] + dists = matching.iou_distance(r_tracked_stracks, detections_second) + matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5) + for itracked, idet in matches: + track = r_tracked_stracks[itracked] + det = detections_second[idet] + if track.state == TrackState.Tracked: + track.update(det, self.frame_id) + activated_starcks.append(track) + else: + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + + for it in u_track: + track = r_tracked_stracks[it] + if not track.state == TrackState.Lost: + track.mark_lost() + lost_stracks.append(track) + + '''Deal with unconfirmed tracks, usually tracks with only one beginning frame''' + detections = [detections[i] for i in u_detection] + ious_dists = matching.iou_distance(unconfirmed, detections) + ious_dists_mask = (ious_dists > self.proximity_thresh) + + ious_dists = matching.fuse_score(ious_dists, detections) + + emb_dists = matching.embedding_distance(unconfirmed, detections) / 2.0 + raw_emb_dists = emb_dists.copy() + emb_dists[emb_dists > self.appearance_thresh] = 1.0 + emb_dists[ious_dists_mask] = 1.0 + dists = np.minimum(ious_dists, emb_dists) + + matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7) + for itracked, idet in matches: + unconfirmed[itracked].update(detections[idet], self.frame_id) + activated_starcks.append(unconfirmed[itracked]) + for it in u_unconfirmed: + track = unconfirmed[it] + track.mark_removed() + removed_stracks.append(track) + + """ Step 4: Init new stracks""" + for inew in u_detection: + track = detections[inew] + if track.score < self.new_track_thresh: + continue + + track.activate(self.kalman_filter, self.frame_id) + activated_starcks.append(track) + + """ Step 5: Update state""" + for track in self.lost_stracks: + if self.frame_id - track.end_frame > self.max_time_lost: + track.mark_removed() + removed_stracks.append(track) + + """ Merge """ + self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked] + self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks) + self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks) + self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks) + self.lost_stracks.extend(lost_stracks) + self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks) + self.removed_stracks.extend(removed_stracks) + self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks) + + # output_stracks = [track for track in self.tracked_stracks if track.is_activated] + output_stracks = [track for track in self.tracked_stracks if track.is_activated] + outputs = [] + for t in output_stracks: + output= [] + tlwh = t.tlwh + tid = t.track_id + tlwh = np.expand_dims(tlwh, axis=0) + xyxy = xywh2xyxy(tlwh) + xyxy = np.squeeze(xyxy, axis=0) + output.extend(xyxy) + output.append(tid) + output.append(t.cls) + output.append(t.score) + outputs.append(output) + + return outputs + + def _xywh_to_xyxy(self, bbox_xywh): + x, y, w, h = bbox_xywh + x1 = max(int(x - w / 2), 0) + x2 = min(int(x + w / 2), self.width - 1) + y1 = max(int(y - h / 2), 0) + y2 = min(int(y + h / 2), self.height - 1) + return x1, y1, x2, y2 + + def _get_features(self, bbox_xywh, ori_img): + im_crops = [] + for box in bbox_xywh: + x1, y1, x2, y2 = self._xywh_to_xyxy(box) + im = ori_img[y1:y2, x1:x2] + im_crops.append(im) + if im_crops: + features = self.model(im_crops) + else: + features = np.array([]) + return features + +def joint_stracks(tlista, tlistb): + exists = {} + res = [] + for t in tlista: + exists[t.track_id] = 1 + res.append(t) + for t in tlistb: + tid = t.track_id + if not exists.get(tid, 0): + exists[tid] = 1 + res.append(t) + return res + + +def sub_stracks(tlista, tlistb): + stracks = {} + for t in tlista: + stracks[t.track_id] = t + for t in tlistb: + tid = t.track_id + if stracks.get(tid, 0): + del stracks[tid] + return list(stracks.values()) + + +def remove_duplicate_stracks(stracksa, stracksb): + pdist = matching.iou_distance(stracksa, stracksb) + pairs = np.where(pdist < 0.15) + dupa, dupb = list(), list() + for p, q in zip(*pairs): + timep = stracksa[p].frame_id - stracksa[p].start_frame + timeq = stracksb[q].frame_id - stracksb[q].start_frame + if timep > timeq: + dupb.append(q) + else: + dupa.append(p) + resa = [t for i, t in enumerate(stracksa) if not i in dupa] + resb = [t for i, t in enumerate(stracksb) if not i in dupb] + return resa, resb diff --git a/feeder/trackers/botsort/configs/botsort.yaml b/feeder/trackers/botsort/configs/botsort.yaml new file mode 100644 index 0000000..e5afb91 --- /dev/null +++ b/feeder/trackers/botsort/configs/botsort.yaml @@ -0,0 +1,13 @@ +# Trial number: 232 +# HOTA, MOTA, IDF1: [45.31] +botsort: + appearance_thresh: 0.4818211117541298 + cmc_method: sparseOptFlow + conf_thres: 0.3501265956918775 + frame_rate: 30 + lambda_: 0.9896143462366406 + match_thresh: 0.22734550911325851 + new_track_thresh: 0.21144301345190655 + proximity_thresh: 0.5945380911899254 + track_buffer: 60 + track_high_thresh: 0.33824964456239337 diff --git a/feeder/trackers/botsort/gmc.py b/feeder/trackers/botsort/gmc.py new file mode 100644 index 0000000..e7ec207 --- /dev/null +++ b/feeder/trackers/botsort/gmc.py @@ -0,0 +1,316 @@ +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import copy +import time + + +class GMC: + def __init__(self, method='sparseOptFlow', downscale=2, verbose=None): + super(GMC, self).__init__() + + self.method = method + self.downscale = max(1, int(downscale)) + + if self.method == 'orb': + self.detector = cv2.FastFeatureDetector_create(20) + self.extractor = cv2.ORB_create() + self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING) + + elif self.method == 'sift': + self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20) + self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20) + self.matcher = cv2.BFMatcher(cv2.NORM_L2) + + elif self.method == 'ecc': + number_of_iterations = 5000 + termination_eps = 1e-6 + self.warp_mode = cv2.MOTION_EUCLIDEAN + self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps) + + elif self.method == 'sparseOptFlow': + self.feature_params = dict(maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3, + useHarrisDetector=False, k=0.04) + # self.gmc_file = open('GMC_results.txt', 'w') + + elif self.method == 'file' or self.method == 'files': + seqName = verbose[0] + ablation = verbose[1] + if ablation: + filePath = r'tracker/GMC_files/MOT17_ablation' + else: + filePath = r'tracker/GMC_files/MOTChallenge' + + if '-FRCNN' in seqName: + seqName = seqName[:-6] + elif '-DPM' in seqName: + seqName = seqName[:-4] + elif '-SDP' in seqName: + seqName = seqName[:-4] + + self.gmcFile = open(filePath + "/GMC-" + seqName + ".txt", 'r') + + if self.gmcFile is None: + raise ValueError("Error: Unable to open GMC file in directory:" + filePath) + elif self.method == 'none' or self.method == 'None': + self.method = 'none' + else: + raise ValueError("Error: Unknown CMC method:" + method) + + self.prevFrame = None + self.prevKeyPoints = None + self.prevDescriptors = None + + self.initializedFirstFrame = False + + def apply(self, raw_frame, detections=None): + if self.method == 'orb' or self.method == 'sift': + return self.applyFeaures(raw_frame, detections) + elif self.method == 'ecc': + return self.applyEcc(raw_frame, detections) + elif self.method == 'sparseOptFlow': + return self.applySparseOptFlow(raw_frame, detections) + elif self.method == 'file': + return self.applyFile(raw_frame, detections) + elif self.method == 'none': + return np.eye(2, 3) + else: + return np.eye(2, 3) + + def applyEcc(self, raw_frame, detections=None): + + # Initialize + height, width, _ = raw_frame.shape + frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) + H = np.eye(2, 3, dtype=np.float32) + + # Downscale image (TODO: consider using pyramids) + if self.downscale > 1.0: + frame = cv2.GaussianBlur(frame, (3, 3), 1.5) + frame = cv2.resize(frame, (width // self.downscale, height // self.downscale)) + width = width // self.downscale + height = height // self.downscale + + # Handle first frame + if not self.initializedFirstFrame: + # Initialize data + self.prevFrame = frame.copy() + + # Initialization done + self.initializedFirstFrame = True + + return H + + # Run the ECC algorithm. The results are stored in warp_matrix. + # (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria) + try: + (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1) + except: + print('Warning: find transform failed. Set warp as identity') + + return H + + def applyFeaures(self, raw_frame, detections=None): + + # Initialize + height, width, _ = raw_frame.shape + frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) + H = np.eye(2, 3) + + # Downscale image (TODO: consider using pyramids) + if self.downscale > 1.0: + # frame = cv2.GaussianBlur(frame, (3, 3), 1.5) + frame = cv2.resize(frame, (width // self.downscale, height // self.downscale)) + width = width // self.downscale + height = height // self.downscale + + # find the keypoints + mask = np.zeros_like(frame) + # mask[int(0.05 * height): int(0.95 * height), int(0.05 * width): int(0.95 * width)] = 255 + mask[int(0.02 * height): int(0.98 * height), int(0.02 * width): int(0.98 * width)] = 255 + if detections is not None: + for det in detections: + tlbr = (det[:4] / self.downscale).astype(np.int_) + mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0 + + keypoints = self.detector.detect(frame, mask) + + # compute the descriptors + keypoints, descriptors = self.extractor.compute(frame, keypoints) + + # Handle first frame + if not self.initializedFirstFrame: + # Initialize data + self.prevFrame = frame.copy() + self.prevKeyPoints = copy.copy(keypoints) + self.prevDescriptors = copy.copy(descriptors) + + # Initialization done + self.initializedFirstFrame = True + + return H + + # Match descriptors. + knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2) + + # Filtered matches based on smallest spatial distance + matches = [] + spatialDistances = [] + + maxSpatialDistance = 0.25 * np.array([width, height]) + + # Handle empty matches case + if len(knnMatches) == 0: + # Store to next iteration + self.prevFrame = frame.copy() + self.prevKeyPoints = copy.copy(keypoints) + self.prevDescriptors = copy.copy(descriptors) + + return H + + for m, n in knnMatches: + if m.distance < 0.9 * n.distance: + prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt + currKeyPointLocation = keypoints[m.trainIdx].pt + + spatialDistance = (prevKeyPointLocation[0] - currKeyPointLocation[0], + prevKeyPointLocation[1] - currKeyPointLocation[1]) + + if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \ + (np.abs(spatialDistance[1]) < maxSpatialDistance[1]): + spatialDistances.append(spatialDistance) + matches.append(m) + + meanSpatialDistances = np.mean(spatialDistances, 0) + stdSpatialDistances = np.std(spatialDistances, 0) + + inliesrs = (spatialDistances - meanSpatialDistances) < 2.5 * stdSpatialDistances + + goodMatches = [] + prevPoints = [] + currPoints = [] + for i in range(len(matches)): + if inliesrs[i, 0] and inliesrs[i, 1]: + goodMatches.append(matches[i]) + prevPoints.append(self.prevKeyPoints[matches[i].queryIdx].pt) + currPoints.append(keypoints[matches[i].trainIdx].pt) + + prevPoints = np.array(prevPoints) + currPoints = np.array(currPoints) + + # Draw the keypoint matches on the output image + if 0: + matches_img = np.hstack((self.prevFrame, frame)) + matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR) + W = np.size(self.prevFrame, 1) + for m in goodMatches: + prev_pt = np.array(self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_) + curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_) + curr_pt[0] += W + color = np.random.randint(0, 255, (3,)) + color = (int(color[0]), int(color[1]), int(color[2])) + + matches_img = cv2.line(matches_img, prev_pt, curr_pt, tuple(color), 1, cv2.LINE_AA) + matches_img = cv2.circle(matches_img, prev_pt, 2, tuple(color), -1) + matches_img = cv2.circle(matches_img, curr_pt, 2, tuple(color), -1) + + plt.figure() + plt.imshow(matches_img) + plt.show() + + # Find rigid matrix + if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)): + H, inliesrs = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC) + + # Handle downscale + if self.downscale > 1.0: + H[0, 2] *= self.downscale + H[1, 2] *= self.downscale + else: + print('Warning: not enough matching points') + + # Store to next iteration + self.prevFrame = frame.copy() + self.prevKeyPoints = copy.copy(keypoints) + self.prevDescriptors = copy.copy(descriptors) + + return H + + def applySparseOptFlow(self, raw_frame, detections=None): + + t0 = time.time() + + # Initialize + height, width, _ = raw_frame.shape + frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) + H = np.eye(2, 3) + + # Downscale image + if self.downscale > 1.0: + # frame = cv2.GaussianBlur(frame, (3, 3), 1.5) + frame = cv2.resize(frame, (width // self.downscale, height // self.downscale)) + + # find the keypoints + keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params) + + # Handle first frame + if not self.initializedFirstFrame: + # Initialize data + self.prevFrame = frame.copy() + self.prevKeyPoints = copy.copy(keypoints) + + # Initialization done + self.initializedFirstFrame = True + + return H + + # find correspondences + matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None) + + # leave good correspondences only + prevPoints = [] + currPoints = [] + + for i in range(len(status)): + if status[i]: + prevPoints.append(self.prevKeyPoints[i]) + currPoints.append(matchedKeypoints[i]) + + prevPoints = np.array(prevPoints) + currPoints = np.array(currPoints) + + # Find rigid matrix + if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)): + H, inliesrs = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC) + + # Handle downscale + if self.downscale > 1.0: + H[0, 2] *= self.downscale + H[1, 2] *= self.downscale + else: + print('Warning: not enough matching points') + + # Store to next iteration + self.prevFrame = frame.copy() + self.prevKeyPoints = copy.copy(keypoints) + + t1 = time.time() + + # gmc_line = str(1000 * (t1 - t0)) + "\t" + str(H[0, 0]) + "\t" + str(H[0, 1]) + "\t" + str( + # H[0, 2]) + "\t" + str(H[1, 0]) + "\t" + str(H[1, 1]) + "\t" + str(H[1, 2]) + "\n" + # self.gmc_file.write(gmc_line) + + return H + + def applyFile(self, raw_frame, detections=None): + line = self.gmcFile.readline() + tokens = line.split("\t") + H = np.eye(2, 3, dtype=np.float_) + H[0, 0] = float(tokens[1]) + H[0, 1] = float(tokens[2]) + H[0, 2] = float(tokens[3]) + H[1, 0] = float(tokens[4]) + H[1, 1] = float(tokens[5]) + H[1, 2] = float(tokens[6]) + + return H \ No newline at end of file diff --git a/feeder/trackers/botsort/kalman_filter.py b/feeder/trackers/botsort/kalman_filter.py new file mode 100644 index 0000000..02a6eb4 --- /dev/null +++ b/feeder/trackers/botsort/kalman_filter.py @@ -0,0 +1,269 @@ +# vim: expandtab:ts=4:sw=4 +import numpy as np +import scipy.linalg + + +""" +Table for the 0.95 quantile of the chi-square distribution with N degrees of +freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv +function and used as Mahalanobis gating threshold. +""" +chi2inv95 = { + 1: 3.8415, + 2: 5.9915, + 3: 7.8147, + 4: 9.4877, + 5: 11.070, + 6: 12.592, + 7: 14.067, + 8: 15.507, + 9: 16.919} + + +class KalmanFilter(object): + """ + A simple Kalman filter for tracking bounding boxes in image space. + + The 8-dimensional state space + + x, y, w, h, vx, vy, vw, vh + + contains the bounding box center position (x, y), width w, height h, + and their respective velocities. + + Object motion follows a constant velocity model. The bounding box location + (x, y, w, h) is taken as direct observation of the state space (linear + observation model). + + """ + + def __init__(self): + ndim, dt = 4, 1. + + # Create Kalman filter model matrices. + self._motion_mat = np.eye(2 * ndim, 2 * ndim) + for i in range(ndim): + self._motion_mat[i, ndim + i] = dt + self._update_mat = np.eye(ndim, 2 * ndim) + + # Motion and observation uncertainty are chosen relative to the current + # state estimate. These weights control the amount of uncertainty in + # the model. This is a bit hacky. + self._std_weight_position = 1. / 20 + self._std_weight_velocity = 1. / 160 + + def initiate(self, measurement): + """Create track from unassociated measurement. + + Parameters + ---------- + measurement : ndarray + Bounding box coordinates (x, y, w, h) with center position (x, y), + width w, and height h. + + Returns + ------- + (ndarray, ndarray) + Returns the mean vector (8 dimensional) and covariance matrix (8x8 + dimensional) of the new track. Unobserved velocities are initialized + to 0 mean. + + """ + mean_pos = measurement + mean_vel = np.zeros_like(mean_pos) + mean = np.r_[mean_pos, mean_vel] + + std = [ + 2 * self._std_weight_position * measurement[2], + 2 * self._std_weight_position * measurement[3], + 2 * self._std_weight_position * measurement[2], + 2 * self._std_weight_position * measurement[3], + 10 * self._std_weight_velocity * measurement[2], + 10 * self._std_weight_velocity * measurement[3], + 10 * self._std_weight_velocity * measurement[2], + 10 * self._std_weight_velocity * measurement[3]] + covariance = np.diag(np.square(std)) + return mean, covariance + + def predict(self, mean, covariance): + """Run Kalman filter prediction step. + + Parameters + ---------- + mean : ndarray + The 8 dimensional mean vector of the object state at the previous + time step. + covariance : ndarray + The 8x8 dimensional covariance matrix of the object state at the + previous time step. + + Returns + ------- + (ndarray, ndarray) + Returns the mean vector and covariance matrix of the predicted + state. Unobserved velocities are initialized to 0 mean. + + """ + std_pos = [ + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + self._std_weight_position * mean[2], + self._std_weight_position * mean[3]] + std_vel = [ + self._std_weight_velocity * mean[2], + self._std_weight_velocity * mean[3], + self._std_weight_velocity * mean[2], + self._std_weight_velocity * mean[3]] + motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) + + mean = np.dot(mean, self._motion_mat.T) + covariance = np.linalg.multi_dot(( + self._motion_mat, covariance, self._motion_mat.T)) + motion_cov + + return mean, covariance + + def project(self, mean, covariance): + """Project state distribution to measurement space. + + Parameters + ---------- + mean : ndarray + The state's mean vector (8 dimensional array). + covariance : ndarray + The state's covariance matrix (8x8 dimensional). + + Returns + ------- + (ndarray, ndarray) + Returns the projected mean and covariance matrix of the given state + estimate. + + """ + std = [ + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + self._std_weight_position * mean[2], + self._std_weight_position * mean[3]] + innovation_cov = np.diag(np.square(std)) + + mean = np.dot(self._update_mat, mean) + covariance = np.linalg.multi_dot(( + self._update_mat, covariance, self._update_mat.T)) + return mean, covariance + innovation_cov + + def multi_predict(self, mean, covariance): + """Run Kalman filter prediction step (Vectorized version). + Parameters + ---------- + mean : ndarray + The Nx8 dimensional mean matrix of the object states at the previous + time step. + covariance : ndarray + The Nx8x8 dimensional covariance matrics of the object states at the + previous time step. + Returns + ------- + (ndarray, ndarray) + Returns the mean vector and covariance matrix of the predicted + state. Unobserved velocities are initialized to 0 mean. + """ + std_pos = [ + self._std_weight_position * mean[:, 2], + self._std_weight_position * mean[:, 3], + self._std_weight_position * mean[:, 2], + self._std_weight_position * mean[:, 3]] + std_vel = [ + self._std_weight_velocity * mean[:, 2], + self._std_weight_velocity * mean[:, 3], + self._std_weight_velocity * mean[:, 2], + self._std_weight_velocity * mean[:, 3]] + sqr = np.square(np.r_[std_pos, std_vel]).T + + motion_cov = [] + for i in range(len(mean)): + motion_cov.append(np.diag(sqr[i])) + motion_cov = np.asarray(motion_cov) + + mean = np.dot(mean, self._motion_mat.T) + left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2)) + covariance = np.dot(left, self._motion_mat.T) + motion_cov + + return mean, covariance + + def update(self, mean, covariance, measurement): + """Run Kalman filter correction step. + + Parameters + ---------- + mean : ndarray + The predicted state's mean vector (8 dimensional). + covariance : ndarray + The state's covariance matrix (8x8 dimensional). + measurement : ndarray + The 4 dimensional measurement vector (x, y, w, h), where (x, y) + is the center position, w the width, and h the height of the + bounding box. + + Returns + ------- + (ndarray, ndarray) + Returns the measurement-corrected state distribution. + + """ + projected_mean, projected_cov = self.project(mean, covariance) + + chol_factor, lower = scipy.linalg.cho_factor( + projected_cov, lower=True, check_finite=False) + kalman_gain = scipy.linalg.cho_solve( + (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, + check_finite=False).T + innovation = measurement - projected_mean + + new_mean = mean + np.dot(innovation, kalman_gain.T) + new_covariance = covariance - np.linalg.multi_dot(( + kalman_gain, projected_cov, kalman_gain.T)) + return new_mean, new_covariance + + def gating_distance(self, mean, covariance, measurements, + only_position=False, metric='maha'): + """Compute gating distance between state distribution and measurements. + A suitable distance threshold can be obtained from `chi2inv95`. If + `only_position` is False, the chi-square distribution has 4 degrees of + freedom, otherwise 2. + Parameters + ---------- + mean : ndarray + Mean vector over the state distribution (8 dimensional). + covariance : ndarray + Covariance of the state distribution (8x8 dimensional). + measurements : ndarray + An Nx4 dimensional matrix of N measurements, each in + format (x, y, a, h) where (x, y) is the bounding box center + position, a the aspect ratio, and h the height. + only_position : Optional[bool] + If True, distance computation is done with respect to the bounding + box center position only. + Returns + ------- + ndarray + Returns an array of length N, where the i-th element contains the + squared Mahalanobis distance between (mean, covariance) and + `measurements[i]`. + """ + mean, covariance = self.project(mean, covariance) + if only_position: + mean, covariance = mean[:2], covariance[:2, :2] + measurements = measurements[:, :2] + + d = measurements - mean + if metric == 'gaussian': + return np.sum(d * d, axis=1) + elif metric == 'maha': + cholesky_factor = np.linalg.cholesky(covariance) + z = scipy.linalg.solve_triangular( + cholesky_factor, d.T, lower=True, check_finite=False, + overwrite_b=True) + squared_maha = np.sum(z * z, axis=0) + return squared_maha + else: + raise ValueError('invalid distance metric') \ No newline at end of file diff --git a/feeder/trackers/botsort/matching.py b/feeder/trackers/botsort/matching.py new file mode 100644 index 0000000..756dd45 --- /dev/null +++ b/feeder/trackers/botsort/matching.py @@ -0,0 +1,234 @@ +import numpy as np +import scipy +import lap +from scipy.spatial.distance import cdist + +from trackers.botsort import kalman_filter + + +def merge_matches(m1, m2, shape): + O,P,Q = shape + m1 = np.asarray(m1) + m2 = np.asarray(m2) + + M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P)) + M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q)) + + mask = M1*M2 + match = mask.nonzero() + match = list(zip(match[0], match[1])) + unmatched_O = tuple(set(range(O)) - set([i for i, j in match])) + unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match])) + + return match, unmatched_O, unmatched_Q + + +def _indices_to_matches(cost_matrix, indices, thresh): + matched_cost = cost_matrix[tuple(zip(*indices))] + matched_mask = (matched_cost <= thresh) + + matches = indices[matched_mask] + unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0])) + unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1])) + + return matches, unmatched_a, unmatched_b + + +def linear_assignment(cost_matrix, thresh): + if cost_matrix.size == 0: + return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) + matches, unmatched_a, unmatched_b = [], [], [] + cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) + for ix, mx in enumerate(x): + if mx >= 0: + matches.append([ix, mx]) + unmatched_a = np.where(x < 0)[0] + unmatched_b = np.where(y < 0)[0] + matches = np.asarray(matches) + return matches, unmatched_a, unmatched_b + + +def ious(atlbrs, btlbrs): + """ + Compute cost based on IoU + :type atlbrs: list[tlbr] | np.ndarray + :type atlbrs: list[tlbr] | np.ndarray + + :rtype ious np.ndarray + """ + ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) + if ious.size == 0: + return ious + + ious = bbox_ious( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32) + ) + + return ious + + +def tlbr_expand(tlbr, scale=1.2): + w = tlbr[2] - tlbr[0] + h = tlbr[3] - tlbr[1] + + half_scale = 0.5 * scale + + tlbr[0] -= half_scale * w + tlbr[1] -= half_scale * h + tlbr[2] += half_scale * w + tlbr[3] += half_scale * h + + return tlbr + + +def iou_distance(atracks, btracks): + """ + Compute cost based on IoU + :type atracks: list[STrack] + :type btracks: list[STrack] + + :rtype cost_matrix np.ndarray + """ + + if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)): + atlbrs = atracks + btlbrs = btracks + else: + atlbrs = [track.tlbr for track in atracks] + btlbrs = [track.tlbr for track in btracks] + _ious = ious(atlbrs, btlbrs) + cost_matrix = 1 - _ious + + return cost_matrix + + +def v_iou_distance(atracks, btracks): + """ + Compute cost based on IoU + :type atracks: list[STrack] + :type btracks: list[STrack] + + :rtype cost_matrix np.ndarray + """ + + if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)): + atlbrs = atracks + btlbrs = btracks + else: + atlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in atracks] + btlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in btracks] + _ious = ious(atlbrs, btlbrs) + cost_matrix = 1 - _ious + + return cost_matrix + + +def embedding_distance(tracks, detections, metric='cosine'): + """ + :param tracks: list[STrack] + :param detections: list[BaseTrack] + :param metric: + :return: cost_matrix np.ndarray + """ + + cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32) + if cost_matrix.size == 0: + return cost_matrix + det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32) + track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32) + + cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # / 2.0 # Nomalized features + return cost_matrix + + +def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False): + if cost_matrix.size == 0: + return cost_matrix + gating_dim = 2 if only_position else 4 + gating_threshold = kalman_filter.chi2inv95[gating_dim] + # measurements = np.asarray([det.to_xyah() for det in detections]) + measurements = np.asarray([det.to_xywh() for det in detections]) + for row, track in enumerate(tracks): + gating_distance = kf.gating_distance( + track.mean, track.covariance, measurements, only_position) + cost_matrix[row, gating_distance > gating_threshold] = np.inf + return cost_matrix + + +def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98): + if cost_matrix.size == 0: + return cost_matrix + gating_dim = 2 if only_position else 4 + gating_threshold = kalman_filter.chi2inv95[gating_dim] + # measurements = np.asarray([det.to_xyah() for det in detections]) + measurements = np.asarray([det.to_xywh() for det in detections]) + for row, track in enumerate(tracks): + gating_distance = kf.gating_distance( + track.mean, track.covariance, measurements, only_position, metric='maha') + cost_matrix[row, gating_distance > gating_threshold] = np.inf + cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance + return cost_matrix + + +def fuse_iou(cost_matrix, tracks, detections): + if cost_matrix.size == 0: + return cost_matrix + reid_sim = 1 - cost_matrix + iou_dist = iou_distance(tracks, detections) + iou_sim = 1 - iou_dist + fuse_sim = reid_sim * (1 + iou_sim) / 2 + det_scores = np.array([det.score for det in detections]) + det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) + #fuse_sim = fuse_sim * (1 + det_scores) / 2 + fuse_cost = 1 - fuse_sim + return fuse_cost + + +def fuse_score(cost_matrix, detections): + if cost_matrix.size == 0: + return cost_matrix + iou_sim = 1 - cost_matrix + det_scores = np.array([det.score for det in detections]) + det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) + fuse_sim = iou_sim * det_scores + fuse_cost = 1 - fuse_sim + return fuse_cost + +def bbox_ious(boxes, query_boxes): + """ + Parameters + ---------- + boxes: (N, 4) ndarray of float + query_boxes: (K, 4) ndarray of float + Returns + ------- + overlaps: (N, K) ndarray of overlap between boxes and query_boxes + """ + N = boxes.shape[0] + K = query_boxes.shape[0] + overlaps = np.zeros((N, K), dtype=np.float32) + + for k in range(K): + box_area = ( + (query_boxes[k, 2] - query_boxes[k, 0] + 1) * + (query_boxes[k, 3] - query_boxes[k, 1] + 1) + ) + for n in range(N): + iw = ( + min(boxes[n, 2], query_boxes[k, 2]) - + max(boxes[n, 0], query_boxes[k, 0]) + 1 + ) + if iw > 0: + ih = ( + min(boxes[n, 3], query_boxes[k, 3]) - + max(boxes[n, 1], query_boxes[k, 1]) + 1 + ) + if ih > 0: + ua = float( + (boxes[n, 2] - boxes[n, 0] + 1) * + (boxes[n, 3] - boxes[n, 1] + 1) + + box_area - iw * ih + ) + overlaps[n, k] = iw * ih / ua + return overlaps \ No newline at end of file diff --git a/feeder/trackers/bytetrack/basetrack.py b/feeder/trackers/bytetrack/basetrack.py new file mode 100644 index 0000000..4fe2233 --- /dev/null +++ b/feeder/trackers/bytetrack/basetrack.py @@ -0,0 +1,52 @@ +import numpy as np +from collections import OrderedDict + + +class TrackState(object): + New = 0 + Tracked = 1 + Lost = 2 + Removed = 3 + + +class BaseTrack(object): + _count = 0 + + track_id = 0 + is_activated = False + state = TrackState.New + + history = OrderedDict() + features = [] + curr_feature = None + score = 0 + start_frame = 0 + frame_id = 0 + time_since_update = 0 + + # multi-camera + location = (np.inf, np.inf) + + @property + def end_frame(self): + return self.frame_id + + @staticmethod + def next_id(): + BaseTrack._count += 1 + return BaseTrack._count + + def activate(self, *args): + raise NotImplementedError + + def predict(self): + raise NotImplementedError + + def update(self, *args, **kwargs): + raise NotImplementedError + + def mark_lost(self): + self.state = TrackState.Lost + + def mark_removed(self): + self.state = TrackState.Removed diff --git a/feeder/trackers/bytetrack/byte_tracker.py b/feeder/trackers/bytetrack/byte_tracker.py new file mode 100644 index 0000000..e74afe4 --- /dev/null +++ b/feeder/trackers/bytetrack/byte_tracker.py @@ -0,0 +1,348 @@ +import numpy as np + +from ultralytics.yolo.utils.ops import xywh2xyxy, xyxy2xywh + + +from trackers.bytetrack.kalman_filter import KalmanFilter +from trackers.bytetrack import matching +from trackers.bytetrack.basetrack import BaseTrack, TrackState + +class STrack(BaseTrack): + shared_kalman = KalmanFilter() + def __init__(self, tlwh, score, cls): + + # wait activate + self._tlwh = np.asarray(tlwh, dtype=np.float32) + self.kalman_filter = None + self.mean, self.covariance = None, None + self.is_activated = False + + self.score = score + self.tracklet_len = 0 + self.cls = cls + + def predict(self): + mean_state = self.mean.copy() + if self.state != TrackState.Tracked: + mean_state[7] = 0 + self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) + + @staticmethod + def multi_predict(stracks): + if len(stracks) > 0: + multi_mean = np.asarray([st.mean.copy() for st in stracks]) + multi_covariance = np.asarray([st.covariance for st in stracks]) + for i, st in enumerate(stracks): + if st.state != TrackState.Tracked: + multi_mean[i][7] = 0 + multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance) + for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): + stracks[i].mean = mean + stracks[i].covariance = cov + + def activate(self, kalman_filter, frame_id): + """Start a new tracklet""" + self.kalman_filter = kalman_filter + self.track_id = self.next_id() + self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh)) + + self.tracklet_len = 0 + self.state = TrackState.Tracked + if frame_id == 1: + self.is_activated = True + # self.is_activated = True + self.frame_id = frame_id + self.start_frame = frame_id + + def re_activate(self, new_track, frame_id, new_id=False): + self.mean, self.covariance = self.kalman_filter.update( + self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh) + ) + self.tracklet_len = 0 + self.state = TrackState.Tracked + self.is_activated = True + self.frame_id = frame_id + if new_id: + self.track_id = self.next_id() + self.score = new_track.score + self.cls = new_track.cls + + def update(self, new_track, frame_id): + """ + Update a matched track + :type new_track: STrack + :type frame_id: int + :type update_feature: bool + :return: + """ + self.frame_id = frame_id + self.tracklet_len += 1 + # self.cls = cls + + new_tlwh = new_track.tlwh + self.mean, self.covariance = self.kalman_filter.update( + self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh)) + self.state = TrackState.Tracked + self.is_activated = True + + self.score = new_track.score + + @property + # @jit(nopython=True) + def tlwh(self): + """Get current position in bounding box format `(top left x, top left y, + width, height)`. + """ + if self.mean is None: + return self._tlwh.copy() + ret = self.mean[:4].copy() + ret[2] *= ret[3] + ret[:2] -= ret[2:] / 2 + return ret + + @property + # @jit(nopython=True) + def tlbr(self): + """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., + `(top left, bottom right)`. + """ + ret = self.tlwh.copy() + ret[2:] += ret[:2] + return ret + + @staticmethod + # @jit(nopython=True) + def tlwh_to_xyah(tlwh): + """Convert bounding box to format `(center x, center y, aspect ratio, + height)`, where the aspect ratio is `width / height`. + """ + ret = np.asarray(tlwh).copy() + ret[:2] += ret[2:] / 2 + ret[2] /= ret[3] + return ret + + def to_xyah(self): + return self.tlwh_to_xyah(self.tlwh) + + @staticmethod + # @jit(nopython=True) + def tlbr_to_tlwh(tlbr): + ret = np.asarray(tlbr).copy() + ret[2:] -= ret[:2] + return ret + + @staticmethod + # @jit(nopython=True) + def tlwh_to_tlbr(tlwh): + ret = np.asarray(tlwh).copy() + ret[2:] += ret[:2] + return ret + + def __repr__(self): + return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame) + + +class BYTETracker(object): + def __init__(self, track_thresh=0.45, match_thresh=0.8, track_buffer=25, frame_rate=30): + self.tracked_stracks = [] # type: list[STrack] + self.lost_stracks = [] # type: list[STrack] + self.removed_stracks = [] # type: list[STrack] + + self.frame_id = 0 + self.track_buffer=track_buffer + + self.track_thresh = track_thresh + self.match_thresh = match_thresh + self.det_thresh = track_thresh + 0.1 + self.buffer_size = int(frame_rate / 30.0 * track_buffer) + self.max_time_lost = self.buffer_size + self.kalman_filter = KalmanFilter() + + def update(self, dets, _): + self.frame_id += 1 + activated_starcks = [] + refind_stracks = [] + lost_stracks = [] + removed_stracks = [] + + xyxys = dets[:, 0:4] + xywh = xyxy2xywh(xyxys.numpy()) + confs = dets[:, 4] + clss = dets[:, 5] + + classes = clss.numpy() + xyxys = xyxys.numpy() + confs = confs.numpy() + + remain_inds = confs > self.track_thresh + inds_low = confs > 0.1 + inds_high = confs < self.track_thresh + + inds_second = np.logical_and(inds_low, inds_high) + + dets_second = xywh[inds_second] + dets = xywh[remain_inds] + + scores_keep = confs[remain_inds] + scores_second = confs[inds_second] + + clss_keep = classes[remain_inds] + clss_second = classes[inds_second] + + + if len(dets) > 0: + '''Detections''' + detections = [STrack(xyxy, s, c) for + (xyxy, s, c) in zip(dets, scores_keep, clss_keep)] + else: + detections = [] + + ''' Add newly detected tracklets to tracked_stracks''' + unconfirmed = [] + tracked_stracks = [] # type: list[STrack] + for track in self.tracked_stracks: + if not track.is_activated: + unconfirmed.append(track) + else: + tracked_stracks.append(track) + + ''' Step 2: First association, with high score detection boxes''' + strack_pool = joint_stracks(tracked_stracks, self.lost_stracks) + # Predict the current location with KF + STrack.multi_predict(strack_pool) + dists = matching.iou_distance(strack_pool, detections) + #if not self.args.mot20: + dists = matching.fuse_score(dists, detections) + matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.match_thresh) + + for itracked, idet in matches: + track = strack_pool[itracked] + det = detections[idet] + if track.state == TrackState.Tracked: + track.update(detections[idet], self.frame_id) + activated_starcks.append(track) + else: + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + + ''' Step 3: Second association, with low score detection boxes''' + # association the untrack to the low score detections + if len(dets_second) > 0: + '''Detections''' + detections_second = [STrack(xywh, s, c) for (xywh, s, c) in zip(dets_second, scores_second, clss_second)] + else: + detections_second = [] + r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked] + dists = matching.iou_distance(r_tracked_stracks, detections_second) + matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5) + for itracked, idet in matches: + track = r_tracked_stracks[itracked] + det = detections_second[idet] + if track.state == TrackState.Tracked: + track.update(det, self.frame_id) + activated_starcks.append(track) + else: + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + + for it in u_track: + track = r_tracked_stracks[it] + if not track.state == TrackState.Lost: + track.mark_lost() + lost_stracks.append(track) + + '''Deal with unconfirmed tracks, usually tracks with only one beginning frame''' + detections = [detections[i] for i in u_detection] + dists = matching.iou_distance(unconfirmed, detections) + #if not self.args.mot20: + dists = matching.fuse_score(dists, detections) + matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7) + for itracked, idet in matches: + unconfirmed[itracked].update(detections[idet], self.frame_id) + activated_starcks.append(unconfirmed[itracked]) + for it in u_unconfirmed: + track = unconfirmed[it] + track.mark_removed() + removed_stracks.append(track) + + """ Step 4: Init new stracks""" + for inew in u_detection: + track = detections[inew] + if track.score < self.det_thresh: + continue + track.activate(self.kalman_filter, self.frame_id) + activated_starcks.append(track) + """ Step 5: Update state""" + for track in self.lost_stracks: + if self.frame_id - track.end_frame > self.max_time_lost: + track.mark_removed() + removed_stracks.append(track) + + # print('Ramained match {} s'.format(t4-t3)) + + self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked] + self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks) + self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks) + self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks) + self.lost_stracks.extend(lost_stracks) + self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks) + self.removed_stracks.extend(removed_stracks) + self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks) + # get scores of lost tracks + output_stracks = [track for track in self.tracked_stracks if track.is_activated] + outputs = [] + for t in output_stracks: + output= [] + tlwh = t.tlwh + tid = t.track_id + tlwh = np.expand_dims(tlwh, axis=0) + xyxy = xywh2xyxy(tlwh) + xyxy = np.squeeze(xyxy, axis=0) + output.extend(xyxy) + output.append(tid) + output.append(t.cls) + output.append(t.score) + outputs.append(output) + + return outputs +#track_id, class_id, conf + +def joint_stracks(tlista, tlistb): + exists = {} + res = [] + for t in tlista: + exists[t.track_id] = 1 + res.append(t) + for t in tlistb: + tid = t.track_id + if not exists.get(tid, 0): + exists[tid] = 1 + res.append(t) + return res + + +def sub_stracks(tlista, tlistb): + stracks = {} + for t in tlista: + stracks[t.track_id] = t + for t in tlistb: + tid = t.track_id + if stracks.get(tid, 0): + del stracks[tid] + return list(stracks.values()) + + +def remove_duplicate_stracks(stracksa, stracksb): + pdist = matching.iou_distance(stracksa, stracksb) + pairs = np.where(pdist < 0.15) + dupa, dupb = list(), list() + for p, q in zip(*pairs): + timep = stracksa[p].frame_id - stracksa[p].start_frame + timeq = stracksb[q].frame_id - stracksb[q].start_frame + if timep > timeq: + dupb.append(q) + else: + dupa.append(p) + resa = [t for i, t in enumerate(stracksa) if not i in dupa] + resb = [t for i, t in enumerate(stracksb) if not i in dupb] + return resa, resb diff --git a/feeder/trackers/bytetrack/configs/bytetrack.yaml b/feeder/trackers/bytetrack/configs/bytetrack.yaml new file mode 100644 index 0000000..e81dd78 --- /dev/null +++ b/feeder/trackers/bytetrack/configs/bytetrack.yaml @@ -0,0 +1,7 @@ +bytetrack: + track_thresh: 0.6 # tracking confidence threshold + track_buffer: 30 # the frames for keep lost tracks + match_thresh: 0.8 # matching threshold for tracking + frame_rate: 30 # FPS + conf_thres: 0.5122620708221085 + diff --git a/feeder/trackers/bytetrack/kalman_filter.py b/feeder/trackers/bytetrack/kalman_filter.py new file mode 100644 index 0000000..deda8a2 --- /dev/null +++ b/feeder/trackers/bytetrack/kalman_filter.py @@ -0,0 +1,270 @@ +# vim: expandtab:ts=4:sw=4 +import numpy as np +import scipy.linalg + + +""" +Table for the 0.95 quantile of the chi-square distribution with N degrees of +freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv +function and used as Mahalanobis gating threshold. +""" +chi2inv95 = { + 1: 3.8415, + 2: 5.9915, + 3: 7.8147, + 4: 9.4877, + 5: 11.070, + 6: 12.592, + 7: 14.067, + 8: 15.507, + 9: 16.919} + + +class KalmanFilter(object): + """ + A simple Kalman filter for tracking bounding boxes in image space. + + The 8-dimensional state space + + x, y, a, h, vx, vy, va, vh + + contains the bounding box center position (x, y), aspect ratio a, height h, + and their respective velocities. + + Object motion follows a constant velocity model. The bounding box location + (x, y, a, h) is taken as direct observation of the state space (linear + observation model). + + """ + + def __init__(self): + ndim, dt = 4, 1. + + # Create Kalman filter model matrices. + self._motion_mat = np.eye(2 * ndim, 2 * ndim) + for i in range(ndim): + self._motion_mat[i, ndim + i] = dt + self._update_mat = np.eye(ndim, 2 * ndim) + + # Motion and observation uncertainty are chosen relative to the current + # state estimate. These weights control the amount of uncertainty in + # the model. This is a bit hacky. + self._std_weight_position = 1. / 20 + self._std_weight_velocity = 1. / 160 + + def initiate(self, measurement): + """Create track from unassociated measurement. + + Parameters + ---------- + measurement : ndarray + Bounding box coordinates (x, y, a, h) with center position (x, y), + aspect ratio a, and height h. + + Returns + ------- + (ndarray, ndarray) + Returns the mean vector (8 dimensional) and covariance matrix (8x8 + dimensional) of the new track. Unobserved velocities are initialized + to 0 mean. + + """ + mean_pos = measurement + mean_vel = np.zeros_like(mean_pos) + mean = np.r_[mean_pos, mean_vel] + + std = [ + 2 * self._std_weight_position * measurement[3], + 2 * self._std_weight_position * measurement[3], + 1e-2, + 2 * self._std_weight_position * measurement[3], + 10 * self._std_weight_velocity * measurement[3], + 10 * self._std_weight_velocity * measurement[3], + 1e-5, + 10 * self._std_weight_velocity * measurement[3]] + covariance = np.diag(np.square(std)) + return mean, covariance + + def predict(self, mean, covariance): + """Run Kalman filter prediction step. + + Parameters + ---------- + mean : ndarray + The 8 dimensional mean vector of the object state at the previous + time step. + covariance : ndarray + The 8x8 dimensional covariance matrix of the object state at the + previous time step. + + Returns + ------- + (ndarray, ndarray) + Returns the mean vector and covariance matrix of the predicted + state. Unobserved velocities are initialized to 0 mean. + + """ + std_pos = [ + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-2, + self._std_weight_position * mean[3]] + std_vel = [ + self._std_weight_velocity * mean[3], + self._std_weight_velocity * mean[3], + 1e-5, + self._std_weight_velocity * mean[3]] + motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) + + #mean = np.dot(self._motion_mat, mean) + mean = np.dot(mean, self._motion_mat.T) + covariance = np.linalg.multi_dot(( + self._motion_mat, covariance, self._motion_mat.T)) + motion_cov + + return mean, covariance + + def project(self, mean, covariance): + """Project state distribution to measurement space. + + Parameters + ---------- + mean : ndarray + The state's mean vector (8 dimensional array). + covariance : ndarray + The state's covariance matrix (8x8 dimensional). + + Returns + ------- + (ndarray, ndarray) + Returns the projected mean and covariance matrix of the given state + estimate. + + """ + std = [ + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-1, + self._std_weight_position * mean[3]] + innovation_cov = np.diag(np.square(std)) + + mean = np.dot(self._update_mat, mean) + covariance = np.linalg.multi_dot(( + self._update_mat, covariance, self._update_mat.T)) + return mean, covariance + innovation_cov + + def multi_predict(self, mean, covariance): + """Run Kalman filter prediction step (Vectorized version). + Parameters + ---------- + mean : ndarray + The Nx8 dimensional mean matrix of the object states at the previous + time step. + covariance : ndarray + The Nx8x8 dimensional covariance matrics of the object states at the + previous time step. + Returns + ------- + (ndarray, ndarray) + Returns the mean vector and covariance matrix of the predicted + state. Unobserved velocities are initialized to 0 mean. + """ + std_pos = [ + self._std_weight_position * mean[:, 3], + self._std_weight_position * mean[:, 3], + 1e-2 * np.ones_like(mean[:, 3]), + self._std_weight_position * mean[:, 3]] + std_vel = [ + self._std_weight_velocity * mean[:, 3], + self._std_weight_velocity * mean[:, 3], + 1e-5 * np.ones_like(mean[:, 3]), + self._std_weight_velocity * mean[:, 3]] + sqr = np.square(np.r_[std_pos, std_vel]).T + + motion_cov = [] + for i in range(len(mean)): + motion_cov.append(np.diag(sqr[i])) + motion_cov = np.asarray(motion_cov) + + mean = np.dot(mean, self._motion_mat.T) + left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2)) + covariance = np.dot(left, self._motion_mat.T) + motion_cov + + return mean, covariance + + def update(self, mean, covariance, measurement): + """Run Kalman filter correction step. + + Parameters + ---------- + mean : ndarray + The predicted state's mean vector (8 dimensional). + covariance : ndarray + The state's covariance matrix (8x8 dimensional). + measurement : ndarray + The 4 dimensional measurement vector (x, y, a, h), where (x, y) + is the center position, a the aspect ratio, and h the height of the + bounding box. + + Returns + ------- + (ndarray, ndarray) + Returns the measurement-corrected state distribution. + + """ + projected_mean, projected_cov = self.project(mean, covariance) + + chol_factor, lower = scipy.linalg.cho_factor( + projected_cov, lower=True, check_finite=False) + kalman_gain = scipy.linalg.cho_solve( + (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, + check_finite=False).T + innovation = measurement - projected_mean + + new_mean = mean + np.dot(innovation, kalman_gain.T) + new_covariance = covariance - np.linalg.multi_dot(( + kalman_gain, projected_cov, kalman_gain.T)) + return new_mean, new_covariance + + def gating_distance(self, mean, covariance, measurements, + only_position=False, metric='maha'): + """Compute gating distance between state distribution and measurements. + A suitable distance threshold can be obtained from `chi2inv95`. If + `only_position` is False, the chi-square distribution has 4 degrees of + freedom, otherwise 2. + Parameters + ---------- + mean : ndarray + Mean vector over the state distribution (8 dimensional). + covariance : ndarray + Covariance of the state distribution (8x8 dimensional). + measurements : ndarray + An Nx4 dimensional matrix of N measurements, each in + format (x, y, a, h) where (x, y) is the bounding box center + position, a the aspect ratio, and h the height. + only_position : Optional[bool] + If True, distance computation is done with respect to the bounding + box center position only. + Returns + ------- + ndarray + Returns an array of length N, where the i-th element contains the + squared Mahalanobis distance between (mean, covariance) and + `measurements[i]`. + """ + mean, covariance = self.project(mean, covariance) + if only_position: + mean, covariance = mean[:2], covariance[:2, :2] + measurements = measurements[:, :2] + + d = measurements - mean + if metric == 'gaussian': + return np.sum(d * d, axis=1) + elif metric == 'maha': + cholesky_factor = np.linalg.cholesky(covariance) + z = scipy.linalg.solve_triangular( + cholesky_factor, d.T, lower=True, check_finite=False, + overwrite_b=True) + squared_maha = np.sum(z * z, axis=0) + return squared_maha + else: + raise ValueError('invalid distance metric') \ No newline at end of file diff --git a/feeder/trackers/bytetrack/matching.py b/feeder/trackers/bytetrack/matching.py new file mode 100644 index 0000000..17d7498 --- /dev/null +++ b/feeder/trackers/bytetrack/matching.py @@ -0,0 +1,219 @@ +import cv2 +import numpy as np +import scipy +import lap +from scipy.spatial.distance import cdist + +from trackers.bytetrack import kalman_filter +import time + +def merge_matches(m1, m2, shape): + O,P,Q = shape + m1 = np.asarray(m1) + m2 = np.asarray(m2) + + M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P)) + M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q)) + + mask = M1*M2 + match = mask.nonzero() + match = list(zip(match[0], match[1])) + unmatched_O = tuple(set(range(O)) - set([i for i, j in match])) + unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match])) + + return match, unmatched_O, unmatched_Q + + +def _indices_to_matches(cost_matrix, indices, thresh): + matched_cost = cost_matrix[tuple(zip(*indices))] + matched_mask = (matched_cost <= thresh) + + matches = indices[matched_mask] + unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0])) + unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1])) + + return matches, unmatched_a, unmatched_b + + +def linear_assignment(cost_matrix, thresh): + if cost_matrix.size == 0: + return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) + matches, unmatched_a, unmatched_b = [], [], [] + cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) + for ix, mx in enumerate(x): + if mx >= 0: + matches.append([ix, mx]) + unmatched_a = np.where(x < 0)[0] + unmatched_b = np.where(y < 0)[0] + matches = np.asarray(matches) + return matches, unmatched_a, unmatched_b + + +def ious(atlbrs, btlbrs): + """ + Compute cost based on IoU + :type atlbrs: list[tlbr] | np.ndarray + :type atlbrs: list[tlbr] | np.ndarray + + :rtype ious np.ndarray + """ + ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) + if ious.size == 0: + return ious + + ious = bbox_ious( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32) + ) + + return ious + + +def iou_distance(atracks, btracks): + """ + Compute cost based on IoU + :type atracks: list[STrack] + :type btracks: list[STrack] + + :rtype cost_matrix np.ndarray + """ + + if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)): + atlbrs = atracks + btlbrs = btracks + else: + atlbrs = [track.tlbr for track in atracks] + btlbrs = [track.tlbr for track in btracks] + _ious = ious(atlbrs, btlbrs) + cost_matrix = 1 - _ious + + return cost_matrix + +def v_iou_distance(atracks, btracks): + """ + Compute cost based on IoU + :type atracks: list[STrack] + :type btracks: list[STrack] + + :rtype cost_matrix np.ndarray + """ + + if (len(atracks)>0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)): + atlbrs = atracks + btlbrs = btracks + else: + atlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in atracks] + btlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in btracks] + _ious = ious(atlbrs, btlbrs) + cost_matrix = 1 - _ious + + return cost_matrix + +def embedding_distance(tracks, detections, metric='cosine'): + """ + :param tracks: list[STrack] + :param detections: list[BaseTrack] + :param metric: + :return: cost_matrix np.ndarray + """ + + cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32) + if cost_matrix.size == 0: + return cost_matrix + det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32) + #for i, track in enumerate(tracks): + #cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric)) + track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32) + cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Nomalized features + return cost_matrix + + +def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False): + if cost_matrix.size == 0: + return cost_matrix + gating_dim = 2 if only_position else 4 + gating_threshold = kalman_filter.chi2inv95[gating_dim] + measurements = np.asarray([det.to_xyah() for det in detections]) + for row, track in enumerate(tracks): + gating_distance = kf.gating_distance( + track.mean, track.covariance, measurements, only_position) + cost_matrix[row, gating_distance > gating_threshold] = np.inf + return cost_matrix + + +def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98): + if cost_matrix.size == 0: + return cost_matrix + gating_dim = 2 if only_position else 4 + gating_threshold = kalman_filter.chi2inv95[gating_dim] + measurements = np.asarray([det.to_xyah() for det in detections]) + for row, track in enumerate(tracks): + gating_distance = kf.gating_distance( + track.mean, track.covariance, measurements, only_position, metric='maha') + cost_matrix[row, gating_distance > gating_threshold] = np.inf + cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance + return cost_matrix + + +def fuse_iou(cost_matrix, tracks, detections): + if cost_matrix.size == 0: + return cost_matrix + reid_sim = 1 - cost_matrix + iou_dist = iou_distance(tracks, detections) + iou_sim = 1 - iou_dist + fuse_sim = reid_sim * (1 + iou_sim) / 2 + det_scores = np.array([det.score for det in detections]) + det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) + #fuse_sim = fuse_sim * (1 + det_scores) / 2 + fuse_cost = 1 - fuse_sim + return fuse_cost + + +def fuse_score(cost_matrix, detections): + if cost_matrix.size == 0: + return cost_matrix + iou_sim = 1 - cost_matrix + det_scores = np.array([det.score for det in detections]) + det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) + fuse_sim = iou_sim * det_scores + fuse_cost = 1 - fuse_sim + return fuse_cost + + +def bbox_ious(boxes, query_boxes): + """ + Parameters + ---------- + boxes: (N, 4) ndarray of float + query_boxes: (K, 4) ndarray of float + Returns + ------- + overlaps: (N, K) ndarray of overlap between boxes and query_boxes + """ + N = boxes.shape[0] + K = query_boxes.shape[0] + overlaps = np.zeros((N, K), dtype=np.float32) + + for k in range(K): + box_area = ( + (query_boxes[k, 2] - query_boxes[k, 0] + 1) * + (query_boxes[k, 3] - query_boxes[k, 1] + 1) + ) + for n in range(N): + iw = ( + min(boxes[n, 2], query_boxes[k, 2]) - + max(boxes[n, 0], query_boxes[k, 0]) + 1 + ) + if iw > 0: + ih = ( + min(boxes[n, 3], query_boxes[k, 3]) - + max(boxes[n, 1], query_boxes[k, 1]) + 1 + ) + if ih > 0: + ua = float( + (boxes[n, 2] - boxes[n, 0] + 1) * + (boxes[n, 3] - boxes[n, 1] + 1) + + box_area - iw * ih + ) + overlaps[n, k] = iw * ih / ua + return overlaps \ No newline at end of file diff --git a/feeder/trackers/deepocsort/__init__.py b/feeder/trackers/deepocsort/__init__.py new file mode 100644 index 0000000..0c53de6 --- /dev/null +++ b/feeder/trackers/deepocsort/__init__.py @@ -0,0 +1,2 @@ +from . import args +from . import ocsort diff --git a/feeder/trackers/deepocsort/args.py b/feeder/trackers/deepocsort/args.py new file mode 100644 index 0000000..cfd34cc --- /dev/null +++ b/feeder/trackers/deepocsort/args.py @@ -0,0 +1,110 @@ +import argparse + + +def make_parser(): + parser = argparse.ArgumentParser("OC-SORT parameters") + + # distributed + parser.add_argument("-b", "--batch-size", type=int, default=1, help="batch size") + parser.add_argument("-d", "--devices", default=None, type=int, help="device for training") + + parser.add_argument("--local_rank", default=0, type=int, help="local rank for dist training") + parser.add_argument("--num_machines", default=1, type=int, help="num of node for training") + parser.add_argument("--machine_rank", default=0, type=int, help="node rank for multi-node training") + + parser.add_argument( + "-f", + "--exp_file", + default=None, + type=str, + help="pls input your expriment description file", + ) + parser.add_argument( + "--test", + dest="test", + default=False, + action="store_true", + help="Evaluating on test-dev set.", + ) + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + + # det args + parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval") + parser.add_argument("--conf", default=0.1, type=float, help="test conf") + parser.add_argument("--nms", default=0.7, type=float, help="test nms threshold") + parser.add_argument("--tsize", default=[800, 1440], nargs="+", type=int, help="test img size") + parser.add_argument("--seed", default=None, type=int, help="eval seed") + + # tracking args + parser.add_argument("--track_thresh", type=float, default=0.6, help="detection confidence threshold") + parser.add_argument( + "--iou_thresh", + type=float, + default=0.3, + help="the iou threshold in Sort for matching", + ) + parser.add_argument("--min_hits", type=int, default=3, help="min hits to create track in SORT") + parser.add_argument( + "--inertia", + type=float, + default=0.2, + help="the weight of VDC term in cost matrix", + ) + parser.add_argument( + "--deltat", + type=int, + default=3, + help="time step difference to estimate direction", + ) + parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks") + parser.add_argument( + "--match_thresh", + type=float, + default=0.9, + help="matching threshold for tracking", + ) + parser.add_argument( + "--gt-type", + type=str, + default="_val_half", + help="suffix to find the gt annotation", + ) + parser.add_argument("--public", action="store_true", help="use public detection") + parser.add_argument("--asso", default="iou", help="similarity function: iou/giou/diou/ciou/ctdis") + + # for kitti/bdd100k inference with public detections + parser.add_argument( + "--raw_results_path", + type=str, + default="exps/permatrack_kitti_test/", + help="path to the raw tracking results from other tracks", + ) + parser.add_argument("--out_path", type=str, help="path to save output results") + parser.add_argument( + "--hp", + action="store_true", + help="use head padding to add the missing objects during \ + initializing the tracks (offline).", + ) + + # for demo video + parser.add_argument("--demo_type", default="image", help="demo type, eg. image, video and webcam") + parser.add_argument("--path", default="./videos/demo.mp4", help="path to images or video") + parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id") + parser.add_argument( + "--save_result", + action="store_true", + help="whether to save the inference result of image/video", + ) + parser.add_argument( + "--device", + default="gpu", + type=str, + help="device to run our model, can either be cpu or gpu", + ) + return parser diff --git a/feeder/trackers/deepocsort/association.py b/feeder/trackers/deepocsort/association.py new file mode 100644 index 0000000..a84c296 --- /dev/null +++ b/feeder/trackers/deepocsort/association.py @@ -0,0 +1,445 @@ +import os +import pdb + +import numpy as np +from scipy.special import softmax + + +def iou_batch(bboxes1, bboxes2): + """ + From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2] + """ + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + wh = w * h + o = wh / ( + (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) + - wh + ) + return o + + +def giou_batch(bboxes1, bboxes2): + """ + :param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2) + :param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2) + :return: + """ + # for details should go to https://arxiv.org/pdf/1902.09630.pdf + # ensure predict's bbox form + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + wh = w * h + iou = wh / ( + (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) + - wh + ) + + xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0]) + yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1]) + xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2]) + yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3]) + wc = xxc2 - xxc1 + hc = yyc2 - yyc1 + assert (wc > 0).all() and (hc > 0).all() + area_enclose = wc * hc + giou = iou - (area_enclose - wh) / area_enclose + giou = (giou + 1.0) / 2.0 # resize from (-1,1) to (0,1) + return giou + + +def diou_batch(bboxes1, bboxes2): + """ + :param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2) + :param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2) + :return: + """ + # for details should go to https://arxiv.org/pdf/1902.09630.pdf + # ensure predict's bbox form + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + # calculate the intersection box + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + wh = w * h + iou = wh / ( + (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) + - wh + ) + + centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0 + centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0 + centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0 + centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0 + + inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2 + + xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0]) + yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1]) + xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2]) + yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3]) + + outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2 + diou = iou - inner_diag / outer_diag + + return (diou + 1) / 2.0 # resize from (-1,1) to (0,1) + + +def ciou_batch(bboxes1, bboxes2): + """ + :param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2) + :param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2) + :return: + """ + # for details should go to https://arxiv.org/pdf/1902.09630.pdf + # ensure predict's bbox form + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + # calculate the intersection box + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + wh = w * h + iou = wh / ( + (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) + - wh + ) + + centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0 + centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0 + centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0 + centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0 + + inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2 + + xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0]) + yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1]) + xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2]) + yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3]) + + outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2 + + w1 = bboxes1[..., 2] - bboxes1[..., 0] + h1 = bboxes1[..., 3] - bboxes1[..., 1] + w2 = bboxes2[..., 2] - bboxes2[..., 0] + h2 = bboxes2[..., 3] - bboxes2[..., 1] + + # prevent dividing over zero. add one pixel shift + h2 = h2 + 1.0 + h1 = h1 + 1.0 + arctan = np.arctan(w2 / h2) - np.arctan(w1 / h1) + v = (4 / (np.pi**2)) * (arctan**2) + S = 1 - iou + alpha = v / (S + v) + ciou = iou - inner_diag / outer_diag - alpha * v + + return (ciou + 1) / 2.0 # resize from (-1,1) to (0,1) + + +def ct_dist(bboxes1, bboxes2): + """ + Measure the center distance between two sets of bounding boxes, + this is a coarse implementation, we don't recommend using it only + for association, which can be unstable and sensitive to frame rate + and object speed. + """ + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0 + centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0 + centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0 + centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0 + + ct_dist2 = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2 + + ct_dist = np.sqrt(ct_dist2) + + # The linear rescaling is a naive version and needs more study + ct_dist = ct_dist / ct_dist.max() + return ct_dist.max() - ct_dist # resize to (0,1) + + +def speed_direction_batch(dets, tracks): + tracks = tracks[..., np.newaxis] + CX1, CY1 = (dets[:, 0] + dets[:, 2]) / 2.0, (dets[:, 1] + dets[:, 3]) / 2.0 + CX2, CY2 = (tracks[:, 0] + tracks[:, 2]) / 2.0, (tracks[:, 1] + tracks[:, 3]) / 2.0 + dx = CX1 - CX2 + dy = CY1 - CY2 + norm = np.sqrt(dx**2 + dy**2) + 1e-6 + dx = dx / norm + dy = dy / norm + return dy, dx # size: num_track x num_det + + +def linear_assignment(cost_matrix): + try: + import lap + + _, x, y = lap.lapjv(cost_matrix, extend_cost=True) + return np.array([[y[i], i] for i in x if i >= 0]) # + except ImportError: + from scipy.optimize import linear_sum_assignment + + x, y = linear_sum_assignment(cost_matrix) + return np.array(list(zip(x, y))) + + +def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3): + """ + Assigns detections to tracked object (both represented as bounding boxes) + Returns 3 lists of matches, unmatched_detections and unmatched_trackers + """ + if len(trackers) == 0: + return ( + np.empty((0, 2), dtype=int), + np.arange(len(detections)), + np.empty((0, 5), dtype=int), + ) + + iou_matrix = iou_batch(detections, trackers) + + if min(iou_matrix.shape) > 0: + a = (iou_matrix > iou_threshold).astype(np.int32) + if a.sum(1).max() == 1 and a.sum(0).max() == 1: + matched_indices = np.stack(np.where(a), axis=1) + else: + matched_indices = linear_assignment(-iou_matrix) + else: + matched_indices = np.empty(shape=(0, 2)) + + unmatched_detections = [] + for d, det in enumerate(detections): + if d not in matched_indices[:, 0]: + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(trackers): + if t not in matched_indices[:, 1]: + unmatched_trackers.append(t) + + # filter out matched with low IOU + matches = [] + for m in matched_indices: + if iou_matrix[m[0], m[1]] < iou_threshold: + unmatched_detections.append(m[0]) + unmatched_trackers.append(m[1]) + else: + matches.append(m.reshape(1, 2)) + if len(matches) == 0: + matches = np.empty((0, 2), dtype=int) + else: + matches = np.concatenate(matches, axis=0) + + return matches, np.array(unmatched_detections), np.array(unmatched_trackers) + + +def compute_aw_max_metric(emb_cost, w_association_emb, bottom=0.5): + w_emb = np.full_like(emb_cost, w_association_emb) + + for idx in range(emb_cost.shape[0]): + inds = np.argsort(-emb_cost[idx]) + # If there's less than two matches, just keep original weight + if len(inds) < 2: + continue + if emb_cost[idx, inds[0]] == 0: + row_weight = 0 + else: + row_weight = 1 - max((emb_cost[idx, inds[1]] / emb_cost[idx, inds[0]]) - bottom, 0) / (1 - bottom) + w_emb[idx] *= row_weight + + for idj in range(emb_cost.shape[1]): + inds = np.argsort(-emb_cost[:, idj]) + # If there's less than two matches, just keep original weight + if len(inds) < 2: + continue + if emb_cost[inds[0], idj] == 0: + col_weight = 0 + else: + col_weight = 1 - max((emb_cost[inds[1], idj] / emb_cost[inds[0], idj]) - bottom, 0) / (1 - bottom) + w_emb[:, idj] *= col_weight + + return w_emb * emb_cost + + +def associate( + detections, trackers, iou_threshold, velocities, previous_obs, vdc_weight, emb_cost, w_assoc_emb, aw_off, aw_param +): + if len(trackers) == 0: + return ( + np.empty((0, 2), dtype=int), + np.arange(len(detections)), + np.empty((0, 5), dtype=int), + ) + + Y, X = speed_direction_batch(detections, previous_obs) + inertia_Y, inertia_X = velocities[:, 0], velocities[:, 1] + inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1) + inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1) + diff_angle_cos = inertia_X * X + inertia_Y * Y + diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1) + diff_angle = np.arccos(diff_angle_cos) + diff_angle = (np.pi / 2.0 - np.abs(diff_angle)) / np.pi + + valid_mask = np.ones(previous_obs.shape[0]) + valid_mask[np.where(previous_obs[:, 4] < 0)] = 0 + + iou_matrix = iou_batch(detections, trackers) + scores = np.repeat(detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1) + # iou_matrix = iou_matrix * scores # a trick sometiems works, we don't encourage this + valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1) + + angle_diff_cost = (valid_mask * diff_angle) * vdc_weight + angle_diff_cost = angle_diff_cost.T + angle_diff_cost = angle_diff_cost * scores + + if min(iou_matrix.shape) > 0: + a = (iou_matrix > iou_threshold).astype(np.int32) + if a.sum(1).max() == 1 and a.sum(0).max() == 1: + matched_indices = np.stack(np.where(a), axis=1) + else: + if emb_cost is None: + emb_cost = 0 + else: + emb_cost = emb_cost.cpu().numpy() + emb_cost[iou_matrix <= 0] = 0 + if not aw_off: + emb_cost = compute_aw_max_metric(emb_cost, w_assoc_emb, bottom=aw_param) + else: + emb_cost *= w_assoc_emb + + final_cost = -(iou_matrix + angle_diff_cost + emb_cost) + matched_indices = linear_assignment(final_cost) + else: + matched_indices = np.empty(shape=(0, 2)) + + unmatched_detections = [] + for d, det in enumerate(detections): + if d not in matched_indices[:, 0]: + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(trackers): + if t not in matched_indices[:, 1]: + unmatched_trackers.append(t) + + # filter out matched with low IOU + matches = [] + for m in matched_indices: + if iou_matrix[m[0], m[1]] < iou_threshold: + unmatched_detections.append(m[0]) + unmatched_trackers.append(m[1]) + else: + matches.append(m.reshape(1, 2)) + if len(matches) == 0: + matches = np.empty((0, 2), dtype=int) + else: + matches = np.concatenate(matches, axis=0) + + return matches, np.array(unmatched_detections), np.array(unmatched_trackers) + + +def associate_kitti(detections, trackers, det_cates, iou_threshold, velocities, previous_obs, vdc_weight): + if len(trackers) == 0: + return ( + np.empty((0, 2), dtype=int), + np.arange(len(detections)), + np.empty((0, 5), dtype=int), + ) + + """ + Cost from the velocity direction consistency + """ + Y, X = speed_direction_batch(detections, previous_obs) + inertia_Y, inertia_X = velocities[:, 0], velocities[:, 1] + inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1) + inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1) + diff_angle_cos = inertia_X * X + inertia_Y * Y + diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1) + diff_angle = np.arccos(diff_angle_cos) + diff_angle = (np.pi / 2.0 - np.abs(diff_angle)) / np.pi + + valid_mask = np.ones(previous_obs.shape[0]) + valid_mask[np.where(previous_obs[:, 4] < 0)] = 0 + valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1) + + scores = np.repeat(detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1) + angle_diff_cost = (valid_mask * diff_angle) * vdc_weight + angle_diff_cost = angle_diff_cost.T + angle_diff_cost = angle_diff_cost * scores + + """ + Cost from IoU + """ + iou_matrix = iou_batch(detections, trackers) + + """ + With multiple categories, generate the cost for catgory mismatch + """ + num_dets = detections.shape[0] + num_trk = trackers.shape[0] + cate_matrix = np.zeros((num_dets, num_trk)) + for i in range(num_dets): + for j in range(num_trk): + if det_cates[i] != trackers[j, 4]: + cate_matrix[i][j] = -1e6 + + cost_matrix = -iou_matrix - angle_diff_cost - cate_matrix + + if min(iou_matrix.shape) > 0: + a = (iou_matrix > iou_threshold).astype(np.int32) + if a.sum(1).max() == 1 and a.sum(0).max() == 1: + matched_indices = np.stack(np.where(a), axis=1) + else: + matched_indices = linear_assignment(cost_matrix) + else: + matched_indices = np.empty(shape=(0, 2)) + + unmatched_detections = [] + for d, det in enumerate(detections): + if d not in matched_indices[:, 0]: + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(trackers): + if t not in matched_indices[:, 1]: + unmatched_trackers.append(t) + + # filter out matched with low IOU + matches = [] + for m in matched_indices: + if iou_matrix[m[0], m[1]] < iou_threshold: + unmatched_detections.append(m[0]) + unmatched_trackers.append(m[1]) + else: + matches.append(m.reshape(1, 2)) + if len(matches) == 0: + matches = np.empty((0, 2), dtype=int) + else: + matches = np.concatenate(matches, axis=0) + + return matches, np.array(unmatched_detections), np.array(unmatched_trackers) diff --git a/feeder/trackers/deepocsort/cmc.py b/feeder/trackers/deepocsort/cmc.py new file mode 100644 index 0000000..13d771f --- /dev/null +++ b/feeder/trackers/deepocsort/cmc.py @@ -0,0 +1,170 @@ +import pdb +import pickle +import os + +import cv2 +import numpy as np + + +class CMCComputer: + def __init__(self, minimum_features=10, method="sparse"): + assert method in ["file", "sparse", "sift"] + + os.makedirs("./cache", exist_ok=True) + self.cache_path = "./cache/affine_ocsort.pkl" + self.cache = {} + if os.path.exists(self.cache_path): + with open(self.cache_path, "rb") as fp: + self.cache = pickle.load(fp) + self.minimum_features = minimum_features + self.prev_img = None + self.prev_desc = None + self.sparse_flow_param = dict( + maxCorners=3000, + qualityLevel=0.01, + minDistance=1, + blockSize=3, + useHarrisDetector=False, + k=0.04, + ) + self.file_computed = {} + + self.comp_function = None + if method == "sparse": + self.comp_function = self._affine_sparse_flow + elif method == "sift": + self.comp_function = self._affine_sift + # Same BoT-SORT CMC arrays + elif method == "file": + self.comp_function = self._affine_file + self.file_affines = {} + # Maps from tag name to file name + self.file_names = {} + + # All the ablation file names + for f_name in os.listdir("./cache/cmc_files/MOT17_ablation/"): + # The tag that'll be passed into compute_affine based on image name + tag = f_name.replace("GMC-", "").replace(".txt", "") + "-FRCNN" + f_name = os.path.join("./cache/cmc_files/MOT17_ablation/", f_name) + self.file_names[tag] = f_name + for f_name in os.listdir("./cache/cmc_files/MOT20_ablation/"): + tag = f_name.replace("GMC-", "").replace(".txt", "") + f_name = os.path.join("./cache/cmc_files/MOT20_ablation/", f_name) + self.file_names[tag] = f_name + + # All the test file names + for f_name in os.listdir("./cache/cmc_files/MOTChallenge/"): + tag = f_name.replace("GMC-", "").replace(".txt", "") + if "MOT17" in tag: + tag = tag + "-FRCNN" + # If it's an ablation one (not test) don't overwrite it + if tag in self.file_names: + continue + f_name = os.path.join("./cache/cmc_files/MOTChallenge/", f_name) + self.file_names[tag] = f_name + + def compute_affine(self, img, bbox, tag): + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + if tag in self.cache: + A = self.cache[tag] + return A + mask = np.ones_like(img, dtype=np.uint8) + if bbox.shape[0] > 0: + bbox = np.round(bbox).astype(np.int32) + bbox[bbox < 0] = 0 + for bb in bbox: + mask[bb[1] : bb[3], bb[0] : bb[2]] = 0 + + A = self.comp_function(img, mask, tag) + self.cache[tag] = A + + return A + + def _load_file(self, name): + affines = [] + with open(self.file_names[name], "r") as fp: + for line in fp: + tokens = [float(f) for f in line.split("\t")[1:7]] + A = np.eye(2, 3) + A[0, 0] = tokens[0] + A[0, 1] = tokens[1] + A[0, 2] = tokens[2] + A[1, 0] = tokens[3] + A[1, 1] = tokens[4] + A[1, 2] = tokens[5] + affines.append(A) + self.file_affines[name] = affines + + def _affine_file(self, frame, mask, tag): + name, num = tag.split(":") + if name not in self.file_affines: + self._load_file(name) + if name not in self.file_affines: + raise RuntimeError("Error loading file affines for CMC.") + + return self.file_affines[name][int(num) - 1] + + def _affine_sift(self, frame, mask, tag): + A = np.eye(2, 3) + detector = cv2.SIFT_create() + kp, desc = detector.detectAndCompute(frame, mask) + if self.prev_desc is None: + self.prev_desc = [kp, desc] + return A + if desc.shape[0] < self.minimum_features or self.prev_desc[1].shape[0] < self.minimum_features: + return A + + bf = cv2.BFMatcher(cv2.NORM_L2) + matches = bf.knnMatch(self.prev_desc[1], desc, k=2) + good = [] + for m, n in matches: + if m.distance < 0.7 * n.distance: + good.append(m) + + if len(good) > self.minimum_features: + src_pts = np.float32([self.prev_desc[0][m.queryIdx].pt for m in good]).reshape(-1, 1, 2) + dst_pts = np.float32([kp[m.trainIdx].pt for m in good]).reshape(-1, 1, 2) + A, _ = cv2.estimateAffinePartial2D(src_pts, dst_pts, method=cv2.RANSAC) + else: + print("Warning: not enough matching points") + if A is None: + A = np.eye(2, 3) + + self.prev_desc = [kp, desc] + return A + + def _affine_sparse_flow(self, frame, mask, tag): + # Initialize + A = np.eye(2, 3) + + # find the keypoints + keypoints = cv2.goodFeaturesToTrack(frame, mask=mask, **self.sparse_flow_param) + + # Handle first frame + if self.prev_img is None: + self.prev_img = frame + self.prev_desc = keypoints + return A + + matched_kp, status, err = cv2.calcOpticalFlowPyrLK(self.prev_img, frame, self.prev_desc, None) + matched_kp = matched_kp.reshape(-1, 2) + status = status.reshape(-1) + prev_points = self.prev_desc.reshape(-1, 2) + prev_points = prev_points[status] + curr_points = matched_kp[status] + + # Find rigid matrix + if prev_points.shape[0] > self.minimum_features: + A, _ = cv2.estimateAffinePartial2D(prev_points, curr_points, method=cv2.RANSAC) + else: + print("Warning: not enough matching points") + if A is None: + A = np.eye(2, 3) + + self.prev_img = frame + self.prev_desc = keypoints + return A + + def dump_cache(self): + with open(self.cache_path, "wb") as fp: + pickle.dump(self.cache, fp) diff --git a/feeder/trackers/deepocsort/configs/deepocsort.yaml b/feeder/trackers/deepocsort/configs/deepocsort.yaml new file mode 100644 index 0000000..dfa34fa --- /dev/null +++ b/feeder/trackers/deepocsort/configs/deepocsort.yaml @@ -0,0 +1,12 @@ +# Trial number: 137 +# HOTA, MOTA, IDF1: [55.567] +deepocsort: + asso_func: giou + conf_thres: 0.5122620708221085 + delta_t: 1 + det_thresh: 0 + inertia: 0.3941737016672115 + iou_thresh: 0.22136877277096445 + max_age: 50 + min_hits: 1 + use_byte: false diff --git a/feeder/trackers/deepocsort/embedding.py b/feeder/trackers/deepocsort/embedding.py new file mode 100644 index 0000000..bbef156 --- /dev/null +++ b/feeder/trackers/deepocsort/embedding.py @@ -0,0 +1,116 @@ +import pdb +from collections import OrderedDict +import os +import pickle + +import torch +import cv2 +import torchvision +import numpy as np + + + +class EmbeddingComputer: + def __init__(self, dataset): + self.model = None + self.dataset = dataset + self.crop_size = (128, 384) + os.makedirs("./cache/embeddings/", exist_ok=True) + self.cache_path = "./cache/embeddings/{}_embedding.pkl" + self.cache = {} + self.cache_name = "" + + def load_cache(self, path): + self.cache_name = path + cache_path = self.cache_path.format(path) + if os.path.exists(cache_path): + with open(cache_path, "rb") as fp: + self.cache = pickle.load(fp) + + def compute_embedding(self, img, bbox, tag, is_numpy=True): + if self.cache_name != tag.split(":")[0]: + self.load_cache(tag.split(":")[0]) + + if tag in self.cache: + embs = self.cache[tag] + if embs.shape[0] != bbox.shape[0]: + raise RuntimeError( + "ERROR: The number of cached embeddings don't match the " + "number of detections.\nWas the detector model changed? Delete cache if so." + ) + return embs + + if self.model is None: + self.initialize_model() + + # Make sure bbox is within image frame + if is_numpy: + h, w = img.shape[:2] + else: + h, w = img.shape[2:] + results = np.round(bbox).astype(np.int32) + results[:, 0] = results[:, 0].clip(0, w) + results[:, 1] = results[:, 1].clip(0, h) + results[:, 2] = results[:, 2].clip(0, w) + results[:, 3] = results[:, 3].clip(0, h) + + # Generate all the crops + crops = [] + for p in results: + if is_numpy: + crop = img[p[1] : p[3], p[0] : p[2]] + crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB) + crop = cv2.resize(crop, self.crop_size, interpolation=cv2.INTER_LINEAR) + crop = torch.as_tensor(crop.astype("float32").transpose(2, 0, 1)) + crop = crop.unsqueeze(0) + else: + crop = img[:, :, p[1] : p[3], p[0] : p[2]] + crop = torchvision.transforms.functional.resize(crop, self.crop_size) + + crops.append(crop) + + crops = torch.cat(crops, dim=0) + + # Create embeddings and l2 normalize them + with torch.no_grad(): + crops = crops.cuda() + crops = crops.half() + embs = self.model(crops) + embs = torch.nn.functional.normalize(embs) + embs = embs.cpu().numpy() + + self.cache[tag] = embs + return embs + + def initialize_model(self): + """ + model = torchreid.models.build_model(name="osnet_ain_x1_0", num_classes=2510, loss="softmax", pretrained=False) + sd = torch.load("external/weights/osnet_ain_ms_d_c.pth.tar")["state_dict"] + new_state_dict = OrderedDict() + for k, v in sd.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + # load params + model.load_state_dict(new_state_dict) + model.eval() + model.cuda() + """ + if self.dataset == "mot17": + path = "external/weights/mot17_sbs_S50.pth" + elif self.dataset == "mot20": + path = "external/weights/mot20_sbs_S50.pth" + elif self.dataset == "dance": + path = None + else: + raise RuntimeError("Need the path for a new ReID model.") + + model = FastReID(path) + model.eval() + model.cuda() + model.half() + self.model = model + + def dump_cache(self): + if self.cache_name: + with open(self.cache_path.format(self.cache_name), "wb") as fp: + pickle.dump(self.cache, fp) diff --git a/feeder/trackers/deepocsort/kalmanfilter.py b/feeder/trackers/deepocsort/kalmanfilter.py new file mode 100644 index 0000000..19e0427 --- /dev/null +++ b/feeder/trackers/deepocsort/kalmanfilter.py @@ -0,0 +1,1636 @@ +# -*- coding: utf-8 -*- +# pylint: disable=invalid-name, too-many-arguments, too-many-branches, +# pylint: disable=too-many-locals, too-many-instance-attributes, too-many-lines + +""" +This module implements the linear Kalman filter in both an object +oriented and procedural form. The KalmanFilter class implements +the filter by storing the various matrices in instance variables, +minimizing the amount of bookkeeping you have to do. +All Kalman filters operate with a predict->update cycle. The +predict step, implemented with the method or function predict(), +uses the state transition matrix F to predict the state in the next +time period (epoch). The state is stored as a gaussian (x, P), where +x is the state (column) vector, and P is its covariance. Covariance +matrix Q specifies the process covariance. In Bayesian terms, this +prediction is called the *prior*, which you can think of colloquially +as the estimate prior to incorporating the measurement. +The update step, implemented with the method or function `update()`, +incorporates the measurement z with covariance R, into the state +estimate (x, P). The class stores the system uncertainty in S, +the innovation (residual between prediction and measurement in +measurement space) in y, and the Kalman gain in k. The procedural +form returns these variables to you. In Bayesian terms this computes +the *posterior* - the estimate after the information from the +measurement is incorporated. +Whether you use the OO form or procedural form is up to you. If +matrices such as H, R, and F are changing each epoch, you'll probably +opt to use the procedural form. If they are unchanging, the OO +form is perhaps easier to use since you won't need to keep track +of these matrices. This is especially useful if you are implementing +banks of filters or comparing various KF designs for performance; +a trivial coding bug could lead to using the wrong sets of matrices. +This module also offers an implementation of the RTS smoother, and +other helper functions, such as log likelihood computations. +The Saver class allows you to easily save the state of the +KalmanFilter class after every update +This module expects NumPy arrays for all values that expect +arrays, although in a few cases, particularly method parameters, +it will accept types that convert to NumPy arrays, such as lists +of lists. These exceptions are documented in the method or function. +Examples +-------- +The following example constructs a constant velocity kinematic +filter, filters noisy data, and plots the results. It also demonstrates +using the Saver class to save the state of the filter at each epoch. +.. code-block:: Python + import matplotlib.pyplot as plt + import numpy as np + from filterpy.kalman import KalmanFilter + from filterpy.common import Q_discrete_white_noise, Saver + r_std, q_std = 2., 0.003 + cv = KalmanFilter(dim_x=2, dim_z=1) + cv.x = np.array([[0., 1.]]) # position, velocity + cv.F = np.array([[1, dt],[ [0, 1]]) + cv.R = np.array([[r_std^^2]]) + f.H = np.array([[1., 0.]]) + f.P = np.diag([.1^^2, .03^^2) + f.Q = Q_discrete_white_noise(2, dt, q_std**2) + saver = Saver(cv) + for z in range(100): + cv.predict() + cv.update([z + randn() * r_std]) + saver.save() # save the filter's state + saver.to_array() + plt.plot(saver.x[:, 0]) + # plot all of the priors + plt.plot(saver.x_prior[:, 0]) + # plot mahalanobis distance + plt.figure() + plt.plot(saver.mahalanobis) +This code implements the same filter using the procedural form + x = np.array([[0., 1.]]) # position, velocity + F = np.array([[1, dt],[ [0, 1]]) + R = np.array([[r_std^^2]]) + H = np.array([[1., 0.]]) + P = np.diag([.1^^2, .03^^2) + Q = Q_discrete_white_noise(2, dt, q_std**2) + for z in range(100): + x, P = predict(x, P, F=F, Q=Q) + x, P = update(x, P, z=[z + randn() * r_std], R=R, H=H) + xs.append(x[0, 0]) + plt.plot(xs) +For more examples see the test subdirectory, or refer to the +book cited below. In it I both teach Kalman filtering from basic +principles, and teach the use of this library in great detail. +FilterPy library. +http://github.com/rlabbe/filterpy +Documentation at: +https://filterpy.readthedocs.org +Supporting book at: +https://github.com/rlabbe/Kalman-and-Bayesian-Filters-in-Python +This is licensed under an MIT license. See the readme.MD file +for more information. +Copyright 2014-2018 Roger R Labbe Jr. +""" + +from __future__ import absolute_import, division + +import pdb +from copy import deepcopy +from math import log, exp, sqrt +import sys +import numpy as np +from numpy import dot, zeros, eye, isscalar, shape +import numpy.linalg as linalg +from filterpy.stats import logpdf +from filterpy.common import pretty_str, reshape_z + + +class KalmanFilterNew(object): + """Implements a Kalman filter. You are responsible for setting the + various state variables to reasonable values; the defaults will + not give you a functional filter. + For now the best documentation is my free book Kalman and Bayesian + Filters in Python [2]_. The test files in this directory also give you a + basic idea of use, albeit without much description. + In brief, you will first construct this object, specifying the size of + the state vector with dim_x and the size of the measurement vector that + you will be using with dim_z. These are mostly used to perform size checks + when you assign values to the various matrices. For example, if you + specified dim_z=2 and then try to assign a 3x3 matrix to R (the + measurement noise matrix you will get an assert exception because R + should be 2x2. (If for whatever reason you need to alter the size of + things midstream just use the underscore version of the matrices to + assign directly: your_filter._R = a_3x3_matrix.) + After construction the filter will have default matrices created for you, + but you must specify the values for each. It’s usually easiest to just + overwrite them rather than assign to each element yourself. This will be + clearer in the example below. All are of type numpy.array. + Examples + -------- + Here is a filter that tracks position and velocity using a sensor that only + reads position. + First construct the object with the required dimensionality. Here the state + (`dim_x`) has 2 coefficients (position and velocity), and the measurement + (`dim_z`) has one. In FilterPy `x` is the state, `z` is the measurement. + .. code:: + from filterpy.kalman import KalmanFilter + f = KalmanFilter (dim_x=2, dim_z=1) + Assign the initial value for the state (position and velocity). You can do this + with a two dimensional array like so: + .. code:: + f.x = np.array([[2.], # position + [0.]]) # velocity + or just use a one dimensional array, which I prefer doing. + .. code:: + f.x = np.array([2., 0.]) + Define the state transition matrix: + .. code:: + f.F = np.array([[1.,1.], + [0.,1.]]) + Define the measurement function. Here we need to convert a position-velocity + vector into just a position vector, so we use: + .. code:: + f.H = np.array([[1., 0.]]) + Define the state's covariance matrix P. + .. code:: + f.P = np.array([[1000., 0.], + [ 0., 1000.] ]) + Now assign the measurement noise. Here the dimension is 1x1, so I can + use a scalar + .. code:: + f.R = 5 + I could have done this instead: + .. code:: + f.R = np.array([[5.]]) + Note that this must be a 2 dimensional array. + Finally, I will assign the process noise. Here I will take advantage of + another FilterPy library function: + .. code:: + from filterpy.common import Q_discrete_white_noise + f.Q = Q_discrete_white_noise(dim=2, dt=0.1, var=0.13) + Now just perform the standard predict/update loop: + .. code:: + while some_condition_is_true: + z = get_sensor_reading() + f.predict() + f.update(z) + do_something_with_estimate (f.x) + **Procedural Form** + This module also contains stand alone functions to perform Kalman filtering. + Use these if you are not a fan of objects. + **Example** + .. code:: + while True: + z, R = read_sensor() + x, P = predict(x, P, F, Q) + x, P = update(x, P, z, R, H) + See my book Kalman and Bayesian Filters in Python [2]_. + You will have to set the following attributes after constructing this + object for the filter to perform properly. Please note that there are + various checks in place to ensure that you have made everything the + 'correct' size. However, it is possible to provide incorrectly sized + arrays such that the linear algebra can not perform an operation. + It can also fail silently - you can end up with matrices of a size that + allows the linear algebra to work, but are the wrong shape for the problem + you are trying to solve. + Parameters + ---------- + dim_x : int + Number of state variables for the Kalman filter. For example, if + you are tracking the position and velocity of an object in two + dimensions, dim_x would be 4. + This is used to set the default size of P, Q, and u + dim_z : int + Number of of measurement inputs. For example, if the sensor + provides you with position in (x,y), dim_z would be 2. + dim_u : int (optional) + size of the control input, if it is being used. + Default value of 0 indicates it is not used. + compute_log_likelihood : bool (default = True) + Computes log likelihood by default, but this can be a slow + computation, so if you never use it you can turn this computation + off. + Attributes + ---------- + x : numpy.array(dim_x, 1) + Current state estimate. Any call to update() or predict() updates + this variable. + P : numpy.array(dim_x, dim_x) + Current state covariance matrix. Any call to update() or predict() + updates this variable. + x_prior : numpy.array(dim_x, 1) + Prior (predicted) state estimate. The *_prior and *_post attributes + are for convenience; they store the prior and posterior of the + current epoch. Read Only. + P_prior : numpy.array(dim_x, dim_x) + Prior (predicted) state covariance matrix. Read Only. + x_post : numpy.array(dim_x, 1) + Posterior (updated) state estimate. Read Only. + P_post : numpy.array(dim_x, dim_x) + Posterior (updated) state covariance matrix. Read Only. + z : numpy.array + Last measurement used in update(). Read only. + R : numpy.array(dim_z, dim_z) + Measurement noise covariance matrix. Also known as the + observation covariance. + Q : numpy.array(dim_x, dim_x) + Process noise covariance matrix. Also known as the transition + covariance. + F : numpy.array() + State Transition matrix. Also known as `A` in some formulation. + H : numpy.array(dim_z, dim_x) + Measurement function. Also known as the observation matrix, or as `C`. + y : numpy.array + Residual of the update step. Read only. + K : numpy.array(dim_x, dim_z) + Kalman gain of the update step. Read only. + S : numpy.array + System uncertainty (P projected to measurement space). Read only. + SI : numpy.array + Inverse system uncertainty. Read only. + log_likelihood : float + log-likelihood of the last measurement. Read only. + likelihood : float + likelihood of last measurement. Read only. + Computed from the log-likelihood. The log-likelihood can be very + small, meaning a large negative value such as -28000. Taking the + exp() of that results in 0.0, which can break typical algorithms + which multiply by this value, so by default we always return a + number >= sys.float_info.min. + mahalanobis : float + mahalanobis distance of the innovation. Read only. + inv : function, default numpy.linalg.inv + If you prefer another inverse function, such as the Moore-Penrose + pseudo inverse, set it to that instead: kf.inv = np.linalg.pinv + This is only used to invert self.S. If you know it is diagonal, you + might choose to set it to filterpy.common.inv_diagonal, which is + several times faster than numpy.linalg.inv for diagonal matrices. + alpha : float + Fading memory setting. 1.0 gives the normal Kalman filter, and + values slightly larger than 1.0 (such as 1.02) give a fading + memory effect - previous measurements have less influence on the + filter's estimates. This formulation of the Fading memory filter + (there are many) is due to Dan Simon [1]_. + References + ---------- + .. [1] Dan Simon. "Optimal State Estimation." John Wiley & Sons. + p. 208-212. (2006) + .. [2] Roger Labbe. "Kalman and Bayesian Filters in Python" + https://github.com/rlabbe/Kalman-and-Bayesian-Filters-in-Python + """ + + def __init__(self, dim_x, dim_z, dim_u=0): + if dim_x < 1: + raise ValueError("dim_x must be 1 or greater") + if dim_z < 1: + raise ValueError("dim_z must be 1 or greater") + if dim_u < 0: + raise ValueError("dim_u must be 0 or greater") + + self.dim_x = dim_x + self.dim_z = dim_z + self.dim_u = dim_u + + self.x = zeros((dim_x, 1)) # state + self.P = eye(dim_x) # uncertainty covariance + self.Q = eye(dim_x) # process uncertainty + self.B = None # control transition matrix + self.F = eye(dim_x) # state transition matrix + self.H = zeros((dim_z, dim_x)) # measurement function + self.R = eye(dim_z) # measurement uncertainty + self._alpha_sq = 1.0 # fading memory control + self.M = np.zeros((dim_x, dim_z)) # process-measurement cross correlation + self.z = np.array([[None] * self.dim_z]).T + + # gain and residual are computed during the innovation step. We + # save them so that in case you want to inspect them for various + # purposes + self.K = np.zeros((dim_x, dim_z)) # kalman gain + self.y = zeros((dim_z, 1)) + self.S = np.zeros((dim_z, dim_z)) # system uncertainty + self.SI = np.zeros((dim_z, dim_z)) # inverse system uncertainty + + # identity matrix. Do not alter this. + self._I = np.eye(dim_x) + + # these will always be a copy of x,P after predict() is called + self.x_prior = self.x.copy() + self.P_prior = self.P.copy() + + # these will always be a copy of x,P after update() is called + self.x_post = self.x.copy() + self.P_post = self.P.copy() + + # Only computed only if requested via property + self._log_likelihood = log(sys.float_info.min) + self._likelihood = sys.float_info.min + self._mahalanobis = None + + # keep all observations + self.history_obs = [] + + self.inv = np.linalg.inv + + self.attr_saved = None + self.observed = False + self.last_measurement = None + + def predict(self, u=None, B=None, F=None, Q=None): + """ + Predict next state (prior) using the Kalman filter state propagation + equations. + Parameters + ---------- + u : np.array, default 0 + Optional control vector. + B : np.array(dim_x, dim_u), or None + Optional control transition matrix; a value of None + will cause the filter to use `self.B`. + F : np.array(dim_x, dim_x), or None + Optional state transition matrix; a value of None + will cause the filter to use `self.F`. + Q : np.array(dim_x, dim_x), scalar, or None + Optional process noise matrix; a value of None will cause the + filter to use `self.Q`. + """ + + if B is None: + B = self.B + if F is None: + F = self.F + if Q is None: + Q = self.Q + elif isscalar(Q): + Q = eye(self.dim_x) * Q + + # x = Fx + Bu + if B is not None and u is not None: + self.x = dot(F, self.x) + dot(B, u) + else: + self.x = dot(F, self.x) + + # P = FPF' + Q + self.P = self._alpha_sq * dot(dot(F, self.P), F.T) + Q + + # save prior + self.x_prior = self.x.copy() + self.P_prior = self.P.copy() + + def freeze(self): + """ + Save the parameters before non-observation forward + """ + self.attr_saved = deepcopy(self.__dict__) + + def apply_affine_correction(self, m, t, new_kf): + """ + Apply to both last state and last observation for OOS smoothing. + + Messy due to internal logic for kalman filter being messy. + """ + if new_kf: + big_m = np.kron(np.eye(4, dtype=float), m) + self.x = big_m @ self.x + self.x[:2] += t + self.P = big_m @ self.P @ big_m.T + + # If frozen, also need to update the frozen state for OOS + if not self.observed and self.attr_saved is not None: + self.attr_saved["x"] = big_m @ self.attr_saved["x"] + self.attr_saved["x"][:2] += t + self.attr_saved["P"] = big_m @ self.attr_saved["P"] @ big_m.T + self.attr_saved["last_measurement"][:2] = m @ self.attr_saved["last_measurement"][:2] + t + self.attr_saved["last_measurement"][2:] = m @ self.attr_saved["last_measurement"][2:] + else: + scale = np.linalg.norm(m[:, 0]) + self.x[:2] = m @ self.x[:2] + t + self.x[4:6] = m @ self.x[4:6] + # self.x[2] *= scale + # self.x[6] *= scale + + self.P[:2, :2] = m @ self.P[:2, :2] @ m.T + self.P[4:6, 4:6] = m @ self.P[4:6, 4:6] @ m.T + # self.P[2, 2] *= 2 * scale + # self.P[6, 6] *= 2 * scale + + # If frozen, also need to update the frozen state for OOS + if not self.observed and self.attr_saved is not None: + self.attr_saved["x"][:2] = m @ self.attr_saved["x"][:2] + t + self.attr_saved["x"][4:6] = m @ self.attr_saved["x"][4:6] + # self.attr_saved["x"][2] *= scale + # self.attr_saved["x"][6] *= scale + + self.attr_saved["P"][:2, :2] = m @ self.attr_saved["P"][:2, :2] @ m.T + self.attr_saved["P"][4:6, 4:6] = m @ self.attr_saved["P"][4:6, 4:6] @ m.T + # self.attr_saved["P"][2, 2] *= 2 * scale + # self.attr_saved["P"][6, 6] *= 2 * scale + + self.attr_saved["last_measurement"][:2] = m @ self.attr_saved["last_measurement"][:2] + t + # self.attr_saved["last_measurement"][2] *= scale + + def unfreeze(self): + if self.attr_saved is not None: + new_history = deepcopy(self.history_obs) + self.__dict__ = self.attr_saved + # self.history_obs = new_history + self.history_obs = self.history_obs[:-1] + occur = [int(d is None) for d in new_history] + indices = np.where(np.array(occur) == 0)[0] + index1 = indices[-2] + index2 = indices[-1] + # box1 = new_history[index1] + box1 = self.last_measurement + x1, y1, s1, r1 = box1 + w1 = np.sqrt(s1 * r1) + h1 = np.sqrt(s1 / r1) + box2 = new_history[index2] + x2, y2, s2, r2 = box2 + w2 = np.sqrt(s2 * r2) + h2 = np.sqrt(s2 / r2) + time_gap = index2 - index1 + dx = (x2 - x1) / time_gap + dy = (y2 - y1) / time_gap + dw = (w2 - w1) / time_gap + dh = (h2 - h1) / time_gap + for i in range(index2 - index1): + """ + The default virtual trajectory generation is by linear + motion (constant speed hypothesis), you could modify this + part to implement your own. + """ + x = x1 + (i + 1) * dx + y = y1 + (i + 1) * dy + w = w1 + (i + 1) * dw + h = h1 + (i + 1) * dh + s = w * h + r = w / float(h) + new_box = np.array([x, y, s, r]).reshape((4, 1)) + """ + I still use predict-update loop here to refresh the parameters, + but this can be faster by directly modifying the internal parameters + as suggested in the paper. I keep this naive but slow way for + easy read and understanding + """ + self.update(new_box) + if not i == (index2 - index1 - 1): + self.predict() + + def update(self, z, R=None, H=None): + """ + Add a new measurement (z) to the Kalman filter. + If z is None, nothing is computed. However, x_post and P_post are + updated with the prior (x_prior, P_prior), and self.z is set to None. + Parameters + ---------- + z : (dim_z, 1): array_like + measurement for this update. z can be a scalar if dim_z is 1, + otherwise it must be convertible to a column vector. + If you pass in a value of H, z must be a column vector the + of the correct size. + R : np.array, scalar, or None + Optionally provide R to override the measurement noise for this + one call, otherwise self.R will be used. + H : np.array, or None + Optionally provide H to override the measurement function for this + one call, otherwise self.H will be used. + """ + # set to None to force recompute + self._log_likelihood = None + self._likelihood = None + self._mahalanobis = None + + # append the observation + self.history_obs.append(z) + + if z is None: + if self.observed: + """ + Got no observation so freeze the current parameters for future + potential online smoothing. + """ + self.last_measurement = self.history_obs[-2] + self.freeze() + self.observed = False + self.z = np.array([[None] * self.dim_z]).T + self.x_post = self.x.copy() + self.P_post = self.P.copy() + self.y = zeros((self.dim_z, 1)) + return + + # self.observed = True + if not self.observed: + """ + Get observation, use online smoothing to re-update parameters + """ + self.unfreeze() + self.observed = True + + if R is None: + R = self.R + elif isscalar(R): + R = eye(self.dim_z) * R + + if H is None: + z = reshape_z(z, self.dim_z, self.x.ndim) + H = self.H + + # y = z - Hx + # error (residual) between measurement and prediction + self.y = z - dot(H, self.x) + + # common subexpression for speed + PHT = dot(self.P, H.T) + + # S = HPH' + R + # project system uncertainty into measurement space + self.S = dot(H, PHT) + R + self.SI = self.inv(self.S) + # K = PH'inv(S) + # map system uncertainty into kalman gain + self.K = dot(PHT, self.SI) + + # x = x + Ky + # predict new x with residual scaled by the kalman gain + self.x = self.x + dot(self.K, self.y) + + # P = (I-KH)P(I-KH)' + KRK' + # This is more numerically stable + # and works for non-optimal K vs the equation + # P = (I-KH)P usually seen in the literature. + + I_KH = self._I - dot(self.K, H) + self.P = dot(dot(I_KH, self.P), I_KH.T) + dot(dot(self.K, R), self.K.T) + + # save measurement and posterior state + self.z = deepcopy(z) + self.x_post = self.x.copy() + self.P_post = self.P.copy() + + def md_for_measurement(self, z): + """Mahalanobis distance for any measurement. + + Should be run after a prediction() call. + """ + z = reshape_z(z, self.dim_z, self.x.ndim) + H = self.H + y = z - dot(H, self.x) + md = sqrt(float(dot(dot(y.T, self.SI), y))) + return md + + def predict_steadystate(self, u=0, B=None): + """ + Predict state (prior) using the Kalman filter state propagation + equations. Only x is updated, P is left unchanged. See + update_steadstate() for a longer explanation of when to use this + method. + Parameters + ---------- + u : np.array + Optional control vector. If non-zero, it is multiplied by B + to create the control input into the system. + B : np.array(dim_x, dim_u), or None + Optional control transition matrix; a value of None + will cause the filter to use `self.B`. + """ + + if B is None: + B = self.B + + # x = Fx + Bu + if B is not None: + self.x = dot(self.F, self.x) + dot(B, u) + else: + self.x = dot(self.F, self.x) + + # save prior + self.x_prior = self.x.copy() + self.P_prior = self.P.copy() + + def update_steadystate(self, z): + """ + Add a new measurement (z) to the Kalman filter without recomputing + the Kalman gain K, the state covariance P, or the system + uncertainty S. + You can use this for LTI systems since the Kalman gain and covariance + converge to a fixed value. Precompute these and assign them explicitly, + or run the Kalman filter using the normal predict()/update(0 cycle + until they converge. + The main advantage of this call is speed. We do significantly less + computation, notably avoiding a costly matrix inversion. + Use in conjunction with predict_steadystate(), otherwise P will grow + without bound. + Parameters + ---------- + z : (dim_z, 1): array_like + measurement for this update. z can be a scalar if dim_z is 1, + otherwise it must be convertible to a column vector. + Examples + -------- + >>> cv = kinematic_kf(dim=3, order=2) # 3D const velocity filter + >>> # let filter converge on representative data, then save k and P + >>> for i in range(100): + >>> cv.predict() + >>> cv.update([i, i, i]) + >>> saved_k = np.copy(cv.K) + >>> saved_P = np.copy(cv.P) + later on: + >>> cv = kinematic_kf(dim=3, order=2) # 3D const velocity filter + >>> cv.K = np.copy(saved_K) + >>> cv.P = np.copy(saved_P) + >>> for i in range(100): + >>> cv.predict_steadystate() + >>> cv.update_steadystate([i, i, i]) + """ + + # set to None to force recompute + self._log_likelihood = None + self._likelihood = None + self._mahalanobis = None + + if z is None: + self.z = np.array([[None] * self.dim_z]).T + self.x_post = self.x.copy() + self.P_post = self.P.copy() + self.y = zeros((self.dim_z, 1)) + return + + z = reshape_z(z, self.dim_z, self.x.ndim) + + # y = z - Hx + # error (residual) between measurement and prediction + self.y = z - dot(self.H, self.x) + + # x = x + Ky + # predict new x with residual scaled by the kalman gain + self.x = self.x + dot(self.K, self.y) + + self.z = deepcopy(z) + self.x_post = self.x.copy() + self.P_post = self.P.copy() + + # set to None to force recompute + self._log_likelihood = None + self._likelihood = None + self._mahalanobis = None + + def update_correlated(self, z, R=None, H=None): + """Add a new measurement (z) to the Kalman filter assuming that + process noise and measurement noise are correlated as defined in + the `self.M` matrix. + A partial derivation can be found in [1] + If z is None, nothing is changed. + Parameters + ---------- + z : (dim_z, 1): array_like + measurement for this update. z can be a scalar if dim_z is 1, + otherwise it must be convertible to a column vector. + R : np.array, scalar, or None + Optionally provide R to override the measurement noise for this + one call, otherwise self.R will be used. + H : np.array, or None + Optionally provide H to override the measurement function for this + one call, otherwise self.H will be used. + References + ---------- + .. [1] Bulut, Y. (2011). Applied Kalman filter theory (Doctoral dissertation, Northeastern University). + http://people.duke.edu/~hpgavin/SystemID/References/Balut-KalmanFilter-PhD-NEU-2011.pdf + """ + + # set to None to force recompute + self._log_likelihood = None + self._likelihood = None + self._mahalanobis = None + + if z is None: + self.z = np.array([[None] * self.dim_z]).T + self.x_post = self.x.copy() + self.P_post = self.P.copy() + self.y = zeros((self.dim_z, 1)) + return + + if R is None: + R = self.R + elif isscalar(R): + R = eye(self.dim_z) * R + + # rename for readability and a tiny extra bit of speed + if H is None: + z = reshape_z(z, self.dim_z, self.x.ndim) + H = self.H + + # handle special case: if z is in form [[z]] but x is not a column + # vector dimensions will not match + if self.x.ndim == 1 and shape(z) == (1, 1): + z = z[0] + + if shape(z) == (): # is it scalar, e.g. z=3 or z=np.array(3) + z = np.asarray([z]) + + # y = z - Hx + # error (residual) between measurement and prediction + self.y = z - dot(H, self.x) + + # common subexpression for speed + PHT = dot(self.P, H.T) + + # project system uncertainty into measurement space + self.S = dot(H, PHT) + dot(H, self.M) + dot(self.M.T, H.T) + R + self.SI = self.inv(self.S) + + # K = PH'inv(S) + # map system uncertainty into kalman gain + self.K = dot(PHT + self.M, self.SI) + + # x = x + Ky + # predict new x with residual scaled by the kalman gain + self.x = self.x + dot(self.K, self.y) + self.P = self.P - dot(self.K, dot(H, self.P) + self.M.T) + + self.z = deepcopy(z) + self.x_post = self.x.copy() + self.P_post = self.P.copy() + + def batch_filter( + self, + zs, + Fs=None, + Qs=None, + Hs=None, + Rs=None, + Bs=None, + us=None, + update_first=False, + saver=None, + ): + """Batch processes a sequences of measurements. + Parameters + ---------- + zs : list-like + list of measurements at each time step `self.dt`. Missing + measurements must be represented by `None`. + Fs : None, list-like, default=None + optional value or list of values to use for the state transition + matrix F. + If Fs is None then self.F is used for all epochs. + Otherwise it must contain a list-like list of F's, one for + each epoch. This allows you to have varying F per epoch. + Qs : None, np.array or list-like, default=None + optional value or list of values to use for the process error + covariance Q. + If Qs is None then self.Q is used for all epochs. + Otherwise it must contain a list-like list of Q's, one for + each epoch. This allows you to have varying Q per epoch. + Hs : None, np.array or list-like, default=None + optional list of values to use for the measurement matrix H. + If Hs is None then self.H is used for all epochs. + If Hs contains a single matrix, then it is used as H for all + epochs. + Otherwise it must contain a list-like list of H's, one for + each epoch. This allows you to have varying H per epoch. + Rs : None, np.array or list-like, default=None + optional list of values to use for the measurement error + covariance R. + If Rs is None then self.R is used for all epochs. + Otherwise it must contain a list-like list of R's, one for + each epoch. This allows you to have varying R per epoch. + Bs : None, np.array or list-like, default=None + optional list of values to use for the control transition matrix B. + If Bs is None then self.B is used for all epochs. + Otherwise it must contain a list-like list of B's, one for + each epoch. This allows you to have varying B per epoch. + us : None, np.array or list-like, default=None + optional list of values to use for the control input vector; + If us is None then None is used for all epochs (equivalent to 0, + or no control input). + Otherwise it must contain a list-like list of u's, one for + each epoch. + update_first : bool, optional, default=False + controls whether the order of operations is update followed by + predict, or predict followed by update. Default is predict->update. + saver : filterpy.common.Saver, optional + filterpy.common.Saver object. If provided, saver.save() will be + called after every epoch + Returns + ------- + means : np.array((n,dim_x,1)) + array of the state for each time step after the update. Each entry + is an np.array. In other words `means[k,:]` is the state at step + `k`. + covariance : np.array((n,dim_x,dim_x)) + array of the covariances for each time step after the update. + In other words `covariance[k,:,:]` is the covariance at step `k`. + means_predictions : np.array((n,dim_x,1)) + array of the state for each time step after the predictions. Each + entry is an np.array. In other words `means[k,:]` is the state at + step `k`. + covariance_predictions : np.array((n,dim_x,dim_x)) + array of the covariances for each time step after the prediction. + In other words `covariance[k,:,:]` is the covariance at step `k`. + Examples + -------- + .. code-block:: Python + # this example demonstrates tracking a measurement where the time + # between measurement varies, as stored in dts. This requires + # that F be recomputed for each epoch. The output is then smoothed + # with an RTS smoother. + zs = [t + random.randn()*4 for t in range (40)] + Fs = [np.array([[1., dt], [0, 1]] for dt in dts] + (mu, cov, _, _) = kf.batch_filter(zs, Fs=Fs) + (xs, Ps, Ks, Pps) = kf.rts_smoother(mu, cov, Fs=Fs) + """ + + # pylint: disable=too-many-statements + n = np.size(zs, 0) + if Fs is None: + Fs = [self.F] * n + if Qs is None: + Qs = [self.Q] * n + if Hs is None: + Hs = [self.H] * n + if Rs is None: + Rs = [self.R] * n + if Bs is None: + Bs = [self.B] * n + if us is None: + us = [0] * n + + # mean estimates from Kalman Filter + if self.x.ndim == 1: + means = zeros((n, self.dim_x)) + means_p = zeros((n, self.dim_x)) + else: + means = zeros((n, self.dim_x, 1)) + means_p = zeros((n, self.dim_x, 1)) + + # state covariances from Kalman Filter + covariances = zeros((n, self.dim_x, self.dim_x)) + covariances_p = zeros((n, self.dim_x, self.dim_x)) + + if update_first: + for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)): + + self.update(z, R=R, H=H) + means[i, :] = self.x + covariances[i, :, :] = self.P + + self.predict(u=u, B=B, F=F, Q=Q) + means_p[i, :] = self.x + covariances_p[i, :, :] = self.P + + if saver is not None: + saver.save() + else: + for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)): + + self.predict(u=u, B=B, F=F, Q=Q) + means_p[i, :] = self.x + covariances_p[i, :, :] = self.P + + self.update(z, R=R, H=H) + means[i, :] = self.x + covariances[i, :, :] = self.P + + if saver is not None: + saver.save() + + return (means, covariances, means_p, covariances_p) + + def rts_smoother(self, Xs, Ps, Fs=None, Qs=None, inv=np.linalg.inv): + """ + Runs the Rauch-Tung-Striebel Kalman smoother on a set of + means and covariances computed by a Kalman filter. The usual input + would come from the output of `KalmanFilter.batch_filter()`. + Parameters + ---------- + Xs : numpy.array + array of the means (state variable x) of the output of a Kalman + filter. + Ps : numpy.array + array of the covariances of the output of a kalman filter. + Fs : list-like collection of numpy.array, optional + State transition matrix of the Kalman filter at each time step. + Optional, if not provided the filter's self.F will be used + Qs : list-like collection of numpy.array, optional + Process noise of the Kalman filter at each time step. Optional, + if not provided the filter's self.Q will be used + inv : function, default numpy.linalg.inv + If you prefer another inverse function, such as the Moore-Penrose + pseudo inverse, set it to that instead: kf.inv = np.linalg.pinv + Returns + ------- + x : numpy.ndarray + smoothed means + P : numpy.ndarray + smoothed state covariances + K : numpy.ndarray + smoother gain at each step + Pp : numpy.ndarray + Predicted state covariances + Examples + -------- + .. code-block:: Python + zs = [t + random.randn()*4 for t in range (40)] + (mu, cov, _, _) = kalman.batch_filter(zs) + (x, P, K, Pp) = rts_smoother(mu, cov, kf.F, kf.Q) + """ + + if len(Xs) != len(Ps): + raise ValueError("length of Xs and Ps must be the same") + + n = Xs.shape[0] + dim_x = Xs.shape[1] + + if Fs is None: + Fs = [self.F] * n + if Qs is None: + Qs = [self.Q] * n + + # smoother gain + K = zeros((n, dim_x, dim_x)) + + x, P, Pp = Xs.copy(), Ps.copy(), Ps.copy() + for k in range(n - 2, -1, -1): + Pp[k] = dot(dot(Fs[k + 1], P[k]), Fs[k + 1].T) + Qs[k + 1] + + # pylint: disable=bad-whitespace + K[k] = dot(dot(P[k], Fs[k + 1].T), inv(Pp[k])) + x[k] += dot(K[k], x[k + 1] - dot(Fs[k + 1], x[k])) + P[k] += dot(dot(K[k], P[k + 1] - Pp[k]), K[k].T) + + return (x, P, K, Pp) + + def get_prediction(self, u=None, B=None, F=None, Q=None): + """ + Predict next state (prior) using the Kalman filter state propagation + equations and returns it without modifying the object. + Parameters + ---------- + u : np.array, default 0 + Optional control vector. + B : np.array(dim_x, dim_u), or None + Optional control transition matrix; a value of None + will cause the filter to use `self.B`. + F : np.array(dim_x, dim_x), or None + Optional state transition matrix; a value of None + will cause the filter to use `self.F`. + Q : np.array(dim_x, dim_x), scalar, or None + Optional process noise matrix; a value of None will cause the + filter to use `self.Q`. + Returns + ------- + (x, P) : tuple + State vector and covariance array of the prediction. + """ + + if B is None: + B = self.B + if F is None: + F = self.F + if Q is None: + Q = self.Q + elif isscalar(Q): + Q = eye(self.dim_x) * Q + + # x = Fx + Bu + if B is not None and u is not None: + x = dot(F, self.x) + dot(B, u) + else: + x = dot(F, self.x) + + # P = FPF' + Q + P = self._alpha_sq * dot(dot(F, self.P), F.T) + Q + + return x, P + + def get_update(self, z=None): + """ + Computes the new estimate based on measurement `z` and returns it + without altering the state of the filter. + Parameters + ---------- + z : (dim_z, 1): array_like + measurement for this update. z can be a scalar if dim_z is 1, + otherwise it must be convertible to a column vector. + Returns + ------- + (x, P) : tuple + State vector and covariance array of the update. + """ + + if z is None: + return self.x, self.P + z = reshape_z(z, self.dim_z, self.x.ndim) + + R = self.R + H = self.H + P = self.P + x = self.x + + # error (residual) between measurement and prediction + y = z - dot(H, x) + + # common subexpression for speed + PHT = dot(P, H.T) + + # project system uncertainty into measurement space + S = dot(H, PHT) + R + + # map system uncertainty into kalman gain + K = dot(PHT, self.inv(S)) + + # predict new x with residual scaled by the kalman gain + x = x + dot(K, y) + + # P = (I-KH)P(I-KH)' + KRK' + I_KH = self._I - dot(K, H) + P = dot(dot(I_KH, P), I_KH.T) + dot(dot(K, R), K.T) + + return x, P + + def residual_of(self, z): + """ + Returns the residual for the given measurement (z). Does not alter + the state of the filter. + """ + z = reshape_z(z, self.dim_z, self.x.ndim) + return z - dot(self.H, self.x_prior) + + def measurement_of_state(self, x): + """ + Helper function that converts a state into a measurement. + Parameters + ---------- + x : np.array + kalman state vector + Returns + ------- + z : (dim_z, 1): array_like + measurement for this update. z can be a scalar if dim_z is 1, + otherwise it must be convertible to a column vector. + """ + + return dot(self.H, x) + + @property + def log_likelihood(self): + """ + log-likelihood of the last measurement. + """ + if self._log_likelihood is None: + self._log_likelihood = logpdf(x=self.y, cov=self.S) + return self._log_likelihood + + @property + def likelihood(self): + """ + Computed from the log-likelihood. The log-likelihood can be very + small, meaning a large negative value such as -28000. Taking the + exp() of that results in 0.0, which can break typical algorithms + which multiply by this value, so by default we always return a + number >= sys.float_info.min. + """ + if self._likelihood is None: + self._likelihood = exp(self.log_likelihood) + if self._likelihood == 0: + self._likelihood = sys.float_info.min + return self._likelihood + + @property + def mahalanobis(self): + """ " + Mahalanobis distance of measurement. E.g. 3 means measurement + was 3 standard deviations away from the predicted value. + Returns + ------- + mahalanobis : float + """ + if self._mahalanobis is None: + self._mahalanobis = sqrt(float(dot(dot(self.y.T, self.SI), self.y))) + return self._mahalanobis + + @property + def alpha(self): + """ + Fading memory setting. 1.0 gives the normal Kalman filter, and + values slightly larger than 1.0 (such as 1.02) give a fading + memory effect - previous measurements have less influence on the + filter's estimates. This formulation of the Fading memory filter + (there are many) is due to Dan Simon [1]_. + """ + return self._alpha_sq**0.5 + + def log_likelihood_of(self, z): + """ + log likelihood of the measurement `z`. This should only be called + after a call to update(). Calling after predict() will yield an + incorrect result.""" + + if z is None: + return log(sys.float_info.min) + return logpdf(z, dot(self.H, self.x), self.S) + + @alpha.setter + def alpha(self, value): + if not np.isscalar(value) or value < 1: + raise ValueError("alpha must be a float greater than 1") + + self._alpha_sq = value**2 + + def __repr__(self): + return "\n".join( + [ + "KalmanFilter object", + pretty_str("dim_x", self.dim_x), + pretty_str("dim_z", self.dim_z), + pretty_str("dim_u", self.dim_u), + pretty_str("x", self.x), + pretty_str("P", self.P), + pretty_str("x_prior", self.x_prior), + pretty_str("P_prior", self.P_prior), + pretty_str("x_post", self.x_post), + pretty_str("P_post", self.P_post), + pretty_str("F", self.F), + pretty_str("Q", self.Q), + pretty_str("R", self.R), + pretty_str("H", self.H), + pretty_str("K", self.K), + pretty_str("y", self.y), + pretty_str("S", self.S), + pretty_str("SI", self.SI), + pretty_str("M", self.M), + pretty_str("B", self.B), + pretty_str("z", self.z), + pretty_str("log-likelihood", self.log_likelihood), + pretty_str("likelihood", self.likelihood), + pretty_str("mahalanobis", self.mahalanobis), + pretty_str("alpha", self.alpha), + pretty_str("inv", self.inv), + ] + ) + + def test_matrix_dimensions(self, z=None, H=None, R=None, F=None, Q=None): + """ + Performs a series of asserts to check that the size of everything + is what it should be. This can help you debug problems in your design. + If you pass in H, R, F, Q those will be used instead of this object's + value for those matrices. + Testing `z` (the measurement) is problamatic. x is a vector, and can be + implemented as either a 1D array or as a nx1 column vector. Thus Hx + can be of different shapes. Then, if Hx is a single value, it can + be either a 1D array or 2D vector. If either is true, z can reasonably + be a scalar (either '3' or np.array('3') are scalars under this + definition), a 1D, 1 element array, or a 2D, 1 element array. You are + allowed to pass in any combination that works. + """ + + if H is None: + H = self.H + if R is None: + R = self.R + if F is None: + F = self.F + if Q is None: + Q = self.Q + x = self.x + P = self.P + + assert x.ndim == 1 or x.ndim == 2, "x must have one or two dimensions, but has {}".format(x.ndim) + + if x.ndim == 1: + assert x.shape[0] == self.dim_x, "Shape of x must be ({},{}), but is {}".format(self.dim_x, 1, x.shape) + else: + assert x.shape == ( + self.dim_x, + 1, + ), "Shape of x must be ({},{}), but is {}".format(self.dim_x, 1, x.shape) + + assert P.shape == ( + self.dim_x, + self.dim_x, + ), "Shape of P must be ({},{}), but is {}".format(self.dim_x, self.dim_x, P.shape) + + assert Q.shape == ( + self.dim_x, + self.dim_x, + ), "Shape of Q must be ({},{}), but is {}".format(self.dim_x, self.dim_x, P.shape) + + assert F.shape == ( + self.dim_x, + self.dim_x, + ), "Shape of F must be ({},{}), but is {}".format(self.dim_x, self.dim_x, F.shape) + + assert np.ndim(H) == 2, "Shape of H must be (dim_z, {}), but is {}".format(P.shape[0], shape(H)) + + assert H.shape[1] == P.shape[0], "Shape of H must be (dim_z, {}), but is {}".format(P.shape[0], H.shape) + + # shape of R must be the same as HPH' + hph_shape = (H.shape[0], H.shape[0]) + r_shape = shape(R) + + if H.shape[0] == 1: + # r can be scalar, 1D, or 2D in this case + assert r_shape in [ + (), + (1,), + (1, 1), + ], "R must be scalar or one element array, but is shaped {}".format(r_shape) + else: + assert r_shape == hph_shape, "shape of R should be {} but it is {}".format(hph_shape, r_shape) + + if z is not None: + z_shape = shape(z) + else: + z_shape = (self.dim_z, 1) + + # H@x must have shape of z + Hx = dot(H, x) + + if z_shape == (): # scalar or np.array(scalar) + assert Hx.ndim == 1 or shape(Hx) == ( + 1, + 1, + ), "shape of z should be {}, not {} for the given H".format(shape(Hx), z_shape) + + elif shape(Hx) == (1,): + assert z_shape[0] == 1, "Shape of z must be {} for the given H".format(shape(Hx)) + + else: + assert z_shape == shape(Hx) or ( + len(z_shape) == 1 and shape(Hx) == (z_shape[0], 1) + ), "shape of z should be {}, not {} for the given H".format(shape(Hx), z_shape) + + if np.ndim(Hx) > 1 and shape(Hx) != (1, 1): + assert shape(Hx) == z_shape, "shape of z should be {} for the given H, but it is {}".format( + shape(Hx), z_shape + ) + + +def update(x, P, z, R, H=None, return_all=False): + """ + Add a new measurement (z) to the Kalman filter. If z is None, nothing + is changed. + This can handle either the multidimensional or unidimensional case. If + all parameters are floats instead of arrays the filter will still work, + and return floats for x, P as the result. + update(1, 2, 1, 1, 1) # univariate + update(x, P, 1 + Parameters + ---------- + x : numpy.array(dim_x, 1), or float + State estimate vector + P : numpy.array(dim_x, dim_x), or float + Covariance matrix + z : (dim_z, 1): array_like + measurement for this update. z can be a scalar if dim_z is 1, + otherwise it must be convertible to a column vector. + R : numpy.array(dim_z, dim_z), or float + Measurement noise matrix + H : numpy.array(dim_x, dim_x), or float, optional + Measurement function. If not provided, a value of 1 is assumed. + return_all : bool, default False + If true, y, K, S, and log_likelihood are returned, otherwise + only x and P are returned. + Returns + ------- + x : numpy.array + Posterior state estimate vector + P : numpy.array + Posterior covariance matrix + y : numpy.array or scalar + Residua. Difference between measurement and state in measurement space + K : numpy.array + Kalman gain + S : numpy.array + System uncertainty in measurement space + log_likelihood : float + log likelihood of the measurement + """ + + # pylint: disable=bare-except + + if z is None: + if return_all: + return x, P, None, None, None, None + return x, P + + if H is None: + H = np.array([1]) + + if np.isscalar(H): + H = np.array([H]) + + Hx = np.atleast_1d(dot(H, x)) + z = reshape_z(z, Hx.shape[0], x.ndim) + + # error (residual) between measurement and prediction + y = z - Hx + + # project system uncertainty into measurement space + S = dot(dot(H, P), H.T) + R + + # map system uncertainty into kalman gain + try: + K = dot(dot(P, H.T), linalg.inv(S)) + except: + # can't invert a 1D array, annoyingly + K = dot(dot(P, H.T), 1.0 / S) + + # predict new x with residual scaled by the kalman gain + x = x + dot(K, y) + + # P = (I-KH)P(I-KH)' + KRK' + KH = dot(K, H) + + try: + I_KH = np.eye(KH.shape[0]) - KH + except: + I_KH = np.array([1 - KH]) + P = dot(dot(I_KH, P), I_KH.T) + dot(dot(K, R), K.T) + + if return_all: + # compute log likelihood + log_likelihood = logpdf(z, dot(H, x), S) + return x, P, y, K, S, log_likelihood + return x, P + + +def update_steadystate(x, z, K, H=None): + """ + Add a new measurement (z) to the Kalman filter. If z is None, nothing + is changed. + Parameters + ---------- + x : numpy.array(dim_x, 1), or float + State estimate vector + z : (dim_z, 1): array_like + measurement for this update. z can be a scalar if dim_z is 1, + otherwise it must be convertible to a column vector. + K : numpy.array, or float + Kalman gain matrix + H : numpy.array(dim_x, dim_x), or float, optional + Measurement function. If not provided, a value of 1 is assumed. + Returns + ------- + x : numpy.array + Posterior state estimate vector + Examples + -------- + This can handle either the multidimensional or unidimensional case. If + all parameters are floats instead of arrays the filter will still work, + and return floats for x, P as the result. + >>> update_steadystate(1, 2, 1) # univariate + >>> update_steadystate(x, P, z, H) + """ + + if z is None: + return x + + if H is None: + H = np.array([1]) + + if np.isscalar(H): + H = np.array([H]) + + Hx = np.atleast_1d(dot(H, x)) + z = reshape_z(z, Hx.shape[0], x.ndim) + + # error (residual) between measurement and prediction + y = z - Hx + + # estimate new x with residual scaled by the kalman gain + return x + dot(K, y) + + +def predict(x, P, F=1, Q=0, u=0, B=1, alpha=1.0): + """ + Predict next state (prior) using the Kalman filter state propagation + equations. + Parameters + ---------- + x : numpy.array + State estimate vector + P : numpy.array + Covariance matrix + F : numpy.array() + State Transition matrix + Q : numpy.array, Optional + Process noise matrix + u : numpy.array, Optional, default 0. + Control vector. If non-zero, it is multiplied by B + to create the control input into the system. + B : numpy.array, optional, default 0. + Control transition matrix. + alpha : float, Optional, default=1.0 + Fading memory setting. 1.0 gives the normal Kalman filter, and + values slightly larger than 1.0 (such as 1.02) give a fading + memory effect - previous measurements have less influence on the + filter's estimates. This formulation of the Fading memory filter + (there are many) is due to Dan Simon + Returns + ------- + x : numpy.array + Prior state estimate vector + P : numpy.array + Prior covariance matrix + """ + + if np.isscalar(F): + F = np.array(F) + x = dot(F, x) + dot(B, u) + P = (alpha * alpha) * dot(dot(F, P), F.T) + Q + + return x, P + + +def predict_steadystate(x, F=1, u=0, B=1): + """ + Predict next state (prior) using the Kalman filter state propagation + equations. This steady state form only computes x, assuming that the + covariance is constant. + Parameters + ---------- + x : numpy.array + State estimate vector + P : numpy.array + Covariance matrix + F : numpy.array() + State Transition matrix + u : numpy.array, Optional, default 0. + Control vector. If non-zero, it is multiplied by B + to create the control input into the system. + B : numpy.array, optional, default 0. + Control transition matrix. + Returns + ------- + x : numpy.array + Prior state estimate vector + """ + + if np.isscalar(F): + F = np.array(F) + x = dot(F, x) + dot(B, u) + + return x + + +def batch_filter(x, P, zs, Fs, Qs, Hs, Rs, Bs=None, us=None, update_first=False, saver=None): + """ + Batch processes a sequences of measurements. + Parameters + ---------- + zs : list-like + list of measurements at each time step. Missing measurements must be + represented by None. + Fs : list-like + list of values to use for the state transition matrix matrix. + Qs : list-like + list of values to use for the process error + covariance. + Hs : list-like + list of values to use for the measurement matrix. + Rs : list-like + list of values to use for the measurement error + covariance. + Bs : list-like, optional + list of values to use for the control transition matrix; + a value of None in any position will cause the filter + to use `self.B` for that time step. + us : list-like, optional + list of values to use for the control input vector; + a value of None in any position will cause the filter to use + 0 for that time step. + update_first : bool, optional + controls whether the order of operations is update followed by + predict, or predict followed by update. Default is predict->update. + saver : filterpy.common.Saver, optional + filterpy.common.Saver object. If provided, saver.save() will be + called after every epoch + Returns + ------- + means : np.array((n,dim_x,1)) + array of the state for each time step after the update. Each entry + is an np.array. In other words `means[k,:]` is the state at step + `k`. + covariance : np.array((n,dim_x,dim_x)) + array of the covariances for each time step after the update. + In other words `covariance[k,:,:]` is the covariance at step `k`. + means_predictions : np.array((n,dim_x,1)) + array of the state for each time step after the predictions. Each + entry is an np.array. In other words `means[k,:]` is the state at + step `k`. + covariance_predictions : np.array((n,dim_x,dim_x)) + array of the covariances for each time step after the prediction. + In other words `covariance[k,:,:]` is the covariance at step `k`. + Examples + -------- + .. code-block:: Python + zs = [t + random.randn()*4 for t in range (40)] + Fs = [kf.F for t in range (40)] + Hs = [kf.H for t in range (40)] + (mu, cov, _, _) = kf.batch_filter(zs, Rs=R_list, Fs=Fs, Hs=Hs, Qs=None, + Bs=None, us=None, update_first=False) + (xs, Ps, Ks, Pps) = kf.rts_smoother(mu, cov, Fs=Fs, Qs=None) + """ + + n = np.size(zs, 0) + dim_x = x.shape[0] + + # mean estimates from Kalman Filter + if x.ndim == 1: + means = zeros((n, dim_x)) + means_p = zeros((n, dim_x)) + else: + means = zeros((n, dim_x, 1)) + means_p = zeros((n, dim_x, 1)) + + # state covariances from Kalman Filter + covariances = zeros((n, dim_x, dim_x)) + covariances_p = zeros((n, dim_x, dim_x)) + + if us is None: + us = [0.0] * n + Bs = [0.0] * n + + if update_first: + for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)): + + x, P = update(x, P, z, R=R, H=H) + means[i, :] = x + covariances[i, :, :] = P + + x, P = predict(x, P, u=u, B=B, F=F, Q=Q) + means_p[i, :] = x + covariances_p[i, :, :] = P + if saver is not None: + saver.save() + else: + for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)): + + x, P = predict(x, P, u=u, B=B, F=F, Q=Q) + means_p[i, :] = x + covariances_p[i, :, :] = P + + x, P = update(x, P, z, R=R, H=H) + means[i, :] = x + covariances[i, :, :] = P + if saver is not None: + saver.save() + + return (means, covariances, means_p, covariances_p) + + +def rts_smoother(Xs, Ps, Fs, Qs): + """ + Runs the Rauch-Tung-Striebel Kalman smoother on a set of + means and covariances computed by a Kalman filter. The usual input + would come from the output of `KalmanFilter.batch_filter()`. + Parameters + ---------- + Xs : numpy.array + array of the means (state variable x) of the output of a Kalman + filter. + Ps : numpy.array + array of the covariances of the output of a kalman filter. + Fs : list-like collection of numpy.array + State transition matrix of the Kalman filter at each time step. + Qs : list-like collection of numpy.array, optional + Process noise of the Kalman filter at each time step. + Returns + ------- + x : numpy.ndarray + smoothed means + P : numpy.ndarray + smoothed state covariances + K : numpy.ndarray + smoother gain at each step + pP : numpy.ndarray + predicted state covariances + Examples + -------- + .. code-block:: Python + zs = [t + random.randn()*4 for t in range (40)] + (mu, cov, _, _) = kalman.batch_filter(zs) + (x, P, K, pP) = rts_smoother(mu, cov, kf.F, kf.Q) + """ + + if len(Xs) != len(Ps): + raise ValueError("length of Xs and Ps must be the same") + + n = Xs.shape[0] + dim_x = Xs.shape[1] + + # smoother gain + K = zeros((n, dim_x, dim_x)) + x, P, pP = Xs.copy(), Ps.copy(), Ps.copy() + + for k in range(n - 2, -1, -1): + pP[k] = dot(dot(Fs[k], P[k]), Fs[k].T) + Qs[k] + + # pylint: disable=bad-whitespace + K[k] = dot(dot(P[k], Fs[k].T), linalg.inv(pP[k])) + x[k] += dot(K[k], x[k + 1] - dot(Fs[k], x[k])) + P[k] += dot(dot(K[k], P[k + 1] - pP[k]), K[k].T) + + return (x, P, K, pP) diff --git a/feeder/trackers/deepocsort/ocsort.py b/feeder/trackers/deepocsort/ocsort.py new file mode 100644 index 0000000..a20f34a --- /dev/null +++ b/feeder/trackers/deepocsort/ocsort.py @@ -0,0 +1,670 @@ +""" + This script is adopted from the SORT script by Alex Bewley alex@bewley.ai +""" +from __future__ import print_function + +import pdb +import pickle + +import cv2 +import torch +import torchvision + +import numpy as np +from .association import * +from .embedding import EmbeddingComputer +from .cmc import CMCComputer +from reid_multibackend import ReIDDetectMultiBackend + + +def k_previous_obs(observations, cur_age, k): + if len(observations) == 0: + return [-1, -1, -1, -1, -1] + for i in range(k): + dt = k - i + if cur_age - dt in observations: + return observations[cur_age - dt] + max_age = max(observations.keys()) + return observations[max_age] + + +def convert_bbox_to_z(bbox): + """ + Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form + [x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is + the aspect ratio + """ + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + x = bbox[0] + w / 2.0 + y = bbox[1] + h / 2.0 + s = w * h # scale is just area + r = w / float(h + 1e-6) + return np.array([x, y, s, r]).reshape((4, 1)) + + +def convert_bbox_to_z_new(bbox): + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + x = bbox[0] + w / 2.0 + y = bbox[1] + h / 2.0 + return np.array([x, y, w, h]).reshape((4, 1)) + + +def convert_x_to_bbox_new(x): + x, y, w, h = x.reshape(-1)[:4] + return np.array([x - w / 2, y - h / 2, x + w / 2, y + h / 2]).reshape(1, 4) + + +def convert_x_to_bbox(x, score=None): + """ + Takes a bounding box in the centre form [x,y,s,r] and returns it in the form + [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right + """ + w = np.sqrt(x[2] * x[3]) + h = x[2] / w + if score == None: + return np.array([x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0]).reshape((1, 4)) + else: + return np.array([x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0, score]).reshape((1, 5)) + + +def speed_direction(bbox1, bbox2): + cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0 + cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0 + speed = np.array([cy2 - cy1, cx2 - cx1]) + norm = np.sqrt((cy2 - cy1) ** 2 + (cx2 - cx1) ** 2) + 1e-6 + return speed / norm + + +def new_kf_process_noise(w, h, p=1 / 20, v=1 / 160): + Q = np.diag( + ((p * w) ** 2, (p * h) ** 2, (p * w) ** 2, (p * h) ** 2, (v * w) ** 2, (v * h) ** 2, (v * w) ** 2, (v * h) ** 2) + ) + return Q + + +def new_kf_measurement_noise(w, h, m=1 / 20): + w_var = (m * w) ** 2 + h_var = (m * h) ** 2 + R = np.diag((w_var, h_var, w_var, h_var)) + return R + + +class KalmanBoxTracker(object): + """ + This class represents the internal state of individual tracked objects observed as bbox. + """ + + count = 0 + + def __init__(self, bbox, cls, delta_t=3, orig=False, emb=None, alpha=0, new_kf=False): + """ + Initialises a tracker using initial bounding box. + + """ + # define constant velocity model + if not orig: + from .kalmanfilter import KalmanFilterNew as KalmanFilter + else: + from filterpy.kalman import KalmanFilter + self.cls = cls + self.conf = bbox[-1] + self.new_kf = new_kf + if new_kf: + self.kf = KalmanFilter(dim_x=8, dim_z=4) + self.kf.F = np.array( + [ + # x y w h x' y' w' h' + [1, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 1, 0, 0], + [0, 0, 1, 0, 0, 0, 1, 0], + [0, 0, 0, 1, 0, 0, 0, 1], + [0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + ] + ) + self.kf.H = np.array( + [ + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + ] + ) + _, _, w, h = convert_bbox_to_z_new(bbox).reshape(-1) + self.kf.P = new_kf_process_noise(w, h) + self.kf.P[:4, :4] *= 4 + self.kf.P[4:, 4:] *= 100 + # Process and measurement uncertainty happen in functions + self.bbox_to_z_func = convert_bbox_to_z_new + self.x_to_bbox_func = convert_x_to_bbox_new + else: + self.kf = KalmanFilter(dim_x=7, dim_z=4) + self.kf.F = np.array( + [ + # x y s r x' y' s' + [1, 0, 0, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 0, 1], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 1], + ] + ) + self.kf.H = np.array( + [ + [1, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + ] + ) + self.kf.R[2:, 2:] *= 10.0 + self.kf.P[4:, 4:] *= 1000.0 # give high uncertainty to the unobservable initial velocities + self.kf.P *= 10.0 + self.kf.Q[-1, -1] *= 0.01 + self.kf.Q[4:, 4:] *= 0.01 + self.bbox_to_z_func = convert_bbox_to_z + self.x_to_bbox_func = convert_x_to_bbox + + self.kf.x[:4] = self.bbox_to_z_func(bbox) + + self.time_since_update = 0 + self.id = KalmanBoxTracker.count + KalmanBoxTracker.count += 1 + self.history = [] + self.hits = 0 + self.hit_streak = 0 + self.age = 0 + """ + NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of + function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a + fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), let's bear it for now. + """ + # Used for OCR + self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder + # Used to output track after min_hits reached + self.history_observations = [] + # Used for velocity + self.observations = dict() + self.velocity = None + self.delta_t = delta_t + + self.emb = emb + + self.frozen = False + + def update(self, bbox, cls): + """ + Updates the state vector with observed bbox. + """ + if bbox is not None: + self.frozen = False + self.cls = cls + if self.last_observation.sum() >= 0: # no previous observation + previous_box = None + for dt in range(self.delta_t, 0, -1): + if self.age - dt in self.observations: + previous_box = self.observations[self.age - dt] + break + if previous_box is None: + previous_box = self.last_observation + """ + Estimate the track speed direction with observations \Delta t steps away + """ + self.velocity = speed_direction(previous_box, bbox) + """ + Insert new observations. This is a ugly way to maintain both self.observations + and self.history_observations. Bear it for the moment. + """ + self.last_observation = bbox + self.observations[self.age] = bbox + self.history_observations.append(bbox) + + self.time_since_update = 0 + self.history = [] + self.hits += 1 + self.hit_streak += 1 + if self.new_kf: + R = new_kf_measurement_noise(self.kf.x[2, 0], self.kf.x[3, 0]) + self.kf.update(self.bbox_to_z_func(bbox), R=R) + else: + self.kf.update(self.bbox_to_z_func(bbox)) + else: + self.kf.update(bbox) + self.frozen = True + + def update_emb(self, emb, alpha=0.9): + self.emb = alpha * self.emb + (1 - alpha) * emb + self.emb /= np.linalg.norm(self.emb) + + def get_emb(self): + return self.emb.cpu() + + def apply_affine_correction(self, affine): + m = affine[:, :2] + t = affine[:, 2].reshape(2, 1) + # For OCR + if self.last_observation.sum() > 0: + ps = self.last_observation[:4].reshape(2, 2).T + ps = m @ ps + t + self.last_observation[:4] = ps.T.reshape(-1) + + # Apply to each box in the range of velocity computation + for dt in range(self.delta_t, -1, -1): + if self.age - dt in self.observations: + ps = self.observations[self.age - dt][:4].reshape(2, 2).T + ps = m @ ps + t + self.observations[self.age - dt][:4] = ps.T.reshape(-1) + + # Also need to change kf state, but might be frozen + self.kf.apply_affine_correction(m, t, self.new_kf) + + def predict(self): + """ + Advances the state vector and returns the predicted bounding box estimate. + """ + # Don't allow negative bounding boxes + if self.new_kf: + if self.kf.x[2] + self.kf.x[6] <= 0: + self.kf.x[6] = 0 + if self.kf.x[3] + self.kf.x[7] <= 0: + self.kf.x[7] = 0 + + # Stop velocity, will update in kf during OOS + if self.frozen: + self.kf.x[6] = self.kf.x[7] = 0 + Q = new_kf_process_noise(self.kf.x[2, 0], self.kf.x[3, 0]) + else: + if (self.kf.x[6] + self.kf.x[2]) <= 0: + self.kf.x[6] *= 0.0 + Q = None + + self.kf.predict(Q=Q) + self.age += 1 + if self.time_since_update > 0: + self.hit_streak = 0 + self.time_since_update += 1 + self.history.append(self.x_to_bbox_func(self.kf.x)) + return self.history[-1] + + def get_state(self): + """ + Returns the current bounding box estimate. + """ + return self.x_to_bbox_func(self.kf.x) + + def mahalanobis(self, bbox): + """Should be run after a predict() call for accuracy.""" + return self.kf.md_for_measurement(self.bbox_to_z_func(bbox)) + + +""" + We support multiple ways for association cost calculation, by default + we use IoU. GIoU may have better performance in some situations. We note + that we hardly normalize the cost by all methods to (0,1) which may not be + the best practice. +""" +ASSO_FUNCS = { + "iou": iou_batch, + "giou": giou_batch, + "ciou": ciou_batch, + "diou": diou_batch, + "ct_dist": ct_dist, +} + + +class OCSort(object): + def __init__( + self, + model_weights, + device, + fp16, + det_thresh, + max_age=30, + min_hits=3, + iou_threshold=0.3, + delta_t=3, + asso_func="iou", + inertia=0.2, + w_association_emb=0.75, + alpha_fixed_emb=0.95, + aw_param=0.5, + embedding_off=False, + cmc_off=False, + aw_off=False, + new_kf_off=False, + **kwargs + ): + """ + Sets key parameters for SORT + """ + self.max_age = max_age + self.min_hits = min_hits + self.iou_threshold = iou_threshold + self.trackers = [] + self.frame_count = 0 + self.det_thresh = det_thresh + self.delta_t = delta_t + self.asso_func = ASSO_FUNCS[asso_func] + self.inertia = inertia + self.w_association_emb = w_association_emb + self.alpha_fixed_emb = alpha_fixed_emb + self.aw_param = aw_param + KalmanBoxTracker.count = 0 + + self.embedder = ReIDDetectMultiBackend(weights=model_weights, device=device, fp16=fp16) + self.cmc = CMCComputer() + self.embedding_off = embedding_off + self.cmc_off = cmc_off + self.aw_off = aw_off + self.new_kf_off = new_kf_off + + def update(self, dets, img_numpy, tag='blub'): + """ + Params: + dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...] + Requires: this method must be called once for each frame even with empty detections (use np.empty((0, 5)) for frames without detections). + Returns the a similar array, where the last column is the object ID. + NOTE: The number of objects returned may differ from the number of detections provided. + """ + xyxys = dets[:, 0:4] + scores = dets[:, 4] + clss = dets[:, 5] + + classes = clss.numpy() + xyxys = xyxys.numpy() + scores = scores.numpy() + + dets = dets[:, 0:6].numpy() + remain_inds = scores > self.det_thresh + dets = dets[remain_inds] + self.height, self.width = img_numpy.shape[:2] + + # Rescale + #scale = min(img_tensor.shape[2] / img_numpy.shape[0], img_tensor.shape[3] / img_numpy.shape[1]) + #dets[:, :4] /= scale + + # Embedding + if self.embedding_off or dets.shape[0] == 0: + dets_embs = np.ones((dets.shape[0], 1)) + else: + # (Ndets x X) [512, 1024, 2048] + #dets_embs = self.embedder.compute_embedding(img_numpy, dets[:, :4], tag) + dets_embs = self._get_features(dets[:, :4], img_numpy) + + # CMC + if not self.cmc_off: + transform = self.cmc.compute_affine(img_numpy, dets[:, :4], tag) + for trk in self.trackers: + trk.apply_affine_correction(transform) + + trust = (dets[:, 4] - self.det_thresh) / (1 - self.det_thresh) + af = self.alpha_fixed_emb + # From [self.alpha_fixed_emb, 1], goes to 1 as detector is less confident + dets_alpha = af + (1 - af) * (1 - trust) + + # get predicted locations from existing trackers. + trks = np.zeros((len(self.trackers), 5)) + trk_embs = [] + to_del = [] + ret = [] + for t, trk in enumerate(trks): + pos = self.trackers[t].predict()[0] + trk[:] = [pos[0], pos[1], pos[2], pos[3], 0] + if np.any(np.isnan(pos)): + to_del.append(t) + else: + trk_embs.append(self.trackers[t].get_emb()) + trks = np.ma.compress_rows(np.ma.masked_invalid(trks)) + + if len(trk_embs) > 0: + trk_embs = np.vstack(trk_embs) + else: + trk_embs = np.array(trk_embs) + + for t in reversed(to_del): + self.trackers.pop(t) + + velocities = np.array([trk.velocity if trk.velocity is not None else np.array((0, 0)) for trk in self.trackers]) + last_boxes = np.array([trk.last_observation for trk in self.trackers]) + k_observations = np.array([k_previous_obs(trk.observations, trk.age, self.delta_t) for trk in self.trackers]) + + """ + First round of association + """ + # (M detections X N tracks, final score) + if self.embedding_off or dets.shape[0] == 0 or trk_embs.shape[0] == 0: + stage1_emb_cost = None + else: + stage1_emb_cost = dets_embs @ trk_embs.T + matched, unmatched_dets, unmatched_trks = associate( + dets, + trks, + self.iou_threshold, + velocities, + k_observations, + self.inertia, + stage1_emb_cost, + self.w_association_emb, + self.aw_off, + self.aw_param, + ) + for m in matched: + self.trackers[m[1]].update(dets[m[0], :5], dets[m[0], 5]) + self.trackers[m[1]].update_emb(dets_embs[m[0]], alpha=dets_alpha[m[0]]) + + """ + Second round of associaton by OCR + """ + if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0: + left_dets = dets[unmatched_dets] + left_dets_embs = dets_embs[unmatched_dets] + left_trks = last_boxes[unmatched_trks] + left_trks_embs = trk_embs[unmatched_trks] + + iou_left = self.asso_func(left_dets, left_trks) + # TODO: is better without this + emb_cost_left = left_dets_embs @ left_trks_embs.T + if self.embedding_off: + emb_cost_left = np.zeros_like(emb_cost_left) + iou_left = np.array(iou_left) + if iou_left.max() > self.iou_threshold: + """ + NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may + get a higher performance especially on MOT17/MOT20 datasets. But we keep it + uniform here for simplicity + """ + rematched_indices = linear_assignment(-iou_left) + to_remove_det_indices = [] + to_remove_trk_indices = [] + for m in rematched_indices: + det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]] + if iou_left[m[0], m[1]] < self.iou_threshold: + continue + self.trackers[trk_ind].update(dets[det_ind, :5], dets[det_ind, 5]) + self.trackers[trk_ind].update_emb(dets_embs[det_ind], alpha=dets_alpha[det_ind]) + to_remove_det_indices.append(det_ind) + to_remove_trk_indices.append(trk_ind) + unmatched_dets = np.setdiff1d(unmatched_dets, np.array(to_remove_det_indices)) + unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices)) + + for m in unmatched_trks: + self.trackers[m].update(None, None) + + # create and initialise new trackers for unmatched detections + for i in unmatched_dets: + trk = KalmanBoxTracker( + dets[i, :5], dets[i, 5], delta_t=self.delta_t, emb=dets_embs[i], alpha=dets_alpha[i], new_kf=not self.new_kf_off + ) + self.trackers.append(trk) + i = len(self.trackers) + for trk in reversed(self.trackers): + if trk.last_observation.sum() < 0: + d = trk.get_state()[0] + else: + """ + this is optional to use the recent observation or the kalman filter prediction, + we didn't notice significant difference here + """ + d = trk.last_observation[:4] + if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits): + # +1 as MOT benchmark requires positive + ret.append(np.concatenate((d, [trk.id + 1], [trk.cls], [trk.conf])).reshape(1, -1)) + i -= 1 + # remove dead tracklet + if trk.time_since_update > self.max_age: + self.trackers.pop(i) + if len(ret) > 0: + return np.concatenate(ret) + return np.empty((0, 5)) + + def _xywh_to_xyxy(self, bbox_xywh): + x, y, w, h = bbox_xywh + x1 = max(int(x - w / 2), 0) + x2 = min(int(x + w / 2), self.width - 1) + y1 = max(int(y - h / 2), 0) + y2 = min(int(y + h / 2), self.height - 1) + return x1, y1, x2, y2 + + def _get_features(self, bbox_xywh, ori_img): + im_crops = [] + for box in bbox_xywh: + x1, y1, x2, y2 = self._xywh_to_xyxy(box) + im = ori_img[y1:y2, x1:x2] + im_crops.append(im) + if im_crops: + features = self.embedder(im_crops).cpu() + else: + features = np.array([]) + + return features + + def update_public(self, dets, cates, scores): + self.frame_count += 1 + + det_scores = np.ones((dets.shape[0], 1)) + dets = np.concatenate((dets, det_scores), axis=1) + + remain_inds = scores > self.det_thresh + + cates = cates[remain_inds] + dets = dets[remain_inds] + + trks = np.zeros((len(self.trackers), 5)) + to_del = [] + ret = [] + for t, trk in enumerate(trks): + pos = self.trackers[t].predict()[0] + cat = self.trackers[t].cate + trk[:] = [pos[0], pos[1], pos[2], pos[3], cat] + if np.any(np.isnan(pos)): + to_del.append(t) + trks = np.ma.compress_rows(np.ma.masked_invalid(trks)) + for t in reversed(to_del): + self.trackers.pop(t) + + velocities = np.array([trk.velocity if trk.velocity is not None else np.array((0, 0)) for trk in self.trackers]) + last_boxes = np.array([trk.last_observation for trk in self.trackers]) + k_observations = np.array([k_previous_obs(trk.observations, trk.age, self.delta_t) for trk in self.trackers]) + + matched, unmatched_dets, unmatched_trks = associate_kitti( + dets, + trks, + cates, + self.iou_threshold, + velocities, + k_observations, + self.inertia, + ) + + for m in matched: + self.trackers[m[1]].update(dets[m[0], :]) + + if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0: + """ + The re-association stage by OCR. + NOTE: at this stage, adding other strategy might be able to continue improve + the performance, such as BYTE association by ByteTrack. + """ + left_dets = dets[unmatched_dets] + left_trks = last_boxes[unmatched_trks] + left_dets_c = left_dets.copy() + left_trks_c = left_trks.copy() + + iou_left = self.asso_func(left_dets_c, left_trks_c) + iou_left = np.array(iou_left) + det_cates_left = cates[unmatched_dets] + trk_cates_left = trks[unmatched_trks][:, 4] + num_dets = unmatched_dets.shape[0] + num_trks = unmatched_trks.shape[0] + cate_matrix = np.zeros((num_dets, num_trks)) + for i in range(num_dets): + for j in range(num_trks): + if det_cates_left[i] != trk_cates_left[j]: + """ + For some datasets, such as KITTI, there are different categories, + we have to avoid associate them together. + """ + cate_matrix[i][j] = -1e6 + iou_left = iou_left + cate_matrix + if iou_left.max() > self.iou_threshold - 0.1: + rematched_indices = linear_assignment(-iou_left) + to_remove_det_indices = [] + to_remove_trk_indices = [] + for m in rematched_indices: + det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]] + if iou_left[m[0], m[1]] < self.iou_threshold - 0.1: + continue + self.trackers[trk_ind].update(dets[det_ind, :]) + to_remove_det_indices.append(det_ind) + to_remove_trk_indices.append(trk_ind) + unmatched_dets = np.setdiff1d(unmatched_dets, np.array(to_remove_det_indices)) + unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices)) + + for i in unmatched_dets: + trk = KalmanBoxTracker(dets[i, :]) + trk.cate = cates[i] + self.trackers.append(trk) + i = len(self.trackers) + + for trk in reversed(self.trackers): + if trk.last_observation.sum() > 0: + d = trk.last_observation[:4] + else: + d = trk.get_state()[0] + if trk.time_since_update < 1: + if (self.frame_count <= self.min_hits) or (trk.hit_streak >= self.min_hits): + # id+1 as MOT benchmark requires positive + ret.append(np.concatenate((d, [trk.id + 1], [trk.cls], [trk.conf])).reshape(1, -1)) + if trk.hit_streak == self.min_hits: + # Head Padding (HP): recover the lost steps during initializing the track + for prev_i in range(self.min_hits - 1): + prev_observation = trk.history_observations[-(prev_i + 2)] + ret.append( + ( + np.concatenate( + ( + prev_observation[:4], + [trk.id + 1], + [trk.cls], + [trk.conf], + ) + ) + ).reshape(1, -1) + ) + i -= 1 + if trk.time_since_update > self.max_age: + self.trackers.pop(i) + + if len(ret) > 0: + return np.concatenate(ret) + return np.empty((0, 7)) + + def dump_cache(self): + self.cmc.dump_cache() + self.embedder.dump_cache() diff --git a/feeder/trackers/deepocsort/reid_multibackend.py b/feeder/trackers/deepocsort/reid_multibackend.py new file mode 100644 index 0000000..6578177 --- /dev/null +++ b/feeder/trackers/deepocsort/reid_multibackend.py @@ -0,0 +1,237 @@ +import torch.nn as nn +import torch +from pathlib import Path +import numpy as np +from itertools import islice +import torchvision.transforms as transforms +import cv2 +import sys +import torchvision.transforms as T +from collections import OrderedDict, namedtuple +import gdown +from os.path import exists as file_exists + + +from yolov8.ultralytics.yolo.utils.checks import check_requirements, check_version +from yolov8.ultralytics.yolo.utils import LOGGER +from trackers.strongsort.deep.reid_model_factory import (show_downloadeable_models, get_model_url, get_model_name, + download_url, load_pretrained_weights) +from trackers.strongsort.deep.models import build_model + + +def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''): + # Check file(s) for acceptable suffix + if file and suffix: + if isinstance(suffix, str): + suffix = [suffix] + for f in file if isinstance(file, (list, tuple)) else [file]: + s = Path(f).suffix.lower() # file suffix + if len(s): + assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}" + + +class ReIDDetectMultiBackend(nn.Module): + # ReID models MultiBackend class for python inference on various backends + def __init__(self, weights='osnet_x0_25_msmt17.pt', device=torch.device('cpu'), fp16=False): + super().__init__() + + w = weights[0] if isinstance(weights, list) else weights + self.pt, self.jit, self.onnx, self.xml, self.engine, self.tflite = self.model_type(w) # get backend + self.fp16 = fp16 + self.fp16 &= self.pt or self.jit or self.engine # FP16 + + # Build transform functions + self.device = device + self.image_size=(256, 128) + self.pixel_mean=[0.485, 0.456, 0.406] + self.pixel_std=[0.229, 0.224, 0.225] + self.transforms = [] + self.transforms += [T.Resize(self.image_size)] + self.transforms += [T.ToTensor()] + self.transforms += [T.Normalize(mean=self.pixel_mean, std=self.pixel_std)] + self.preprocess = T.Compose(self.transforms) + self.to_pil = T.ToPILImage() + + model_name = get_model_name(w) + + if w.suffix == '.pt': + model_url = get_model_url(w) + if not file_exists(w) and model_url is not None: + gdown.download(model_url, str(w), quiet=False) + elif file_exists(w): + pass + else: + print(f'No URL associated to the chosen StrongSORT weights ({w}). Choose between:') + show_downloadeable_models() + exit() + + # Build model + self.model = build_model( + model_name, + num_classes=1, + pretrained=not (w and w.is_file()), + use_gpu=device + ) + + if self.pt: # PyTorch + # populate model arch with weights + if w and w.is_file() and w.suffix == '.pt': + load_pretrained_weights(self.model, w) + + self.model.to(device).eval() + self.model.half() if self.fp16 else self.model.float() + elif self.jit: + LOGGER.info(f'Loading {w} for TorchScript inference...') + self.model = torch.jit.load(w) + self.model.half() if self.fp16 else self.model.float() + elif self.onnx: # ONNX Runtime + LOGGER.info(f'Loading {w} for ONNX Runtime inference...') + cuda = torch.cuda.is_available() and device.type != 'cpu' + #check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime')) + import onnxruntime + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider'] + self.session = onnxruntime.InferenceSession(str(w), providers=providers) + elif self.engine: # TensorRT + LOGGER.info(f'Loading {w} for TensorRT inference...') + import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download + check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0 + if device.type == 'cpu': + device = torch.device('cuda:0') + Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) + logger = trt.Logger(trt.Logger.INFO) + with open(w, 'rb') as f, trt.Runtime(logger) as runtime: + self.model_ = runtime.deserialize_cuda_engine(f.read()) + self.context = self.model_.create_execution_context() + self.bindings = OrderedDict() + self.fp16 = False # default updated below + dynamic = False + for index in range(self.model_.num_bindings): + name = self.model_.get_binding_name(index) + dtype = trt.nptype(self.model_.get_binding_dtype(index)) + if self.model_.binding_is_input(index): + if -1 in tuple(self.model_.get_binding_shape(index)): # dynamic + dynamic = True + self.context.set_binding_shape(index, tuple(self.model_.get_profile_shape(0, index)[2])) + if dtype == np.float16: + self.fp16 = True + shape = tuple(self.context.get_binding_shape(index)) + im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) + self.bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) + self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items()) + batch_size = self.bindings['images'].shape[0] # if dynamic, this is instead max batch size + elif self.xml: # OpenVINO + LOGGER.info(f'Loading {w} for OpenVINO inference...') + check_requirements(('openvino',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/ + from openvino.runtime import Core, Layout, get_batch + ie = Core() + if not Path(w).is_file(): # if not *.xml + w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir + network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin')) + if network.get_parameters()[0].get_layout().empty: + network.get_parameters()[0].set_layout(Layout("NCWH")) + batch_dim = get_batch(network) + if batch_dim.is_static: + batch_size = batch_dim.get_length() + self.executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2 + self.output_layer = next(iter(self.executable_network.outputs)) + + elif self.tflite: + LOGGER.info(f'Loading {w} for TensorFlow Lite inference...') + try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu + from tflite_runtime.interpreter import Interpreter, load_delegate + except ImportError: + import tensorflow as tf + Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate, + self.interpreter = tf.lite.Interpreter(model_path=w) + self.interpreter.allocate_tensors() + # Get input and output tensors. + self.input_details = self.interpreter.get_input_details() + self.output_details = self.interpreter.get_output_details() + + # Test model on random input data. + input_data = np.array(np.random.random_sample((1,256,128,3)), dtype=np.float32) + self.interpreter.set_tensor(self.input_details[0]['index'], input_data) + + self.interpreter.invoke() + + # The function `get_tensor()` returns a copy of the tensor data. + output_data = self.interpreter.get_tensor(self.output_details[0]['index']) + else: + print('This model framework is not supported yet!') + exit() + + + @staticmethod + def model_type(p='path/to/model.pt'): + # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx + from trackers.reid_export import export_formats + sf = list(export_formats().Suffix) # export suffixes + check_suffix(p, sf) # checks + types = [s in Path(p).name for s in sf] + return types + + def _preprocess(self, im_batch): + + images = [] + for element in im_batch: + image = self.to_pil(element) + image = self.preprocess(image) + images.append(image) + + images = torch.stack(images, dim=0) + images = images.to(self.device) + + return images + + + def forward(self, im_batch): + + # preprocess batch + im_batch = self._preprocess(im_batch) + + # batch to half + if self.fp16 and im_batch.dtype != torch.float16: + im_batch = im_batch.half() + + # batch processing + features = [] + if self.pt: + features = self.model(im_batch) + elif self.jit: # TorchScript + features = self.model(im_batch) + elif self.onnx: # ONNX Runtime + im_batch = im_batch.cpu().numpy() # torch to numpy + features = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im_batch})[0] + elif self.engine: # TensorRT + if True and im_batch.shape != self.bindings['images'].shape: + i_in, i_out = (self.model_.get_binding_index(x) for x in ('images', 'output')) + self.context.set_binding_shape(i_in, im_batch.shape) # reshape if dynamic + self.bindings['images'] = self.bindings['images']._replace(shape=im_batch.shape) + self.bindings['output'].data.resize_(tuple(self.context.get_binding_shape(i_out))) + s = self.bindings['images'].shape + assert im_batch.shape == s, f"input size {im_batch.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}" + self.binding_addrs['images'] = int(im_batch.data_ptr()) + self.context.execute_v2(list(self.binding_addrs.values())) + features = self.bindings['output'].data + elif self.xml: # OpenVINO + im_batch = im_batch.cpu().numpy() # FP32 + features = self.executable_network([im_batch])[self.output_layer] + else: + print('Framework not supported at the moment, we are working on it...') + exit() + + if isinstance(features, (list, tuple)): + return self.from_numpy(features[0]) if len(features) == 1 else [self.from_numpy(x) for x in features] + else: + return self.from_numpy(features) + + def from_numpy(self, x): + return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x + + def warmup(self, imgsz=[(256, 128, 3)]): + # Warmup model by running inference once + warmup_types = self.pt, self.jit, self.onnx, self.engine, self.tflite + if any(warmup_types) and self.device.type != 'cpu': + im = [np.empty(*imgsz).astype(np.uint8)] # input + for _ in range(2 if self.jit else 1): # + self.forward(im) # warmup \ No newline at end of file diff --git a/feeder/trackers/multi_tracker_zoo.py b/feeder/trackers/multi_tracker_zoo.py new file mode 100644 index 0000000..5dedf3e --- /dev/null +++ b/feeder/trackers/multi_tracker_zoo.py @@ -0,0 +1,84 @@ +from trackers.strongsort.utils.parser import get_config + +def create_tracker(tracker_type, tracker_config, reid_weights, device, half): + + cfg = get_config() + cfg.merge_from_file(tracker_config) + + if tracker_type == 'strongsort': + from trackers.strongsort.strong_sort import StrongSORT + strongsort = StrongSORT( + reid_weights, + device, + half, + max_dist=cfg.strongsort.max_dist, + max_iou_dist=cfg.strongsort.max_iou_dist, + max_age=cfg.strongsort.max_age, + max_unmatched_preds=cfg.strongsort.max_unmatched_preds, + n_init=cfg.strongsort.n_init, + nn_budget=cfg.strongsort.nn_budget, + mc_lambda=cfg.strongsort.mc_lambda, + ema_alpha=cfg.strongsort.ema_alpha, + + ) + return strongsort + + elif tracker_type == 'ocsort': + from trackers.ocsort.ocsort import OCSort + ocsort = OCSort( + det_thresh=cfg.ocsort.det_thresh, + max_age=cfg.ocsort.max_age, + min_hits=cfg.ocsort.min_hits, + iou_threshold=cfg.ocsort.iou_thresh, + delta_t=cfg.ocsort.delta_t, + asso_func=cfg.ocsort.asso_func, + inertia=cfg.ocsort.inertia, + use_byte=cfg.ocsort.use_byte, + ) + return ocsort + + elif tracker_type == 'bytetrack': + from trackers.bytetrack.byte_tracker import BYTETracker + bytetracker = BYTETracker( + track_thresh=cfg.bytetrack.track_thresh, + match_thresh=cfg.bytetrack.match_thresh, + track_buffer=cfg.bytetrack.track_buffer, + frame_rate=cfg.bytetrack.frame_rate + ) + return bytetracker + + elif tracker_type == 'botsort': + from trackers.botsort.bot_sort import BoTSORT + botsort = BoTSORT( + reid_weights, + device, + half, + track_high_thresh=cfg.botsort.track_high_thresh, + new_track_thresh=cfg.botsort.new_track_thresh, + track_buffer =cfg.botsort.track_buffer, + match_thresh=cfg.botsort.match_thresh, + proximity_thresh=cfg.botsort.proximity_thresh, + appearance_thresh=cfg.botsort.appearance_thresh, + cmc_method =cfg.botsort.cmc_method, + frame_rate=cfg.botsort.frame_rate, + lambda_=cfg.botsort.lambda_ + ) + return botsort + elif tracker_type == 'deepocsort': + from trackers.deepocsort.ocsort import OCSort + botsort = OCSort( + reid_weights, + device, + half, + det_thresh=cfg.deepocsort.det_thresh, + max_age=cfg.deepocsort.max_age, + min_hits=cfg.deepocsort.min_hits, + iou_threshold=cfg.deepocsort.iou_thresh, + delta_t=cfg.deepocsort.delta_t, + asso_func=cfg.deepocsort.asso_func, + inertia=cfg.deepocsort.inertia, + ) + return botsort + else: + print('No such tracker') + exit() \ No newline at end of file diff --git a/feeder/trackers/ocsort/association.py b/feeder/trackers/ocsort/association.py new file mode 100644 index 0000000..64c2a3e --- /dev/null +++ b/feeder/trackers/ocsort/association.py @@ -0,0 +1,377 @@ +import os +import numpy as np + + +def iou_batch(bboxes1, bboxes2): + """ + From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2] + """ + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0., xx2 - xx1) + h = np.maximum(0., yy2 - yy1) + wh = w * h + o = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh) + return(o) + + +def giou_batch(bboxes1, bboxes2): + """ + :param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2) + :param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2) + :return: + """ + # for details should go to https://arxiv.org/pdf/1902.09630.pdf + # ensure predict's bbox form + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0., xx2 - xx1) + h = np.maximum(0., yy2 - yy1) + wh = w * h + iou = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh) + + xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0]) + yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1]) + xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2]) + yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3]) + wc = xxc2 - xxc1 + hc = yyc2 - yyc1 + assert((wc > 0).all() and (hc > 0).all()) + area_enclose = wc * hc + giou = iou - (area_enclose - wh) / area_enclose + giou = (giou + 1.)/2.0 # resize from (-1,1) to (0,1) + return giou + + +def diou_batch(bboxes1, bboxes2): + """ + :param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2) + :param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2) + :return: + """ + # for details should go to https://arxiv.org/pdf/1902.09630.pdf + # ensure predict's bbox form + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + # calculate the intersection box + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0., xx2 - xx1) + h = np.maximum(0., yy2 - yy1) + wh = w * h + iou = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh) + + centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0 + centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0 + centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0 + centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0 + + inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2 + + xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0]) + yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1]) + xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2]) + yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3]) + + outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2 + diou = iou - inner_diag / outer_diag + + return (diou + 1) / 2.0 # resize from (-1,1) to (0,1) + +def ciou_batch(bboxes1, bboxes2): + """ + :param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2) + :param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2) + :return: + """ + # for details should go to https://arxiv.org/pdf/1902.09630.pdf + # ensure predict's bbox form + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + # calculate the intersection box + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0., xx2 - xx1) + h = np.maximum(0., yy2 - yy1) + wh = w * h + iou = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh) + + centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0 + centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0 + centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0 + centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0 + + inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2 + + xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0]) + yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1]) + xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2]) + yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3]) + + outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2 + + w1 = bboxes1[..., 2] - bboxes1[..., 0] + h1 = bboxes1[..., 3] - bboxes1[..., 1] + w2 = bboxes2[..., 2] - bboxes2[..., 0] + h2 = bboxes2[..., 3] - bboxes2[..., 1] + + # prevent dividing over zero. add one pixel shift + h2 = h2 + 1. + h1 = h1 + 1. + arctan = np.arctan(w2/h2) - np.arctan(w1/h1) + v = (4 / (np.pi ** 2)) * (arctan ** 2) + S = 1 - iou + alpha = v / (S+v) + ciou = iou - inner_diag / outer_diag - alpha * v + + return (ciou + 1) / 2.0 # resize from (-1,1) to (0,1) + + +def ct_dist(bboxes1, bboxes2): + """ + Measure the center distance between two sets of bounding boxes, + this is a coarse implementation, we don't recommend using it only + for association, which can be unstable and sensitive to frame rate + and object speed. + """ + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0 + centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0 + centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0 + centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0 + + ct_dist2 = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2 + + ct_dist = np.sqrt(ct_dist2) + + # The linear rescaling is a naive version and needs more study + ct_dist = ct_dist / ct_dist.max() + return ct_dist.max() - ct_dist # resize to (0,1) + + + +def speed_direction_batch(dets, tracks): + tracks = tracks[..., np.newaxis] + CX1, CY1 = (dets[:,0] + dets[:,2])/2.0, (dets[:,1]+dets[:,3])/2.0 + CX2, CY2 = (tracks[:,0] + tracks[:,2]) /2.0, (tracks[:,1]+tracks[:,3])/2.0 + dx = CX1 - CX2 + dy = CY1 - CY2 + norm = np.sqrt(dx**2 + dy**2) + 1e-6 + dx = dx / norm + dy = dy / norm + return dy, dx # size: num_track x num_det + + +def linear_assignment(cost_matrix): + try: + import lap + _, x, y = lap.lapjv(cost_matrix, extend_cost=True) + return np.array([[y[i],i] for i in x if i >= 0]) # + except ImportError: + from scipy.optimize import linear_sum_assignment + x, y = linear_sum_assignment(cost_matrix) + return np.array(list(zip(x, y))) + + +def associate_detections_to_trackers(detections,trackers, iou_threshold = 0.3): + """ + Assigns detections to tracked object (both represented as bounding boxes) + Returns 3 lists of matches, unmatched_detections and unmatched_trackers + """ + if(len(trackers)==0): + return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0,5),dtype=int) + + iou_matrix = iou_batch(detections, trackers) + + if min(iou_matrix.shape) > 0: + a = (iou_matrix > iou_threshold).astype(np.int32) + if a.sum(1).max() == 1 and a.sum(0).max() == 1: + matched_indices = np.stack(np.where(a), axis=1) + else: + matched_indices = linear_assignment(-iou_matrix) + else: + matched_indices = np.empty(shape=(0,2)) + + unmatched_detections = [] + for d, det in enumerate(detections): + if(d not in matched_indices[:,0]): + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(trackers): + if(t not in matched_indices[:,1]): + unmatched_trackers.append(t) + + #filter out matched with low IOU + matches = [] + for m in matched_indices: + if(iou_matrix[m[0], m[1]] 0: + a = (iou_matrix > iou_threshold).astype(np.int32) + if a.sum(1).max() == 1 and a.sum(0).max() == 1: + matched_indices = np.stack(np.where(a), axis=1) + else: + matched_indices = linear_assignment(-(iou_matrix+angle_diff_cost)) + else: + matched_indices = np.empty(shape=(0,2)) + + unmatched_detections = [] + for d, det in enumerate(detections): + if(d not in matched_indices[:,0]): + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(trackers): + if(t not in matched_indices[:,1]): + unmatched_trackers.append(t) + + # filter out matched with low IOU + matches = [] + for m in matched_indices: + if(iou_matrix[m[0], m[1]] 0: + a = (iou_matrix > iou_threshold).astype(np.int32) + if a.sum(1).max() == 1 and a.sum(0).max() == 1: + matched_indices = np.stack(np.where(a), axis=1) + else: + matched_indices = linear_assignment(cost_matrix) + else: + matched_indices = np.empty(shape=(0,2)) + + unmatched_detections = [] + for d, det in enumerate(detections): + if(d not in matched_indices[:,0]): + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(trackers): + if(t not in matched_indices[:,1]): + unmatched_trackers.append(t) + + #filter out matched with low IOU + matches = [] + for m in matched_indices: + if(iou_matrix[m[0], m[1]]update cycle. The +predict step, implemented with the method or function predict(), +uses the state transition matrix F to predict the state in the next +time period (epoch). The state is stored as a gaussian (x, P), where +x is the state (column) vector, and P is its covariance. Covariance +matrix Q specifies the process covariance. In Bayesian terms, this +prediction is called the *prior*, which you can think of colloquially +as the estimate prior to incorporating the measurement. +The update step, implemented with the method or function `update()`, +incorporates the measurement z with covariance R, into the state +estimate (x, P). The class stores the system uncertainty in S, +the innovation (residual between prediction and measurement in +measurement space) in y, and the Kalman gain in k. The procedural +form returns these variables to you. In Bayesian terms this computes +the *posterior* - the estimate after the information from the +measurement is incorporated. +Whether you use the OO form or procedural form is up to you. If +matrices such as H, R, and F are changing each epoch, you'll probably +opt to use the procedural form. If they are unchanging, the OO +form is perhaps easier to use since you won't need to keep track +of these matrices. This is especially useful if you are implementing +banks of filters or comparing various KF designs for performance; +a trivial coding bug could lead to using the wrong sets of matrices. +This module also offers an implementation of the RTS smoother, and +other helper functions, such as log likelihood computations. +The Saver class allows you to easily save the state of the +KalmanFilter class after every update +This module expects NumPy arrays for all values that expect +arrays, although in a few cases, particularly method parameters, +it will accept types that convert to NumPy arrays, such as lists +of lists. These exceptions are documented in the method or function. +Examples +-------- +The following example constructs a constant velocity kinematic +filter, filters noisy data, and plots the results. It also demonstrates +using the Saver class to save the state of the filter at each epoch. +.. code-block:: Python + import matplotlib.pyplot as plt + import numpy as np + from filterpy.kalman import KalmanFilter + from filterpy.common import Q_discrete_white_noise, Saver + r_std, q_std = 2., 0.003 + cv = KalmanFilter(dim_x=2, dim_z=1) + cv.x = np.array([[0., 1.]]) # position, velocity + cv.F = np.array([[1, dt],[ [0, 1]]) + cv.R = np.array([[r_std^^2]]) + f.H = np.array([[1., 0.]]) + f.P = np.diag([.1^^2, .03^^2) + f.Q = Q_discrete_white_noise(2, dt, q_std**2) + saver = Saver(cv) + for z in range(100): + cv.predict() + cv.update([z + randn() * r_std]) + saver.save() # save the filter's state + saver.to_array() + plt.plot(saver.x[:, 0]) + # plot all of the priors + plt.plot(saver.x_prior[:, 0]) + # plot mahalanobis distance + plt.figure() + plt.plot(saver.mahalanobis) +This code implements the same filter using the procedural form + x = np.array([[0., 1.]]) # position, velocity + F = np.array([[1, dt],[ [0, 1]]) + R = np.array([[r_std^^2]]) + H = np.array([[1., 0.]]) + P = np.diag([.1^^2, .03^^2) + Q = Q_discrete_white_noise(2, dt, q_std**2) + for z in range(100): + x, P = predict(x, P, F=F, Q=Q) + x, P = update(x, P, z=[z + randn() * r_std], R=R, H=H) + xs.append(x[0, 0]) + plt.plot(xs) +For more examples see the test subdirectory, or refer to the +book cited below. In it I both teach Kalman filtering from basic +principles, and teach the use of this library in great detail. +FilterPy library. +http://github.com/rlabbe/filterpy +Documentation at: +https://filterpy.readthedocs.org +Supporting book at: +https://github.com/rlabbe/Kalman-and-Bayesian-Filters-in-Python +This is licensed under an MIT license. See the readme.MD file +for more information. +Copyright 2014-2018 Roger R Labbe Jr. +""" + +from __future__ import absolute_import, division + +from copy import deepcopy +from math import log, exp, sqrt +import sys +import numpy as np +from numpy import dot, zeros, eye, isscalar, shape +import numpy.linalg as linalg +from filterpy.stats import logpdf +from filterpy.common import pretty_str, reshape_z + + +class KalmanFilterNew(object): + """ Implements a Kalman filter. You are responsible for setting the + various state variables to reasonable values; the defaults will + not give you a functional filter. + For now the best documentation is my free book Kalman and Bayesian + Filters in Python [2]_. The test files in this directory also give you a + basic idea of use, albeit without much description. + In brief, you will first construct this object, specifying the size of + the state vector with dim_x and the size of the measurement vector that + you will be using with dim_z. These are mostly used to perform size checks + when you assign values to the various matrices. For example, if you + specified dim_z=2 and then try to assign a 3x3 matrix to R (the + measurement noise matrix you will get an assert exception because R + should be 2x2. (If for whatever reason you need to alter the size of + things midstream just use the underscore version of the matrices to + assign directly: your_filter._R = a_3x3_matrix.) + After construction the filter will have default matrices created for you, + but you must specify the values for each. It’s usually easiest to just + overwrite them rather than assign to each element yourself. This will be + clearer in the example below. All are of type numpy.array. + Examples + -------- + Here is a filter that tracks position and velocity using a sensor that only + reads position. + First construct the object with the required dimensionality. Here the state + (`dim_x`) has 2 coefficients (position and velocity), and the measurement + (`dim_z`) has one. In FilterPy `x` is the state, `z` is the measurement. + .. code:: + from filterpy.kalman import KalmanFilter + f = KalmanFilter (dim_x=2, dim_z=1) + Assign the initial value for the state (position and velocity). You can do this + with a two dimensional array like so: + .. code:: + f.x = np.array([[2.], # position + [0.]]) # velocity + or just use a one dimensional array, which I prefer doing. + .. code:: + f.x = np.array([2., 0.]) + Define the state transition matrix: + .. code:: + f.F = np.array([[1.,1.], + [0.,1.]]) + Define the measurement function. Here we need to convert a position-velocity + vector into just a position vector, so we use: + .. code:: + f.H = np.array([[1., 0.]]) + Define the state's covariance matrix P. + .. code:: + f.P = np.array([[1000., 0.], + [ 0., 1000.] ]) + Now assign the measurement noise. Here the dimension is 1x1, so I can + use a scalar + .. code:: + f.R = 5 + I could have done this instead: + .. code:: + f.R = np.array([[5.]]) + Note that this must be a 2 dimensional array. + Finally, I will assign the process noise. Here I will take advantage of + another FilterPy library function: + .. code:: + from filterpy.common import Q_discrete_white_noise + f.Q = Q_discrete_white_noise(dim=2, dt=0.1, var=0.13) + Now just perform the standard predict/update loop: + .. code:: + while some_condition_is_true: + z = get_sensor_reading() + f.predict() + f.update(z) + do_something_with_estimate (f.x) + **Procedural Form** + This module also contains stand alone functions to perform Kalman filtering. + Use these if you are not a fan of objects. + **Example** + .. code:: + while True: + z, R = read_sensor() + x, P = predict(x, P, F, Q) + x, P = update(x, P, z, R, H) + See my book Kalman and Bayesian Filters in Python [2]_. + You will have to set the following attributes after constructing this + object for the filter to perform properly. Please note that there are + various checks in place to ensure that you have made everything the + 'correct' size. However, it is possible to provide incorrectly sized + arrays such that the linear algebra can not perform an operation. + It can also fail silently - you can end up with matrices of a size that + allows the linear algebra to work, but are the wrong shape for the problem + you are trying to solve. + Parameters + ---------- + dim_x : int + Number of state variables for the Kalman filter. For example, if + you are tracking the position and velocity of an object in two + dimensions, dim_x would be 4. + This is used to set the default size of P, Q, and u + dim_z : int + Number of of measurement inputs. For example, if the sensor + provides you with position in (x,y), dim_z would be 2. + dim_u : int (optional) + size of the control input, if it is being used. + Default value of 0 indicates it is not used. + compute_log_likelihood : bool (default = True) + Computes log likelihood by default, but this can be a slow + computation, so if you never use it you can turn this computation + off. + Attributes + ---------- + x : numpy.array(dim_x, 1) + Current state estimate. Any call to update() or predict() updates + this variable. + P : numpy.array(dim_x, dim_x) + Current state covariance matrix. Any call to update() or predict() + updates this variable. + x_prior : numpy.array(dim_x, 1) + Prior (predicted) state estimate. The *_prior and *_post attributes + are for convenience; they store the prior and posterior of the + current epoch. Read Only. + P_prior : numpy.array(dim_x, dim_x) + Prior (predicted) state covariance matrix. Read Only. + x_post : numpy.array(dim_x, 1) + Posterior (updated) state estimate. Read Only. + P_post : numpy.array(dim_x, dim_x) + Posterior (updated) state covariance matrix. Read Only. + z : numpy.array + Last measurement used in update(). Read only. + R : numpy.array(dim_z, dim_z) + Measurement noise covariance matrix. Also known as the + observation covariance. + Q : numpy.array(dim_x, dim_x) + Process noise covariance matrix. Also known as the transition + covariance. + F : numpy.array() + State Transition matrix. Also known as `A` in some formulation. + H : numpy.array(dim_z, dim_x) + Measurement function. Also known as the observation matrix, or as `C`. + y : numpy.array + Residual of the update step. Read only. + K : numpy.array(dim_x, dim_z) + Kalman gain of the update step. Read only. + S : numpy.array + System uncertainty (P projected to measurement space). Read only. + SI : numpy.array + Inverse system uncertainty. Read only. + log_likelihood : float + log-likelihood of the last measurement. Read only. + likelihood : float + likelihood of last measurement. Read only. + Computed from the log-likelihood. The log-likelihood can be very + small, meaning a large negative value such as -28000. Taking the + exp() of that results in 0.0, which can break typical algorithms + which multiply by this value, so by default we always return a + number >= sys.float_info.min. + mahalanobis : float + mahalanobis distance of the innovation. Read only. + inv : function, default numpy.linalg.inv + If you prefer another inverse function, such as the Moore-Penrose + pseudo inverse, set it to that instead: kf.inv = np.linalg.pinv + This is only used to invert self.S. If you know it is diagonal, you + might choose to set it to filterpy.common.inv_diagonal, which is + several times faster than numpy.linalg.inv for diagonal matrices. + alpha : float + Fading memory setting. 1.0 gives the normal Kalman filter, and + values slightly larger than 1.0 (such as 1.02) give a fading + memory effect - previous measurements have less influence on the + filter's estimates. This formulation of the Fading memory filter + (there are many) is due to Dan Simon [1]_. + References + ---------- + .. [1] Dan Simon. "Optimal State Estimation." John Wiley & Sons. + p. 208-212. (2006) + .. [2] Roger Labbe. "Kalman and Bayesian Filters in Python" + https://github.com/rlabbe/Kalman-and-Bayesian-Filters-in-Python + """ + + def __init__(self, dim_x, dim_z, dim_u=0): + if dim_x < 1: + raise ValueError('dim_x must be 1 or greater') + if dim_z < 1: + raise ValueError('dim_z must be 1 or greater') + if dim_u < 0: + raise ValueError('dim_u must be 0 or greater') + + self.dim_x = dim_x + self.dim_z = dim_z + self.dim_u = dim_u + + self.x = zeros((dim_x, 1)) # state + self.P = eye(dim_x) # uncertainty covariance + self.Q = eye(dim_x) # process uncertainty + self.B = None # control transition matrix + self.F = eye(dim_x) # state transition matrix + self.H = zeros((dim_z, dim_x)) # measurement function + self.R = eye(dim_z) # measurement uncertainty + self._alpha_sq = 1. # fading memory control + self.M = np.zeros((dim_x, dim_z)) # process-measurement cross correlation + self.z = np.array([[None]*self.dim_z]).T + + # gain and residual are computed during the innovation step. We + # save them so that in case you want to inspect them for various + # purposes + self.K = np.zeros((dim_x, dim_z)) # kalman gain + self.y = zeros((dim_z, 1)) + self.S = np.zeros((dim_z, dim_z)) # system uncertainty + self.SI = np.zeros((dim_z, dim_z)) # inverse system uncertainty + + # identity matrix. Do not alter this. + self._I = np.eye(dim_x) + + # these will always be a copy of x,P after predict() is called + self.x_prior = self.x.copy() + self.P_prior = self.P.copy() + + # these will always be a copy of x,P after update() is called + self.x_post = self.x.copy() + self.P_post = self.P.copy() + + # Only computed only if requested via property + self._log_likelihood = log(sys.float_info.min) + self._likelihood = sys.float_info.min + self._mahalanobis = None + + # keep all observations + self.history_obs = [] + + self.inv = np.linalg.inv + + self.attr_saved = None + self.observed = False + + + def predict(self, u=None, B=None, F=None, Q=None): + """ + Predict next state (prior) using the Kalman filter state propagation + equations. + Parameters + ---------- + u : np.array, default 0 + Optional control vector. + B : np.array(dim_x, dim_u), or None + Optional control transition matrix; a value of None + will cause the filter to use `self.B`. + F : np.array(dim_x, dim_x), or None + Optional state transition matrix; a value of None + will cause the filter to use `self.F`. + Q : np.array(dim_x, dim_x), scalar, or None + Optional process noise matrix; a value of None will cause the + filter to use `self.Q`. + """ + + if B is None: + B = self.B + if F is None: + F = self.F + if Q is None: + Q = self.Q + elif isscalar(Q): + Q = eye(self.dim_x) * Q + + + # x = Fx + Bu + if B is not None and u is not None: + self.x = dot(F, self.x) + dot(B, u) + else: + self.x = dot(F, self.x) + + # P = FPF' + Q + self.P = self._alpha_sq * dot(dot(F, self.P), F.T) + Q + + # save prior + self.x_prior = self.x.copy() + self.P_prior = self.P.copy() + + + + def freeze(self): + """ + Save the parameters before non-observation forward + """ + self.attr_saved = deepcopy(self.__dict__) + + + def unfreeze(self): + if self.attr_saved is not None: + new_history = deepcopy(self.history_obs) + self.__dict__ = self.attr_saved + # self.history_obs = new_history + self.history_obs = self.history_obs[:-1] + occur = [int(d is None) for d in new_history] + indices = np.where(np.array(occur)==0)[0] + index1 = indices[-2] + index2 = indices[-1] + box1 = new_history[index1] + x1, y1, s1, r1 = box1 + w1 = np.sqrt(s1 * r1) + h1 = np.sqrt(s1 / r1) + box2 = new_history[index2] + x2, y2, s2, r2 = box2 + w2 = np.sqrt(s2 * r2) + h2 = np.sqrt(s2 / r2) + time_gap = index2 - index1 + dx = (x2-x1)/time_gap + dy = (y2-y1)/time_gap + dw = (w2-w1)/time_gap + dh = (h2-h1)/time_gap + for i in range(index2 - index1): + """ + The default virtual trajectory generation is by linear + motion (constant speed hypothesis), you could modify this + part to implement your own. + """ + x = x1 + (i+1) * dx + y = y1 + (i+1) * dy + w = w1 + (i+1) * dw + h = h1 + (i+1) * dh + s = w * h + r = w / float(h) + new_box = np.array([x, y, s, r]).reshape((4, 1)) + """ + I still use predict-update loop here to refresh the parameters, + but this can be faster by directly modifying the internal parameters + as suggested in the paper. I keep this naive but slow way for + easy read and understanding + """ + self.update(new_box) + if not i == (index2-index1-1): + self.predict() + + + def update(self, z, R=None, H=None): + """ + Add a new measurement (z) to the Kalman filter. + If z is None, nothing is computed. However, x_post and P_post are + updated with the prior (x_prior, P_prior), and self.z is set to None. + Parameters + ---------- + z : (dim_z, 1): array_like + measurement for this update. z can be a scalar if dim_z is 1, + otherwise it must be convertible to a column vector. + If you pass in a value of H, z must be a column vector the + of the correct size. + R : np.array, scalar, or None + Optionally provide R to override the measurement noise for this + one call, otherwise self.R will be used. + H : np.array, or None + Optionally provide H to override the measurement function for this + one call, otherwise self.H will be used. + """ + + # set to None to force recompute + self._log_likelihood = None + self._likelihood = None + self._mahalanobis = None + + # append the observation + self.history_obs.append(z) + + if z is None: + if self.observed: + """ + Got no observation so freeze the current parameters for future + potential online smoothing. + """ + self.freeze() + self.observed = False + self.z = np.array([[None]*self.dim_z]).T + self.x_post = self.x.copy() + self.P_post = self.P.copy() + self.y = zeros((self.dim_z, 1)) + return + + # self.observed = True + if not self.observed: + """ + Get observation, use online smoothing to re-update parameters + """ + self.unfreeze() + self.observed = True + + if R is None: + R = self.R + elif isscalar(R): + R = eye(self.dim_z) * R + + if H is None: + z = reshape_z(z, self.dim_z, self.x.ndim) + H = self.H + + # y = z - Hx + # error (residual) between measurement and prediction + self.y = z - dot(H, self.x) + + # common subexpression for speed + PHT = dot(self.P, H.T) + + # S = HPH' + R + # project system uncertainty into measurement space + self.S = dot(H, PHT) + R + self.SI = self.inv(self.S) + # K = PH'inv(S) + # map system uncertainty into kalman gain + self.K = dot(PHT, self.SI) + + # x = x + Ky + # predict new x with residual scaled by the kalman gain + self.x = self.x + dot(self.K, self.y) + + # P = (I-KH)P(I-KH)' + KRK' + # This is more numerically stable + # and works for non-optimal K vs the equation + # P = (I-KH)P usually seen in the literature. + + I_KH = self._I - dot(self.K, H) + self.P = dot(dot(I_KH, self.P), I_KH.T) + dot(dot(self.K, R), self.K.T) + + # save measurement and posterior state + self.z = deepcopy(z) + self.x_post = self.x.copy() + self.P_post = self.P.copy() + + def predict_steadystate(self, u=0, B=None): + """ + Predict state (prior) using the Kalman filter state propagation + equations. Only x is updated, P is left unchanged. See + update_steadstate() for a longer explanation of when to use this + method. + Parameters + ---------- + u : np.array + Optional control vector. If non-zero, it is multiplied by B + to create the control input into the system. + B : np.array(dim_x, dim_u), or None + Optional control transition matrix; a value of None + will cause the filter to use `self.B`. + """ + + if B is None: + B = self.B + + # x = Fx + Bu + if B is not None: + self.x = dot(self.F, self.x) + dot(B, u) + else: + self.x = dot(self.F, self.x) + + # save prior + self.x_prior = self.x.copy() + self.P_prior = self.P.copy() + + def update_steadystate(self, z): + """ + Add a new measurement (z) to the Kalman filter without recomputing + the Kalman gain K, the state covariance P, or the system + uncertainty S. + You can use this for LTI systems since the Kalman gain and covariance + converge to a fixed value. Precompute these and assign them explicitly, + or run the Kalman filter using the normal predict()/update(0 cycle + until they converge. + The main advantage of this call is speed. We do significantly less + computation, notably avoiding a costly matrix inversion. + Use in conjunction with predict_steadystate(), otherwise P will grow + without bound. + Parameters + ---------- + z : (dim_z, 1): array_like + measurement for this update. z can be a scalar if dim_z is 1, + otherwise it must be convertible to a column vector. + Examples + -------- + >>> cv = kinematic_kf(dim=3, order=2) # 3D const velocity filter + >>> # let filter converge on representative data, then save k and P + >>> for i in range(100): + >>> cv.predict() + >>> cv.update([i, i, i]) + >>> saved_k = np.copy(cv.K) + >>> saved_P = np.copy(cv.P) + later on: + >>> cv = kinematic_kf(dim=3, order=2) # 3D const velocity filter + >>> cv.K = np.copy(saved_K) + >>> cv.P = np.copy(saved_P) + >>> for i in range(100): + >>> cv.predict_steadystate() + >>> cv.update_steadystate([i, i, i]) + """ + + # set to None to force recompute + self._log_likelihood = None + self._likelihood = None + self._mahalanobis = None + + if z is None: + self.z = np.array([[None]*self.dim_z]).T + self.x_post = self.x.copy() + self.P_post = self.P.copy() + self.y = zeros((self.dim_z, 1)) + return + + z = reshape_z(z, self.dim_z, self.x.ndim) + + # y = z - Hx + # error (residual) between measurement and prediction + self.y = z - dot(self.H, self.x) + + # x = x + Ky + # predict new x with residual scaled by the kalman gain + self.x = self.x + dot(self.K, self.y) + + self.z = deepcopy(z) + self.x_post = self.x.copy() + self.P_post = self.P.copy() + + # set to None to force recompute + self._log_likelihood = None + self._likelihood = None + self._mahalanobis = None + + def update_correlated(self, z, R=None, H=None): + """ Add a new measurement (z) to the Kalman filter assuming that + process noise and measurement noise are correlated as defined in + the `self.M` matrix. + A partial derivation can be found in [1] + If z is None, nothing is changed. + Parameters + ---------- + z : (dim_z, 1): array_like + measurement for this update. z can be a scalar if dim_z is 1, + otherwise it must be convertible to a column vector. + R : np.array, scalar, or None + Optionally provide R to override the measurement noise for this + one call, otherwise self.R will be used. + H : np.array, or None + Optionally provide H to override the measurement function for this + one call, otherwise self.H will be used. + References + ---------- + .. [1] Bulut, Y. (2011). Applied Kalman filter theory (Doctoral dissertation, Northeastern University). + http://people.duke.edu/~hpgavin/SystemID/References/Balut-KalmanFilter-PhD-NEU-2011.pdf + """ + + # set to None to force recompute + self._log_likelihood = None + self._likelihood = None + self._mahalanobis = None + + if z is None: + self.z = np.array([[None]*self.dim_z]).T + self.x_post = self.x.copy() + self.P_post = self.P.copy() + self.y = zeros((self.dim_z, 1)) + return + + if R is None: + R = self.R + elif isscalar(R): + R = eye(self.dim_z) * R + + # rename for readability and a tiny extra bit of speed + if H is None: + z = reshape_z(z, self.dim_z, self.x.ndim) + H = self.H + + # handle special case: if z is in form [[z]] but x is not a column + # vector dimensions will not match + if self.x.ndim == 1 and shape(z) == (1, 1): + z = z[0] + + if shape(z) == (): # is it scalar, e.g. z=3 or z=np.array(3) + z = np.asarray([z]) + + # y = z - Hx + # error (residual) between measurement and prediction + self.y = z - dot(H, self.x) + + # common subexpression for speed + PHT = dot(self.P, H.T) + + # project system uncertainty into measurement space + self.S = dot(H, PHT) + dot(H, self.M) + dot(self.M.T, H.T) + R + self.SI = self.inv(self.S) + + # K = PH'inv(S) + # map system uncertainty into kalman gain + self.K = dot(PHT + self.M, self.SI) + + # x = x + Ky + # predict new x with residual scaled by the kalman gain + self.x = self.x + dot(self.K, self.y) + self.P = self.P - dot(self.K, dot(H, self.P) + self.M.T) + + self.z = deepcopy(z) + self.x_post = self.x.copy() + self.P_post = self.P.copy() + + def batch_filter(self, zs, Fs=None, Qs=None, Hs=None, + Rs=None, Bs=None, us=None, update_first=False, + saver=None): + """ Batch processes a sequences of measurements. + Parameters + ---------- + zs : list-like + list of measurements at each time step `self.dt`. Missing + measurements must be represented by `None`. + Fs : None, list-like, default=None + optional value or list of values to use for the state transition + matrix F. + If Fs is None then self.F is used for all epochs. + Otherwise it must contain a list-like list of F's, one for + each epoch. This allows you to have varying F per epoch. + Qs : None, np.array or list-like, default=None + optional value or list of values to use for the process error + covariance Q. + If Qs is None then self.Q is used for all epochs. + Otherwise it must contain a list-like list of Q's, one for + each epoch. This allows you to have varying Q per epoch. + Hs : None, np.array or list-like, default=None + optional list of values to use for the measurement matrix H. + If Hs is None then self.H is used for all epochs. + If Hs contains a single matrix, then it is used as H for all + epochs. + Otherwise it must contain a list-like list of H's, one for + each epoch. This allows you to have varying H per epoch. + Rs : None, np.array or list-like, default=None + optional list of values to use for the measurement error + covariance R. + If Rs is None then self.R is used for all epochs. + Otherwise it must contain a list-like list of R's, one for + each epoch. This allows you to have varying R per epoch. + Bs : None, np.array or list-like, default=None + optional list of values to use for the control transition matrix B. + If Bs is None then self.B is used for all epochs. + Otherwise it must contain a list-like list of B's, one for + each epoch. This allows you to have varying B per epoch. + us : None, np.array or list-like, default=None + optional list of values to use for the control input vector; + If us is None then None is used for all epochs (equivalent to 0, + or no control input). + Otherwise it must contain a list-like list of u's, one for + each epoch. + update_first : bool, optional, default=False + controls whether the order of operations is update followed by + predict, or predict followed by update. Default is predict->update. + saver : filterpy.common.Saver, optional + filterpy.common.Saver object. If provided, saver.save() will be + called after every epoch + Returns + ------- + means : np.array((n,dim_x,1)) + array of the state for each time step after the update. Each entry + is an np.array. In other words `means[k,:]` is the state at step + `k`. + covariance : np.array((n,dim_x,dim_x)) + array of the covariances for each time step after the update. + In other words `covariance[k,:,:]` is the covariance at step `k`. + means_predictions : np.array((n,dim_x,1)) + array of the state for each time step after the predictions. Each + entry is an np.array. In other words `means[k,:]` is the state at + step `k`. + covariance_predictions : np.array((n,dim_x,dim_x)) + array of the covariances for each time step after the prediction. + In other words `covariance[k,:,:]` is the covariance at step `k`. + Examples + -------- + .. code-block:: Python + # this example demonstrates tracking a measurement where the time + # between measurement varies, as stored in dts. This requires + # that F be recomputed for each epoch. The output is then smoothed + # with an RTS smoother. + zs = [t + random.randn()*4 for t in range (40)] + Fs = [np.array([[1., dt], [0, 1]] for dt in dts] + (mu, cov, _, _) = kf.batch_filter(zs, Fs=Fs) + (xs, Ps, Ks, Pps) = kf.rts_smoother(mu, cov, Fs=Fs) + """ + + #pylint: disable=too-many-statements + n = np.size(zs, 0) + if Fs is None: + Fs = [self.F] * n + if Qs is None: + Qs = [self.Q] * n + if Hs is None: + Hs = [self.H] * n + if Rs is None: + Rs = [self.R] * n + if Bs is None: + Bs = [self.B] * n + if us is None: + us = [0] * n + + # mean estimates from Kalman Filter + if self.x.ndim == 1: + means = zeros((n, self.dim_x)) + means_p = zeros((n, self.dim_x)) + else: + means = zeros((n, self.dim_x, 1)) + means_p = zeros((n, self.dim_x, 1)) + + # state covariances from Kalman Filter + covariances = zeros((n, self.dim_x, self.dim_x)) + covariances_p = zeros((n, self.dim_x, self.dim_x)) + + if update_first: + for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)): + + self.update(z, R=R, H=H) + means[i, :] = self.x + covariances[i, :, :] = self.P + + self.predict(u=u, B=B, F=F, Q=Q) + means_p[i, :] = self.x + covariances_p[i, :, :] = self.P + + if saver is not None: + saver.save() + else: + for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)): + + self.predict(u=u, B=B, F=F, Q=Q) + means_p[i, :] = self.x + covariances_p[i, :, :] = self.P + + self.update(z, R=R, H=H) + means[i, :] = self.x + covariances[i, :, :] = self.P + + if saver is not None: + saver.save() + + return (means, covariances, means_p, covariances_p) + + def rts_smoother(self, Xs, Ps, Fs=None, Qs=None, inv=np.linalg.inv): + """ + Runs the Rauch-Tung-Striebel Kalman smoother on a set of + means and covariances computed by a Kalman filter. The usual input + would come from the output of `KalmanFilter.batch_filter()`. + Parameters + ---------- + Xs : numpy.array + array of the means (state variable x) of the output of a Kalman + filter. + Ps : numpy.array + array of the covariances of the output of a kalman filter. + Fs : list-like collection of numpy.array, optional + State transition matrix of the Kalman filter at each time step. + Optional, if not provided the filter's self.F will be used + Qs : list-like collection of numpy.array, optional + Process noise of the Kalman filter at each time step. Optional, + if not provided the filter's self.Q will be used + inv : function, default numpy.linalg.inv + If you prefer another inverse function, such as the Moore-Penrose + pseudo inverse, set it to that instead: kf.inv = np.linalg.pinv + Returns + ------- + x : numpy.ndarray + smoothed means + P : numpy.ndarray + smoothed state covariances + K : numpy.ndarray + smoother gain at each step + Pp : numpy.ndarray + Predicted state covariances + Examples + -------- + .. code-block:: Python + zs = [t + random.randn()*4 for t in range (40)] + (mu, cov, _, _) = kalman.batch_filter(zs) + (x, P, K, Pp) = rts_smoother(mu, cov, kf.F, kf.Q) + """ + + if len(Xs) != len(Ps): + raise ValueError('length of Xs and Ps must be the same') + + n = Xs.shape[0] + dim_x = Xs.shape[1] + + if Fs is None: + Fs = [self.F] * n + if Qs is None: + Qs = [self.Q] * n + + # smoother gain + K = zeros((n, dim_x, dim_x)) + + x, P, Pp = Xs.copy(), Ps.copy(), Ps.copy() + for k in range(n-2, -1, -1): + Pp[k] = dot(dot(Fs[k+1], P[k]), Fs[k+1].T) + Qs[k+1] + + #pylint: disable=bad-whitespace + K[k] = dot(dot(P[k], Fs[k+1].T), inv(Pp[k])) + x[k] += dot(K[k], x[k+1] - dot(Fs[k+1], x[k])) + P[k] += dot(dot(K[k], P[k+1] - Pp[k]), K[k].T) + + return (x, P, K, Pp) + + def get_prediction(self, u=None, B=None, F=None, Q=None): + """ + Predict next state (prior) using the Kalman filter state propagation + equations and returns it without modifying the object. + Parameters + ---------- + u : np.array, default 0 + Optional control vector. + B : np.array(dim_x, dim_u), or None + Optional control transition matrix; a value of None + will cause the filter to use `self.B`. + F : np.array(dim_x, dim_x), or None + Optional state transition matrix; a value of None + will cause the filter to use `self.F`. + Q : np.array(dim_x, dim_x), scalar, or None + Optional process noise matrix; a value of None will cause the + filter to use `self.Q`. + Returns + ------- + (x, P) : tuple + State vector and covariance array of the prediction. + """ + + if B is None: + B = self.B + if F is None: + F = self.F + if Q is None: + Q = self.Q + elif isscalar(Q): + Q = eye(self.dim_x) * Q + + # x = Fx + Bu + if B is not None and u is not None: + x = dot(F, self.x) + dot(B, u) + else: + x = dot(F, self.x) + + # P = FPF' + Q + P = self._alpha_sq * dot(dot(F, self.P), F.T) + Q + + return x, P + + def get_update(self, z=None): + """ + Computes the new estimate based on measurement `z` and returns it + without altering the state of the filter. + Parameters + ---------- + z : (dim_z, 1): array_like + measurement for this update. z can be a scalar if dim_z is 1, + otherwise it must be convertible to a column vector. + Returns + ------- + (x, P) : tuple + State vector and covariance array of the update. + """ + + if z is None: + return self.x, self.P + z = reshape_z(z, self.dim_z, self.x.ndim) + + R = self.R + H = self.H + P = self.P + x = self.x + + # error (residual) between measurement and prediction + y = z - dot(H, x) + + # common subexpression for speed + PHT = dot(P, H.T) + + # project system uncertainty into measurement space + S = dot(H, PHT) + R + + # map system uncertainty into kalman gain + K = dot(PHT, self.inv(S)) + + # predict new x with residual scaled by the kalman gain + x = x + dot(K, y) + + # P = (I-KH)P(I-KH)' + KRK' + I_KH = self._I - dot(K, H) + P = dot(dot(I_KH, P), I_KH.T) + dot(dot(K, R), K.T) + + return x, P + + def residual_of(self, z): + """ + Returns the residual for the given measurement (z). Does not alter + the state of the filter. + """ + z = reshape_z(z, self.dim_z, self.x.ndim) + return z - dot(self.H, self.x_prior) + + def measurement_of_state(self, x): + """ + Helper function that converts a state into a measurement. + Parameters + ---------- + x : np.array + kalman state vector + Returns + ------- + z : (dim_z, 1): array_like + measurement for this update. z can be a scalar if dim_z is 1, + otherwise it must be convertible to a column vector. + """ + + return dot(self.H, x) + + @property + def log_likelihood(self): + """ + log-likelihood of the last measurement. + """ + if self._log_likelihood is None: + self._log_likelihood = logpdf(x=self.y, cov=self.S) + return self._log_likelihood + + @property + def likelihood(self): + """ + Computed from the log-likelihood. The log-likelihood can be very + small, meaning a large negative value such as -28000. Taking the + exp() of that results in 0.0, which can break typical algorithms + which multiply by this value, so by default we always return a + number >= sys.float_info.min. + """ + if self._likelihood is None: + self._likelihood = exp(self.log_likelihood) + if self._likelihood == 0: + self._likelihood = sys.float_info.min + return self._likelihood + + @property + def mahalanobis(self): + """" + Mahalanobis distance of measurement. E.g. 3 means measurement + was 3 standard deviations away from the predicted value. + Returns + ------- + mahalanobis : float + """ + if self._mahalanobis is None: + self._mahalanobis = sqrt(float(dot(dot(self.y.T, self.SI), self.y))) + return self._mahalanobis + + @property + def alpha(self): + """ + Fading memory setting. 1.0 gives the normal Kalman filter, and + values slightly larger than 1.0 (such as 1.02) give a fading + memory effect - previous measurements have less influence on the + filter's estimates. This formulation of the Fading memory filter + (there are many) is due to Dan Simon [1]_. + """ + return self._alpha_sq**.5 + + def log_likelihood_of(self, z): + """ + log likelihood of the measurement `z`. This should only be called + after a call to update(). Calling after predict() will yield an + incorrect result.""" + + if z is None: + return log(sys.float_info.min) + return logpdf(z, dot(self.H, self.x), self.S) + + @alpha.setter + def alpha(self, value): + if not np.isscalar(value) or value < 1: + raise ValueError('alpha must be a float greater than 1') + + self._alpha_sq = value**2 + + def __repr__(self): + return '\n'.join([ + 'KalmanFilter object', + pretty_str('dim_x', self.dim_x), + pretty_str('dim_z', self.dim_z), + pretty_str('dim_u', self.dim_u), + pretty_str('x', self.x), + pretty_str('P', self.P), + pretty_str('x_prior', self.x_prior), + pretty_str('P_prior', self.P_prior), + pretty_str('x_post', self.x_post), + pretty_str('P_post', self.P_post), + pretty_str('F', self.F), + pretty_str('Q', self.Q), + pretty_str('R', self.R), + pretty_str('H', self.H), + pretty_str('K', self.K), + pretty_str('y', self.y), + pretty_str('S', self.S), + pretty_str('SI', self.SI), + pretty_str('M', self.M), + pretty_str('B', self.B), + pretty_str('z', self.z), + pretty_str('log-likelihood', self.log_likelihood), + pretty_str('likelihood', self.likelihood), + pretty_str('mahalanobis', self.mahalanobis), + pretty_str('alpha', self.alpha), + pretty_str('inv', self.inv) + ]) + + def test_matrix_dimensions(self, z=None, H=None, R=None, F=None, Q=None): + """ + Performs a series of asserts to check that the size of everything + is what it should be. This can help you debug problems in your design. + If you pass in H, R, F, Q those will be used instead of this object's + value for those matrices. + Testing `z` (the measurement) is problamatic. x is a vector, and can be + implemented as either a 1D array or as a nx1 column vector. Thus Hx + can be of different shapes. Then, if Hx is a single value, it can + be either a 1D array or 2D vector. If either is true, z can reasonably + be a scalar (either '3' or np.array('3') are scalars under this + definition), a 1D, 1 element array, or a 2D, 1 element array. You are + allowed to pass in any combination that works. + """ + + if H is None: + H = self.H + if R is None: + R = self.R + if F is None: + F = self.F + if Q is None: + Q = self.Q + x = self.x + P = self.P + + assert x.ndim == 1 or x.ndim == 2, \ + "x must have one or two dimensions, but has {}".format(x.ndim) + + if x.ndim == 1: + assert x.shape[0] == self.dim_x, \ + "Shape of x must be ({},{}), but is {}".format( + self.dim_x, 1, x.shape) + else: + assert x.shape == (self.dim_x, 1), \ + "Shape of x must be ({},{}), but is {}".format( + self.dim_x, 1, x.shape) + + assert P.shape == (self.dim_x, self.dim_x), \ + "Shape of P must be ({},{}), but is {}".format( + self.dim_x, self.dim_x, P.shape) + + assert Q.shape == (self.dim_x, self.dim_x), \ + "Shape of Q must be ({},{}), but is {}".format( + self.dim_x, self.dim_x, P.shape) + + assert F.shape == (self.dim_x, self.dim_x), \ + "Shape of F must be ({},{}), but is {}".format( + self.dim_x, self.dim_x, F.shape) + + assert np.ndim(H) == 2, \ + "Shape of H must be (dim_z, {}), but is {}".format( + P.shape[0], shape(H)) + + assert H.shape[1] == P.shape[0], \ + "Shape of H must be (dim_z, {}), but is {}".format( + P.shape[0], H.shape) + + # shape of R must be the same as HPH' + hph_shape = (H.shape[0], H.shape[0]) + r_shape = shape(R) + + if H.shape[0] == 1: + # r can be scalar, 1D, or 2D in this case + assert r_shape in [(), (1,), (1, 1)], \ + "R must be scalar or one element array, but is shaped {}".format( + r_shape) + else: + assert r_shape == hph_shape, \ + "shape of R should be {} but it is {}".format(hph_shape, r_shape) + + + if z is not None: + z_shape = shape(z) + else: + z_shape = (self.dim_z, 1) + + # H@x must have shape of z + Hx = dot(H, x) + + if z_shape == (): # scalar or np.array(scalar) + assert Hx.ndim == 1 or shape(Hx) == (1, 1), \ + "shape of z should be {}, not {} for the given H".format( + shape(Hx), z_shape) + + elif shape(Hx) == (1,): + assert z_shape[0] == 1, 'Shape of z must be {} for the given H'.format(shape(Hx)) + + else: + assert (z_shape == shape(Hx) or + (len(z_shape) == 1 and shape(Hx) == (z_shape[0], 1))), \ + "shape of z should be {}, not {} for the given H".format( + shape(Hx), z_shape) + + if np.ndim(Hx) > 1 and shape(Hx) != (1, 1): + assert shape(Hx) == z_shape, \ + 'shape of z should be {} for the given H, but it is {}'.format( + shape(Hx), z_shape) + + +def update(x, P, z, R, H=None, return_all=False): + """ + Add a new measurement (z) to the Kalman filter. If z is None, nothing + is changed. + This can handle either the multidimensional or unidimensional case. If + all parameters are floats instead of arrays the filter will still work, + and return floats for x, P as the result. + update(1, 2, 1, 1, 1) # univariate + update(x, P, 1 + Parameters + ---------- + x : numpy.array(dim_x, 1), or float + State estimate vector + P : numpy.array(dim_x, dim_x), or float + Covariance matrix + z : (dim_z, 1): array_like + measurement for this update. z can be a scalar if dim_z is 1, + otherwise it must be convertible to a column vector. + R : numpy.array(dim_z, dim_z), or float + Measurement noise matrix + H : numpy.array(dim_x, dim_x), or float, optional + Measurement function. If not provided, a value of 1 is assumed. + return_all : bool, default False + If true, y, K, S, and log_likelihood are returned, otherwise + only x and P are returned. + Returns + ------- + x : numpy.array + Posterior state estimate vector + P : numpy.array + Posterior covariance matrix + y : numpy.array or scalar + Residua. Difference between measurement and state in measurement space + K : numpy.array + Kalman gain + S : numpy.array + System uncertainty in measurement space + log_likelihood : float + log likelihood of the measurement + """ + + #pylint: disable=bare-except + + if z is None: + if return_all: + return x, P, None, None, None, None + return x, P + + if H is None: + H = np.array([1]) + + if np.isscalar(H): + H = np.array([H]) + + Hx = np.atleast_1d(dot(H, x)) + z = reshape_z(z, Hx.shape[0], x.ndim) + + # error (residual) between measurement and prediction + y = z - Hx + + # project system uncertainty into measurement space + S = dot(dot(H, P), H.T) + R + + + # map system uncertainty into kalman gain + try: + K = dot(dot(P, H.T), linalg.inv(S)) + except: + # can't invert a 1D array, annoyingly + K = dot(dot(P, H.T), 1./S) + + + # predict new x with residual scaled by the kalman gain + x = x + dot(K, y) + + # P = (I-KH)P(I-KH)' + KRK' + KH = dot(K, H) + + try: + I_KH = np.eye(KH.shape[0]) - KH + except: + I_KH = np.array([1 - KH]) + P = dot(dot(I_KH, P), I_KH.T) + dot(dot(K, R), K.T) + + + if return_all: + # compute log likelihood + log_likelihood = logpdf(z, dot(H, x), S) + return x, P, y, K, S, log_likelihood + return x, P + + +def update_steadystate(x, z, K, H=None): + """ + Add a new measurement (z) to the Kalman filter. If z is None, nothing + is changed. + Parameters + ---------- + x : numpy.array(dim_x, 1), or float + State estimate vector + z : (dim_z, 1): array_like + measurement for this update. z can be a scalar if dim_z is 1, + otherwise it must be convertible to a column vector. + K : numpy.array, or float + Kalman gain matrix + H : numpy.array(dim_x, dim_x), or float, optional + Measurement function. If not provided, a value of 1 is assumed. + Returns + ------- + x : numpy.array + Posterior state estimate vector + Examples + -------- + This can handle either the multidimensional or unidimensional case. If + all parameters are floats instead of arrays the filter will still work, + and return floats for x, P as the result. + >>> update_steadystate(1, 2, 1) # univariate + >>> update_steadystate(x, P, z, H) + """ + + + if z is None: + return x + + if H is None: + H = np.array([1]) + + if np.isscalar(H): + H = np.array([H]) + + Hx = np.atleast_1d(dot(H, x)) + z = reshape_z(z, Hx.shape[0], x.ndim) + + # error (residual) between measurement and prediction + y = z - Hx + + # estimate new x with residual scaled by the kalman gain + return x + dot(K, y) + + +def predict(x, P, F=1, Q=0, u=0, B=1, alpha=1.): + """ + Predict next state (prior) using the Kalman filter state propagation + equations. + Parameters + ---------- + x : numpy.array + State estimate vector + P : numpy.array + Covariance matrix + F : numpy.array() + State Transition matrix + Q : numpy.array, Optional + Process noise matrix + u : numpy.array, Optional, default 0. + Control vector. If non-zero, it is multiplied by B + to create the control input into the system. + B : numpy.array, optional, default 0. + Control transition matrix. + alpha : float, Optional, default=1.0 + Fading memory setting. 1.0 gives the normal Kalman filter, and + values slightly larger than 1.0 (such as 1.02) give a fading + memory effect - previous measurements have less influence on the + filter's estimates. This formulation of the Fading memory filter + (there are many) is due to Dan Simon + Returns + ------- + x : numpy.array + Prior state estimate vector + P : numpy.array + Prior covariance matrix + """ + + if np.isscalar(F): + F = np.array(F) + x = dot(F, x) + dot(B, u) + P = (alpha * alpha) * dot(dot(F, P), F.T) + Q + + return x, P + + +def predict_steadystate(x, F=1, u=0, B=1): + """ + Predict next state (prior) using the Kalman filter state propagation + equations. This steady state form only computes x, assuming that the + covariance is constant. + Parameters + ---------- + x : numpy.array + State estimate vector + P : numpy.array + Covariance matrix + F : numpy.array() + State Transition matrix + u : numpy.array, Optional, default 0. + Control vector. If non-zero, it is multiplied by B + to create the control input into the system. + B : numpy.array, optional, default 0. + Control transition matrix. + Returns + ------- + x : numpy.array + Prior state estimate vector + """ + + if np.isscalar(F): + F = np.array(F) + x = dot(F, x) + dot(B, u) + + return x + + + +def batch_filter(x, P, zs, Fs, Qs, Hs, Rs, Bs=None, us=None, + update_first=False, saver=None): + """ + Batch processes a sequences of measurements. + Parameters + ---------- + zs : list-like + list of measurements at each time step. Missing measurements must be + represented by None. + Fs : list-like + list of values to use for the state transition matrix matrix. + Qs : list-like + list of values to use for the process error + covariance. + Hs : list-like + list of values to use for the measurement matrix. + Rs : list-like + list of values to use for the measurement error + covariance. + Bs : list-like, optional + list of values to use for the control transition matrix; + a value of None in any position will cause the filter + to use `self.B` for that time step. + us : list-like, optional + list of values to use for the control input vector; + a value of None in any position will cause the filter to use + 0 for that time step. + update_first : bool, optional + controls whether the order of operations is update followed by + predict, or predict followed by update. Default is predict->update. + saver : filterpy.common.Saver, optional + filterpy.common.Saver object. If provided, saver.save() will be + called after every epoch + Returns + ------- + means : np.array((n,dim_x,1)) + array of the state for each time step after the update. Each entry + is an np.array. In other words `means[k,:]` is the state at step + `k`. + covariance : np.array((n,dim_x,dim_x)) + array of the covariances for each time step after the update. + In other words `covariance[k,:,:]` is the covariance at step `k`. + means_predictions : np.array((n,dim_x,1)) + array of the state for each time step after the predictions. Each + entry is an np.array. In other words `means[k,:]` is the state at + step `k`. + covariance_predictions : np.array((n,dim_x,dim_x)) + array of the covariances for each time step after the prediction. + In other words `covariance[k,:,:]` is the covariance at step `k`. + Examples + -------- + .. code-block:: Python + zs = [t + random.randn()*4 for t in range (40)] + Fs = [kf.F for t in range (40)] + Hs = [kf.H for t in range (40)] + (mu, cov, _, _) = kf.batch_filter(zs, Rs=R_list, Fs=Fs, Hs=Hs, Qs=None, + Bs=None, us=None, update_first=False) + (xs, Ps, Ks, Pps) = kf.rts_smoother(mu, cov, Fs=Fs, Qs=None) + """ + + n = np.size(zs, 0) + dim_x = x.shape[0] + + # mean estimates from Kalman Filter + if x.ndim == 1: + means = zeros((n, dim_x)) + means_p = zeros((n, dim_x)) + else: + means = zeros((n, dim_x, 1)) + means_p = zeros((n, dim_x, 1)) + + # state covariances from Kalman Filter + covariances = zeros((n, dim_x, dim_x)) + covariances_p = zeros((n, dim_x, dim_x)) + + if us is None: + us = [0.] * n + Bs = [0.] * n + + if update_first: + for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)): + + x, P = update(x, P, z, R=R, H=H) + means[i, :] = x + covariances[i, :, :] = P + + x, P = predict(x, P, u=u, B=B, F=F, Q=Q) + means_p[i, :] = x + covariances_p[i, :, :] = P + if saver is not None: + saver.save() + else: + for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)): + + x, P = predict(x, P, u=u, B=B, F=F, Q=Q) + means_p[i, :] = x + covariances_p[i, :, :] = P + + x, P = update(x, P, z, R=R, H=H) + means[i, :] = x + covariances[i, :, :] = P + if saver is not None: + saver.save() + + return (means, covariances, means_p, covariances_p) + + + +def rts_smoother(Xs, Ps, Fs, Qs): + """ + Runs the Rauch-Tung-Striebel Kalman smoother on a set of + means and covariances computed by a Kalman filter. The usual input + would come from the output of `KalmanFilter.batch_filter()`. + Parameters + ---------- + Xs : numpy.array + array of the means (state variable x) of the output of a Kalman + filter. + Ps : numpy.array + array of the covariances of the output of a kalman filter. + Fs : list-like collection of numpy.array + State transition matrix of the Kalman filter at each time step. + Qs : list-like collection of numpy.array, optional + Process noise of the Kalman filter at each time step. + Returns + ------- + x : numpy.ndarray + smoothed means + P : numpy.ndarray + smoothed state covariances + K : numpy.ndarray + smoother gain at each step + pP : numpy.ndarray + predicted state covariances + Examples + -------- + .. code-block:: Python + zs = [t + random.randn()*4 for t in range (40)] + (mu, cov, _, _) = kalman.batch_filter(zs) + (x, P, K, pP) = rts_smoother(mu, cov, kf.F, kf.Q) + """ + + if len(Xs) != len(Ps): + raise ValueError('length of Xs and Ps must be the same') + + n = Xs.shape[0] + dim_x = Xs.shape[1] + + # smoother gain + K = zeros((n, dim_x, dim_x)) + x, P, pP = Xs.copy(), Ps.copy(), Ps.copy() + + for k in range(n-2, -1, -1): + pP[k] = dot(dot(Fs[k], P[k]), Fs[k].T) + Qs[k] + + #pylint: disable=bad-whitespace + K[k] = dot(dot(P[k], Fs[k].T), linalg.inv(pP[k])) + x[k] += dot(K[k], x[k+1] - dot(Fs[k], x[k])) + P[k] += dot(dot(K[k], P[k+1] - pP[k]), K[k].T) + + return (x, P, K, pP) \ No newline at end of file diff --git a/feeder/trackers/ocsort/ocsort.py b/feeder/trackers/ocsort/ocsort.py new file mode 100644 index 0000000..f4eddf0 --- /dev/null +++ b/feeder/trackers/ocsort/ocsort.py @@ -0,0 +1,328 @@ +""" + This script is adopted from the SORT script by Alex Bewley alex@bewley.ai +""" +from __future__ import print_function + +import numpy as np +from .association import * +from ultralytics.yolo.utils.ops import xywh2xyxy + + +def k_previous_obs(observations, cur_age, k): + if len(observations) == 0: + return [-1, -1, -1, -1, -1] + for i in range(k): + dt = k - i + if cur_age - dt in observations: + return observations[cur_age-dt] + max_age = max(observations.keys()) + return observations[max_age] + + +def convert_bbox_to_z(bbox): + """ + Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form + [x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is + the aspect ratio + """ + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + x = bbox[0] + w/2. + y = bbox[1] + h/2. + s = w * h # scale is just area + r = w / float(h+1e-6) + return np.array([x, y, s, r]).reshape((4, 1)) + + +def convert_x_to_bbox(x, score=None): + """ + Takes a bounding box in the centre form [x,y,s,r] and returns it in the form + [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right + """ + w = np.sqrt(x[2] * x[3]) + h = x[2] / w + if(score == None): + return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2.]).reshape((1, 4)) + else: + return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2., score]).reshape((1, 5)) + + +def speed_direction(bbox1, bbox2): + cx1, cy1 = (bbox1[0]+bbox1[2]) / 2.0, (bbox1[1]+bbox1[3])/2.0 + cx2, cy2 = (bbox2[0]+bbox2[2]) / 2.0, (bbox2[1]+bbox2[3])/2.0 + speed = np.array([cy2-cy1, cx2-cx1]) + norm = np.sqrt((cy2-cy1)**2 + (cx2-cx1)**2) + 1e-6 + return speed / norm + + +class KalmanBoxTracker(object): + """ + This class represents the internal state of individual tracked objects observed as bbox. + """ + count = 0 + + def __init__(self, bbox, cls, delta_t=3, orig=False): + """ + Initialises a tracker using initial bounding box. + + """ + # define constant velocity model + if not orig: + from .kalmanfilter import KalmanFilterNew as KalmanFilter + self.kf = KalmanFilter(dim_x=7, dim_z=4) + else: + from filterpy.kalman import KalmanFilter + self.kf = KalmanFilter(dim_x=7, dim_z=4) + self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0], [0, 0, 1, 0, 0, 0, 1], [ + 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 1]]) + self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0]]) + + self.kf.R[2:, 2:] *= 10. + self.kf.P[4:, 4:] *= 1000. # give high uncertainty to the unobservable initial velocities + self.kf.P *= 10. + self.kf.Q[-1, -1] *= 0.01 + self.kf.Q[4:, 4:] *= 0.01 + + self.kf.x[:4] = convert_bbox_to_z(bbox) + self.time_since_update = 0 + self.id = KalmanBoxTracker.count + KalmanBoxTracker.count += 1 + self.history = [] + self.hits = 0 + self.hit_streak = 0 + self.age = 0 + self.conf = bbox[-1] + self.cls = cls + """ + NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of + function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a + fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), let's bear it for now. + """ + self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder + self.observations = dict() + self.history_observations = [] + self.velocity = None + self.delta_t = delta_t + + def update(self, bbox, cls): + """ + Updates the state vector with observed bbox. + """ + + if bbox is not None: + self.conf = bbox[-1] + self.cls = cls + if self.last_observation.sum() >= 0: # no previous observation + previous_box = None + for i in range(self.delta_t): + dt = self.delta_t - i + if self.age - dt in self.observations: + previous_box = self.observations[self.age-dt] + break + if previous_box is None: + previous_box = self.last_observation + """ + Estimate the track speed direction with observations \Delta t steps away + """ + self.velocity = speed_direction(previous_box, bbox) + + """ + Insert new observations. This is a ugly way to maintain both self.observations + and self.history_observations. Bear it for the moment. + """ + self.last_observation = bbox + self.observations[self.age] = bbox + self.history_observations.append(bbox) + + self.time_since_update = 0 + self.history = [] + self.hits += 1 + self.hit_streak += 1 + self.kf.update(convert_bbox_to_z(bbox)) + else: + self.kf.update(bbox) + + def predict(self): + """ + Advances the state vector and returns the predicted bounding box estimate. + """ + if((self.kf.x[6]+self.kf.x[2]) <= 0): + self.kf.x[6] *= 0.0 + + self.kf.predict() + self.age += 1 + if(self.time_since_update > 0): + self.hit_streak = 0 + self.time_since_update += 1 + self.history.append(convert_x_to_bbox(self.kf.x)) + return self.history[-1] + + def get_state(self): + """ + Returns the current bounding box estimate. + """ + return convert_x_to_bbox(self.kf.x) + + +""" + We support multiple ways for association cost calculation, by default + we use IoU. GIoU may have better performance in some situations. We note + that we hardly normalize the cost by all methods to (0,1) which may not be + the best practice. +""" +ASSO_FUNCS = { "iou": iou_batch, + "giou": giou_batch, + "ciou": ciou_batch, + "diou": diou_batch, + "ct_dist": ct_dist} + + +class OCSort(object): + def __init__(self, det_thresh, max_age=30, min_hits=3, + iou_threshold=0.3, delta_t=3, asso_func="iou", inertia=0.2, use_byte=False): + """ + Sets key parameters for SORT + """ + self.max_age = max_age + self.min_hits = min_hits + self.iou_threshold = iou_threshold + self.trackers = [] + self.frame_count = 0 + self.det_thresh = det_thresh + self.delta_t = delta_t + self.asso_func = ASSO_FUNCS[asso_func] + self.inertia = inertia + self.use_byte = use_byte + KalmanBoxTracker.count = 0 + + def update(self, dets, _): + """ + Params: + dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...] + Requires: this method must be called once for each frame even with empty detections (use np.empty((0, 5)) for frames without detections). + Returns the a similar array, where the last column is the object ID. + NOTE: The number of objects returned may differ from the number of detections provided. + """ + + self.frame_count += 1 + + xyxys = dets[:, 0:4] + confs = dets[:, 4] + clss = dets[:, 5] + + classes = clss.numpy() + xyxys = xyxys.numpy() + confs = confs.numpy() + + output_results = np.column_stack((xyxys, confs, classes)) + + inds_low = confs > 0.1 + inds_high = confs < self.det_thresh + inds_second = np.logical_and(inds_low, inds_high) # self.det_thresh > score > 0.1, for second matching + dets_second = output_results[inds_second] # detections for second matching + remain_inds = confs > self.det_thresh + dets = output_results[remain_inds] + + # get predicted locations from existing trackers. + trks = np.zeros((len(self.trackers), 5)) + to_del = [] + ret = [] + for t, trk in enumerate(trks): + pos = self.trackers[t].predict()[0] + trk[:] = [pos[0], pos[1], pos[2], pos[3], 0] + if np.any(np.isnan(pos)): + to_del.append(t) + trks = np.ma.compress_rows(np.ma.masked_invalid(trks)) + for t in reversed(to_del): + self.trackers.pop(t) + + velocities = np.array( + [trk.velocity if trk.velocity is not None else np.array((0, 0)) for trk in self.trackers]) + last_boxes = np.array([trk.last_observation for trk in self.trackers]) + k_observations = np.array( + [k_previous_obs(trk.observations, trk.age, self.delta_t) for trk in self.trackers]) + + """ + First round of association + """ + matched, unmatched_dets, unmatched_trks = associate( + dets, trks, self.iou_threshold, velocities, k_observations, self.inertia) + for m in matched: + self.trackers[m[1]].update(dets[m[0], :5], dets[m[0], 5]) + + """ + Second round of associaton by OCR + """ + # BYTE association + if self.use_byte and len(dets_second) > 0 and unmatched_trks.shape[0] > 0: + u_trks = trks[unmatched_trks] + iou_left = self.asso_func(dets_second, u_trks) # iou between low score detections and unmatched tracks + iou_left = np.array(iou_left) + if iou_left.max() > self.iou_threshold: + """ + NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may + get a higher performance especially on MOT17/MOT20 datasets. But we keep it + uniform here for simplicity + """ + matched_indices = linear_assignment(-iou_left) + to_remove_trk_indices = [] + for m in matched_indices: + det_ind, trk_ind = m[0], unmatched_trks[m[1]] + if iou_left[m[0], m[1]] < self.iou_threshold: + continue + self.trackers[trk_ind].update(dets_second[det_ind, :5], dets_second[det_ind, 5]) + to_remove_trk_indices.append(trk_ind) + unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices)) + + if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0: + left_dets = dets[unmatched_dets] + left_trks = last_boxes[unmatched_trks] + iou_left = self.asso_func(left_dets, left_trks) + iou_left = np.array(iou_left) + if iou_left.max() > self.iou_threshold: + """ + NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may + get a higher performance especially on MOT17/MOT20 datasets. But we keep it + uniform here for simplicity + """ + rematched_indices = linear_assignment(-iou_left) + to_remove_det_indices = [] + to_remove_trk_indices = [] + for m in rematched_indices: + det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]] + if iou_left[m[0], m[1]] < self.iou_threshold: + continue + self.trackers[trk_ind].update(dets[det_ind, :5], dets[det_ind, 5]) + to_remove_det_indices.append(det_ind) + to_remove_trk_indices.append(trk_ind) + unmatched_dets = np.setdiff1d(unmatched_dets, np.array(to_remove_det_indices)) + unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices)) + + for m in unmatched_trks: + self.trackers[m].update(None, None) + + # create and initialise new trackers for unmatched detections + for i in unmatched_dets: + trk = KalmanBoxTracker(dets[i, :5], dets[i, 5], delta_t=self.delta_t) + self.trackers.append(trk) + i = len(self.trackers) + for trk in reversed(self.trackers): + if trk.last_observation.sum() < 0: + d = trk.get_state()[0] + else: + """ + this is optional to use the recent observation or the kalman filter prediction, + we didn't notice significant difference here + """ + d = trk.last_observation[:4] + if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits): + # +1 as MOT benchmark requires positive + ret.append(np.concatenate((d, [trk.id+1], [trk.cls], [trk.conf])).reshape(1, -1)) + i -= 1 + # remove dead tracklet + if(trk.time_since_update > self.max_age): + self.trackers.pop(i) + if(len(ret) > 0): + return np.concatenate(ret) + return np.empty((0, 5)) diff --git a/feeder/trackers/reid_export.py b/feeder/trackers/reid_export.py new file mode 100644 index 0000000..9ef8d13 --- /dev/null +++ b/feeder/trackers/reid_export.py @@ -0,0 +1,313 @@ +import argparse + +import os +# limit the number of cpus used by high performance libraries +os.environ["OMP_NUM_THREADS"] = "1" +os.environ["OPENBLAS_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["VECLIB_MAXIMUM_THREADS"] = "1" +os.environ["NUMEXPR_NUM_THREADS"] = "1" + +import sys +import numpy as np +from pathlib import Path +import torch +import time +import platform +import pandas as pd +import subprocess +import torch.backends.cudnn as cudnn +from torch.utils.mobile_optimizer import optimize_for_mobile + +FILE = Path(__file__).resolve() +ROOT = FILE.parents[0].parents[0] # yolov5 strongsort root directory +WEIGHTS = ROOT / 'weights' + + +if str(ROOT) not in sys.path: + sys.path.append(str(ROOT)) # add ROOT to PATH +if str(ROOT / 'yolov5') not in sys.path: + sys.path.append(str(ROOT / 'yolov5')) # add yolov5 ROOT to PATH + +ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative + +import logging +from ultralytics.yolo.utils.torch_utils import select_device +from ultralytics.yolo.utils import LOGGER, colorstr, ops +from ultralytics.yolo.utils.checks import check_requirements, check_version +from trackers.strongsort.deep.models import build_model +from trackers.strongsort.deep.reid_model_factory import get_model_name, load_pretrained_weights + + +def file_size(path): + # Return file/dir size (MB) + path = Path(path) + if path.is_file(): + return path.stat().st_size / 1E6 + elif path.is_dir(): + return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6 + else: + return 0.0 + + +def export_formats(): + # YOLOv5 export formats + x = [ + ['PyTorch', '-', '.pt', True, True], + ['TorchScript', 'torchscript', '.torchscript', True, True], + ['ONNX', 'onnx', '.onnx', True, True], + ['OpenVINO', 'openvino', '_openvino_model', True, False], + ['TensorRT', 'engine', '.engine', False, True], + ['TensorFlow Lite', 'tflite', '.tflite', True, False], + ] + return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU']) + + +def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')): + # YOLOv5 TorchScript model export + try: + LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...') + f = file.with_suffix('.torchscript') + + ts = torch.jit.trace(model, im, strict=False) + if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html + optimize_for_mobile(ts)._save_for_lite_interpreter(str(f)) + else: + ts.save(str(f)) + + LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') + return f + except Exception as e: + LOGGER.info(f'{prefix} export failure: {e}') + + +def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')): + # ONNX export + try: + check_requirements(('onnx',)) + import onnx + + f = file.with_suffix('.onnx') + LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...') + + if dynamic: + dynamic = {'images': {0: 'batch'}} # shape(1,3,640,640) + dynamic['output'] = {0: 'batch'} # shape(1,25200,85) + + torch.onnx.export( + model.cpu() if dynamic else model, # --dynamic only compatible with cpu + im.cpu() if dynamic else im, + f, + verbose=False, + opset_version=opset, + do_constant_folding=True, + input_names=['images'], + output_names=['output'], + dynamic_axes=dynamic or None + ) + # Checks + model_onnx = onnx.load(f) # load onnx model + onnx.checker.check_model(model_onnx) # check onnx model + onnx.save(model_onnx, f) + + # Simplify + if simplify: + try: + cuda = torch.cuda.is_available() + check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1')) + import onnxsim + + LOGGER.info(f'simplifying with onnx-simplifier {onnxsim.__version__}...') + model_onnx, check = onnxsim.simplify(model_onnx) + assert check, 'assert check failed' + onnx.save(model_onnx, f) + except Exception as e: + LOGGER.info(f'simplifier failure: {e}') + LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') + return f + except Exception as e: + LOGGER.info(f'export failure: {e}') + + + +def export_openvino(file, half, prefix=colorstr('OpenVINO:')): + # YOLOv5 OpenVINO export + check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/ + import openvino.inference_engine as ie + try: + LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...') + f = str(file).replace('.pt', f'_openvino_model{os.sep}') + + cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}" + subprocess.check_output(cmd.split()) # export + except Exception as e: + LOGGER.info(f'export failure: {e}') + LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') + return f + + +def export_tflite(file, half, prefix=colorstr('TFLite:')): + # YOLOv5 OpenVINO export + try: + check_requirements(('openvino2tensorflow', 'tensorflow', 'tensorflow_datasets')) # requires openvino-dev: https://pypi.org/project/openvino-dev/ + import openvino.inference_engine as ie + LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...') + output = Path(str(file).replace(f'_openvino_model{os.sep}', f'_tflite_model{os.sep}')) + modelxml = list(Path(file).glob('*.xml'))[0] + cmd = f"openvino2tensorflow \ + --model_path {modelxml} \ + --model_output_path {output} \ + --output_pb \ + --output_saved_model \ + --output_no_quant_float32_tflite \ + --output_dynamic_range_quant_tflite" + subprocess.check_output(cmd.split()) # export + + LOGGER.info(f'{prefix} export success, results saved in {output} ({file_size(f):.1f} MB)') + return f + except Exception as e: + LOGGER.info(f'\n{prefix} export failure: {e}') + + +def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): + # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt + try: + assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`' + try: + import tensorrt as trt + except Exception: + if platform.system() == 'Linux': + check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',)) + import tensorrt as trt + + if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012 + grid = model.model[-1].anchor_grid + model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid] + export_onnx(model, im, file, 12, dynamic, simplify) # opset 12 + model.model[-1].anchor_grid = grid + else: # TensorRT >= 8 + check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0 + export_onnx(model, im, file, 12, dynamic, simplify) # opset 13 + onnx = file.with_suffix('.onnx') + + LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') + assert onnx.exists(), f'failed to export ONNX file: {onnx}' + f = file.with_suffix('.engine') # TensorRT engine file + logger = trt.Logger(trt.Logger.INFO) + if verbose: + logger.min_severity = trt.Logger.Severity.VERBOSE + + builder = trt.Builder(logger) + config = builder.create_builder_config() + config.max_workspace_size = workspace * 1 << 30 + # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice + + flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + network = builder.create_network(flag) + parser = trt.OnnxParser(network, logger) + if not parser.parse_from_file(str(onnx)): + raise RuntimeError(f'failed to load ONNX file: {onnx}') + + inputs = [network.get_input(i) for i in range(network.num_inputs)] + outputs = [network.get_output(i) for i in range(network.num_outputs)] + LOGGER.info(f'{prefix} Network Description:') + for inp in inputs: + LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}') + for out in outputs: + LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}') + + if dynamic: + if im.shape[0] <= 1: + LOGGER.warning(f"{prefix}WARNING: --dynamic model requires maximum --batch-size argument") + profile = builder.create_optimization_profile() + for inp in inputs: + profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape) + config.add_optimization_profile(profile) + + LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}') + if builder.platform_has_fast_fp16 and half: + config.set_flag(trt.BuilderFlag.FP16) + with builder.build_engine(network, config) as engine, open(f, 'wb') as t: + t.write(engine.serialize()) + LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') + return f + except Exception as e: + LOGGER.info(f'\n{prefix} export failure: {e}') + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="ReID export") + parser.add_argument('--batch-size', type=int, default=1, help='batch size') + parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[256, 128], help='image (h, w)') + parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile') + parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes') + parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model') + parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version') + parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)') + parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log') + parser.add_argument('--weights', nargs='+', type=str, default=WEIGHTS / 'osnet_x0_25_msmt17.pt', help='model.pt path(s)') + parser.add_argument('--half', action='store_true', help='FP16 half-precision export') + parser.add_argument('--include', + nargs='+', + default=['torchscript'], + help='torchscript, onnx, openvino, engine') + args = parser.parse_args() + + t = time.time() + + include = [x.lower() for x in args.include] # to lowercase + fmts = tuple(export_formats()['Argument'][1:]) # --include arguments + flags = [x in include for x in fmts] + assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}' + jit, onnx, openvino, engine, tflite = flags # export booleans + + args.device = select_device(args.device) + if args.half: + assert args.device.type != 'cpu', '--half only compatible with GPU export, i.e. use --device 0' + assert not args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both' + + if type(args.weights) is list: + args.weights = Path(args.weights[0]) + + model = build_model( + get_model_name(args.weights), + num_classes=1, + pretrained=not (args.weights and args.weights.is_file() and args.weights.suffix == '.pt'), + use_gpu=args.device + ).to(args.device) + load_pretrained_weights(model, args.weights) + model.eval() + + if args.optimize: + assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu' + + im = torch.zeros(args.batch_size, 3, args.imgsz[0], args.imgsz[1]).to(args.device) # image size(1,3,640,480) BCHW iDetection + for _ in range(2): + y = model(im) # dry runs + if args.half: + im, model = im.half(), model.half() # to FP16 + shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape + LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {args.weights} with output shape {shape} ({file_size(args.weights):.1f} MB)") + + # Exports + f = [''] * len(fmts) # exported filenames + if jit: + f[0] = export_torchscript(model, im, args.weights, args.optimize) # opset 12 + if engine: # TensorRT required before ONNX + f[1] = export_engine(model, im, args.weights, args.half, args.dynamic, args.simplify, args.workspace, args.verbose) + if onnx: # OpenVINO requires ONNX + f[2] = export_onnx(model, im, args.weights, args.opset, args.dynamic, args.simplify) # opset 12 + if openvino: + f[3] = export_openvino(args.weights, args.half) + if tflite: + export_tflite(f, False) + + # Finish + f = [str(x) for x in f if x] # filter out '' and None + if any(f): + LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)' + f"\nResults saved to {colorstr('bold', args.weights.parent.resolve())}" + f"\nVisualize: https://netron.app") + diff --git a/feeder/trackers/strongsort/__init__.py b/feeder/trackers/strongsort/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feeder/trackers/strongsort/configs/strongsort.yaml b/feeder/trackers/strongsort/configs/strongsort.yaml new file mode 100644 index 0000000..c4fa8b6 --- /dev/null +++ b/feeder/trackers/strongsort/configs/strongsort.yaml @@ -0,0 +1,11 @@ +strongsort: + ecc: true + ema_alpha: 0.8962157769329083 + max_age: 40 + max_dist: 0.1594374041012136 + max_iou_dist: 0.5431835667667874 + max_unmatched_preds: 0 + mc_lambda: 0.995 + n_init: 3 + nn_budget: 100 + conf_thres: 0.5122620708221085 diff --git a/feeder/trackers/strongsort/deep/checkpoint/.gitkeep b/feeder/trackers/strongsort/deep/checkpoint/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/feeder/trackers/strongsort/deep/checkpoint/osnet_x0_25_market1501.pth b/feeder/trackers/strongsort/deep/checkpoint/osnet_x0_25_market1501.pth new file mode 100644 index 0000000..7fffc34 Binary files /dev/null and b/feeder/trackers/strongsort/deep/checkpoint/osnet_x0_25_market1501.pth differ diff --git a/feeder/trackers/strongsort/deep/checkpoint/osnet_x0_25_msmt17.pth b/feeder/trackers/strongsort/deep/checkpoint/osnet_x0_25_msmt17.pth new file mode 100644 index 0000000..f80a348 Binary files /dev/null and b/feeder/trackers/strongsort/deep/checkpoint/osnet_x0_25_msmt17.pth differ diff --git a/feeder/trackers/strongsort/deep/checkpoint/osnet_x1_0_msmt17.pth b/feeder/trackers/strongsort/deep/checkpoint/osnet_x1_0_msmt17.pth new file mode 100644 index 0000000..078ad76 Binary files /dev/null and b/feeder/trackers/strongsort/deep/checkpoint/osnet_x1_0_msmt17.pth differ diff --git a/feeder/trackers/strongsort/deep/models/__init__.py b/feeder/trackers/strongsort/deep/models/__init__.py new file mode 100644 index 0000000..3c60ba6 --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/__init__.py @@ -0,0 +1,122 @@ +from __future__ import absolute_import +import torch + +from .pcb import * +from .mlfn import * +from .hacnn import * +from .osnet import * +from .senet import * +from .mudeep import * +from .nasnet import * +from .resnet import * +from .densenet import * +from .xception import * +from .osnet_ain import * +from .resnetmid import * +from .shufflenet import * +from .squeezenet import * +from .inceptionv4 import * +from .mobilenetv2 import * +from .resnet_ibn_a import * +from .resnet_ibn_b import * +from .shufflenetv2 import * +from .inceptionresnetv2 import * + +__model_factory = { + # image classification models + 'resnet18': resnet18, + 'resnet34': resnet34, + 'resnet50': resnet50, + 'resnet101': resnet101, + 'resnet152': resnet152, + 'resnext50_32x4d': resnext50_32x4d, + 'resnext101_32x8d': resnext101_32x8d, + 'resnet50_fc512': resnet50_fc512, + 'se_resnet50': se_resnet50, + 'se_resnet50_fc512': se_resnet50_fc512, + 'se_resnet101': se_resnet101, + 'se_resnext50_32x4d': se_resnext50_32x4d, + 'se_resnext101_32x4d': se_resnext101_32x4d, + 'densenet121': densenet121, + 'densenet169': densenet169, + 'densenet201': densenet201, + 'densenet161': densenet161, + 'densenet121_fc512': densenet121_fc512, + 'inceptionresnetv2': inceptionresnetv2, + 'inceptionv4': inceptionv4, + 'xception': xception, + 'resnet50_ibn_a': resnet50_ibn_a, + 'resnet50_ibn_b': resnet50_ibn_b, + # lightweight models + 'nasnsetmobile': nasnetamobile, + 'mobilenetv2_x1_0': mobilenetv2_x1_0, + 'mobilenetv2_x1_4': mobilenetv2_x1_4, + 'shufflenet': shufflenet, + 'squeezenet1_0': squeezenet1_0, + 'squeezenet1_0_fc512': squeezenet1_0_fc512, + 'squeezenet1_1': squeezenet1_1, + 'shufflenet_v2_x0_5': shufflenet_v2_x0_5, + 'shufflenet_v2_x1_0': shufflenet_v2_x1_0, + 'shufflenet_v2_x1_5': shufflenet_v2_x1_5, + 'shufflenet_v2_x2_0': shufflenet_v2_x2_0, + # reid-specific models + 'mudeep': MuDeep, + 'resnet50mid': resnet50mid, + 'hacnn': HACNN, + 'pcb_p6': pcb_p6, + 'pcb_p4': pcb_p4, + 'mlfn': mlfn, + 'osnet_x1_0': osnet_x1_0, + 'osnet_x0_75': osnet_x0_75, + 'osnet_x0_5': osnet_x0_5, + 'osnet_x0_25': osnet_x0_25, + 'osnet_ibn_x1_0': osnet_ibn_x1_0, + 'osnet_ain_x1_0': osnet_ain_x1_0, + 'osnet_ain_x0_75': osnet_ain_x0_75, + 'osnet_ain_x0_5': osnet_ain_x0_5, + 'osnet_ain_x0_25': osnet_ain_x0_25 +} + + +def show_avai_models(): + """Displays available models. + + Examples:: + >>> from torchreid import models + >>> models.show_avai_models() + """ + print(list(__model_factory.keys())) + + +def build_model( + name, num_classes, loss='softmax', pretrained=True, use_gpu=True +): + """A function wrapper for building a model. + + Args: + name (str): model name. + num_classes (int): number of training identities. + loss (str, optional): loss function to optimize the model. Currently + supports "softmax" and "triplet". Default is "softmax". + pretrained (bool, optional): whether to load ImageNet-pretrained weights. + Default is True. + use_gpu (bool, optional): whether to use gpu. Default is True. + + Returns: + nn.Module + + Examples:: + >>> from torchreid import models + >>> model = models.build_model('resnet50', 751, loss='softmax') + """ + avai_models = list(__model_factory.keys()) + if name not in avai_models: + raise KeyError( + 'Unknown model: {}. Must be one of {}'.format(name, avai_models) + ) + return __model_factory[name]( + num_classes=num_classes, + loss=loss, + pretrained=pretrained, + use_gpu=use_gpu + ) diff --git a/feeder/trackers/strongsort/deep/models/densenet.py b/feeder/trackers/strongsort/deep/models/densenet.py new file mode 100644 index 0000000..a1d9b7e --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/densenet.py @@ -0,0 +1,380 @@ +""" +Code source: https://github.com/pytorch/vision +""" +from __future__ import division, absolute_import +import re +from collections import OrderedDict +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.utils import model_zoo + +__all__ = [ + 'densenet121', 'densenet169', 'densenet201', 'densenet161', + 'densenet121_fc512' +] + +model_urls = { + 'densenet121': + 'https://download.pytorch.org/models/densenet121-a639ec97.pth', + 'densenet169': + 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', + 'densenet201': + 'https://download.pytorch.org/models/densenet201-c1103571.pth', + 'densenet161': + 'https://download.pytorch.org/models/densenet161-8d451a50.pth', +} + + +class _DenseLayer(nn.Sequential): + + def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): + super(_DenseLayer, self).__init__() + self.add_module('norm1', nn.BatchNorm2d(num_input_features)), + self.add_module('relu1', nn.ReLU(inplace=True)), + self.add_module( + 'conv1', + nn.Conv2d( + num_input_features, + bn_size * growth_rate, + kernel_size=1, + stride=1, + bias=False + ) + ), + self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), + self.add_module('relu2', nn.ReLU(inplace=True)), + self.add_module( + 'conv2', + nn.Conv2d( + bn_size * growth_rate, + growth_rate, + kernel_size=3, + stride=1, + padding=1, + bias=False + ) + ), + self.drop_rate = drop_rate + + def forward(self, x): + new_features = super(_DenseLayer, self).forward(x) + if self.drop_rate > 0: + new_features = F.dropout( + new_features, p=self.drop_rate, training=self.training + ) + return torch.cat([x, new_features], 1) + + +class _DenseBlock(nn.Sequential): + + def __init__( + self, num_layers, num_input_features, bn_size, growth_rate, drop_rate + ): + super(_DenseBlock, self).__init__() + for i in range(num_layers): + layer = _DenseLayer( + num_input_features + i*growth_rate, growth_rate, bn_size, + drop_rate + ) + self.add_module('denselayer%d' % (i+1), layer) + + +class _Transition(nn.Sequential): + + def __init__(self, num_input_features, num_output_features): + super(_Transition, self).__init__() + self.add_module('norm', nn.BatchNorm2d(num_input_features)) + self.add_module('relu', nn.ReLU(inplace=True)) + self.add_module( + 'conv', + nn.Conv2d( + num_input_features, + num_output_features, + kernel_size=1, + stride=1, + bias=False + ) + ) + self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) + + +class DenseNet(nn.Module): + """Densely connected network. + + Reference: + Huang et al. Densely Connected Convolutional Networks. CVPR 2017. + + Public keys: + - ``densenet121``: DenseNet121. + - ``densenet169``: DenseNet169. + - ``densenet201``: DenseNet201. + - ``densenet161``: DenseNet161. + - ``densenet121_fc512``: DenseNet121 + FC. + """ + + def __init__( + self, + num_classes, + loss, + growth_rate=32, + block_config=(6, 12, 24, 16), + num_init_features=64, + bn_size=4, + drop_rate=0, + fc_dims=None, + dropout_p=None, + **kwargs + ): + + super(DenseNet, self).__init__() + self.loss = loss + + # First convolution + self.features = nn.Sequential( + OrderedDict( + [ + ( + 'conv0', + nn.Conv2d( + 3, + num_init_features, + kernel_size=7, + stride=2, + padding=3, + bias=False + ) + ), + ('norm0', nn.BatchNorm2d(num_init_features)), + ('relu0', nn.ReLU(inplace=True)), + ( + 'pool0', + nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + ), + ] + ) + ) + + # Each denseblock + num_features = num_init_features + for i, num_layers in enumerate(block_config): + block = _DenseBlock( + num_layers=num_layers, + num_input_features=num_features, + bn_size=bn_size, + growth_rate=growth_rate, + drop_rate=drop_rate + ) + self.features.add_module('denseblock%d' % (i+1), block) + num_features = num_features + num_layers*growth_rate + if i != len(block_config) - 1: + trans = _Transition( + num_input_features=num_features, + num_output_features=num_features // 2 + ) + self.features.add_module('transition%d' % (i+1), trans) + num_features = num_features // 2 + + # Final batch norm + self.features.add_module('norm5', nn.BatchNorm2d(num_features)) + + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.feature_dim = num_features + self.fc = self._construct_fc_layer(fc_dims, num_features, dropout_p) + + # Linear layer + self.classifier = nn.Linear(self.feature_dim, num_classes) + + self._init_params() + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + """Constructs fully connected layer. + + Args: + fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed + input_dim (int): input dimension + dropout_p (float): dropout probability, if None, dropout is unused + """ + if fc_dims is None: + self.feature_dim = input_dim + return None + + assert isinstance( + fc_dims, (list, tuple) + ), 'fc_dims must be either list or tuple, but got {}'.format( + type(fc_dims) + ) + + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU(inplace=True)) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + + self.feature_dim = fc_dims[-1] + + return nn.Sequential(*layers) + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu' + ) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + f = self.features(x) + f = F.relu(f, inplace=True) + v = self.global_avgpool(f) + v = v.view(v.size(0), -1) + + if self.fc is not None: + v = self.fc(v) + + if not self.training: + return v + + y = self.classifier(v) + + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + return y, v + else: + raise KeyError('Unsupported loss: {}'.format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url) + + # '.'s are no longer allowed in module names, but pervious _DenseLayer + # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. + # They are also in the checkpoints in model_urls. This pattern is used + # to find such keys. + pattern = re.compile( + r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$' + ) + for key in list(pretrain_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + pretrain_dict[new_key] = pretrain_dict[key] + del pretrain_dict[key] + + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +""" +Dense network configurations: +-- +densenet121: num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16) +densenet169: num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32) +densenet201: num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32) +densenet161: num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24) +""" + + +def densenet121(num_classes, loss='softmax', pretrained=True, **kwargs): + model = DenseNet( + num_classes=num_classes, + loss=loss, + num_init_features=64, + growth_rate=32, + block_config=(6, 12, 24, 16), + fc_dims=None, + dropout_p=None, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['densenet121']) + return model + + +def densenet169(num_classes, loss='softmax', pretrained=True, **kwargs): + model = DenseNet( + num_classes=num_classes, + loss=loss, + num_init_features=64, + growth_rate=32, + block_config=(6, 12, 32, 32), + fc_dims=None, + dropout_p=None, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['densenet169']) + return model + + +def densenet201(num_classes, loss='softmax', pretrained=True, **kwargs): + model = DenseNet( + num_classes=num_classes, + loss=loss, + num_init_features=64, + growth_rate=32, + block_config=(6, 12, 48, 32), + fc_dims=None, + dropout_p=None, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['densenet201']) + return model + + +def densenet161(num_classes, loss='softmax', pretrained=True, **kwargs): + model = DenseNet( + num_classes=num_classes, + loss=loss, + num_init_features=96, + growth_rate=48, + block_config=(6, 12, 36, 24), + fc_dims=None, + dropout_p=None, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['densenet161']) + return model + + +def densenet121_fc512(num_classes, loss='softmax', pretrained=True, **kwargs): + model = DenseNet( + num_classes=num_classes, + loss=loss, + num_init_features=64, + growth_rate=32, + block_config=(6, 12, 24, 16), + fc_dims=[512], + dropout_p=None, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['densenet121']) + return model diff --git a/feeder/trackers/strongsort/deep/models/hacnn.py b/feeder/trackers/strongsort/deep/models/hacnn.py new file mode 100644 index 0000000..f21cc82 --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/hacnn.py @@ -0,0 +1,414 @@ +from __future__ import division, absolute_import +import torch +from torch import nn +from torch.nn import functional as F + +__all__ = ['HACNN'] + + +class ConvBlock(nn.Module): + """Basic convolutional block. + + convolution + batch normalization + relu. + + Args: + in_c (int): number of input channels. + out_c (int): number of output channels. + k (int or tuple): kernel size. + s (int or tuple): stride. + p (int or tuple): padding. + """ + + def __init__(self, in_c, out_c, k, s=1, p=0): + super(ConvBlock, self).__init__() + self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p) + self.bn = nn.BatchNorm2d(out_c) + + def forward(self, x): + return F.relu(self.bn(self.conv(x))) + + +class InceptionA(nn.Module): + + def __init__(self, in_channels, out_channels): + super(InceptionA, self).__init__() + mid_channels = out_channels // 4 + + self.stream1 = nn.Sequential( + ConvBlock(in_channels, mid_channels, 1), + ConvBlock(mid_channels, mid_channels, 3, p=1), + ) + self.stream2 = nn.Sequential( + ConvBlock(in_channels, mid_channels, 1), + ConvBlock(mid_channels, mid_channels, 3, p=1), + ) + self.stream3 = nn.Sequential( + ConvBlock(in_channels, mid_channels, 1), + ConvBlock(mid_channels, mid_channels, 3, p=1), + ) + self.stream4 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1), + ConvBlock(in_channels, mid_channels, 1), + ) + + def forward(self, x): + s1 = self.stream1(x) + s2 = self.stream2(x) + s3 = self.stream3(x) + s4 = self.stream4(x) + y = torch.cat([s1, s2, s3, s4], dim=1) + return y + + +class InceptionB(nn.Module): + + def __init__(self, in_channels, out_channels): + super(InceptionB, self).__init__() + mid_channels = out_channels // 4 + + self.stream1 = nn.Sequential( + ConvBlock(in_channels, mid_channels, 1), + ConvBlock(mid_channels, mid_channels, 3, s=2, p=1), + ) + self.stream2 = nn.Sequential( + ConvBlock(in_channels, mid_channels, 1), + ConvBlock(mid_channels, mid_channels, 3, p=1), + ConvBlock(mid_channels, mid_channels, 3, s=2, p=1), + ) + self.stream3 = nn.Sequential( + nn.MaxPool2d(3, stride=2, padding=1), + ConvBlock(in_channels, mid_channels * 2, 1), + ) + + def forward(self, x): + s1 = self.stream1(x) + s2 = self.stream2(x) + s3 = self.stream3(x) + y = torch.cat([s1, s2, s3], dim=1) + return y + + +class SpatialAttn(nn.Module): + """Spatial Attention (Sec. 3.1.I.1)""" + + def __init__(self): + super(SpatialAttn, self).__init__() + self.conv1 = ConvBlock(1, 1, 3, s=2, p=1) + self.conv2 = ConvBlock(1, 1, 1) + + def forward(self, x): + # global cross-channel averaging + x = x.mean(1, keepdim=True) + # 3-by-3 conv + x = self.conv1(x) + # bilinear resizing + x = F.upsample( + x, (x.size(2) * 2, x.size(3) * 2), + mode='bilinear', + align_corners=True + ) + # scaling conv + x = self.conv2(x) + return x + + +class ChannelAttn(nn.Module): + """Channel Attention (Sec. 3.1.I.2)""" + + def __init__(self, in_channels, reduction_rate=16): + super(ChannelAttn, self).__init__() + assert in_channels % reduction_rate == 0 + self.conv1 = ConvBlock(in_channels, in_channels // reduction_rate, 1) + self.conv2 = ConvBlock(in_channels // reduction_rate, in_channels, 1) + + def forward(self, x): + # squeeze operation (global average pooling) + x = F.avg_pool2d(x, x.size()[2:]) + # excitation operation (2 conv layers) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class SoftAttn(nn.Module): + """Soft Attention (Sec. 3.1.I) + + Aim: Spatial Attention + Channel Attention + + Output: attention maps with shape identical to input. + """ + + def __init__(self, in_channels): + super(SoftAttn, self).__init__() + self.spatial_attn = SpatialAttn() + self.channel_attn = ChannelAttn(in_channels) + self.conv = ConvBlock(in_channels, in_channels, 1) + + def forward(self, x): + y_spatial = self.spatial_attn(x) + y_channel = self.channel_attn(x) + y = y_spatial * y_channel + y = torch.sigmoid(self.conv(y)) + return y + + +class HardAttn(nn.Module): + """Hard Attention (Sec. 3.1.II)""" + + def __init__(self, in_channels): + super(HardAttn, self).__init__() + self.fc = nn.Linear(in_channels, 4 * 2) + self.init_params() + + def init_params(self): + self.fc.weight.data.zero_() + self.fc.bias.data.copy_( + torch.tensor( + [0, -0.75, 0, -0.25, 0, 0.25, 0, 0.75], dtype=torch.float + ) + ) + + def forward(self, x): + # squeeze operation (global average pooling) + x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1)) + # predict transformation parameters + theta = torch.tanh(self.fc(x)) + theta = theta.view(-1, 4, 2) + return theta + + +class HarmAttn(nn.Module): + """Harmonious Attention (Sec. 3.1)""" + + def __init__(self, in_channels): + super(HarmAttn, self).__init__() + self.soft_attn = SoftAttn(in_channels) + self.hard_attn = HardAttn(in_channels) + + def forward(self, x): + y_soft_attn = self.soft_attn(x) + theta = self.hard_attn(x) + return y_soft_attn, theta + + +class HACNN(nn.Module): + """Harmonious Attention Convolutional Neural Network. + + Reference: + Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018. + + Public keys: + - ``hacnn``: HACNN. + """ + + # Args: + # num_classes (int): number of classes to predict + # nchannels (list): number of channels AFTER concatenation + # feat_dim (int): feature dimension for a single stream + # learn_region (bool): whether to learn region features (i.e. local branch) + + def __init__( + self, + num_classes, + loss='softmax', + nchannels=[128, 256, 384], + feat_dim=512, + learn_region=True, + use_gpu=True, + **kwargs + ): + super(HACNN, self).__init__() + self.loss = loss + self.learn_region = learn_region + self.use_gpu = use_gpu + + self.conv = ConvBlock(3, 32, 3, s=2, p=1) + + # Construct Inception + HarmAttn blocks + # ============== Block 1 ============== + self.inception1 = nn.Sequential( + InceptionA(32, nchannels[0]), + InceptionB(nchannels[0], nchannels[0]), + ) + self.ha1 = HarmAttn(nchannels[0]) + + # ============== Block 2 ============== + self.inception2 = nn.Sequential( + InceptionA(nchannels[0], nchannels[1]), + InceptionB(nchannels[1], nchannels[1]), + ) + self.ha2 = HarmAttn(nchannels[1]) + + # ============== Block 3 ============== + self.inception3 = nn.Sequential( + InceptionA(nchannels[1], nchannels[2]), + InceptionB(nchannels[2], nchannels[2]), + ) + self.ha3 = HarmAttn(nchannels[2]) + + self.fc_global = nn.Sequential( + nn.Linear(nchannels[2], feat_dim), + nn.BatchNorm1d(feat_dim), + nn.ReLU(), + ) + self.classifier_global = nn.Linear(feat_dim, num_classes) + + if self.learn_region: + self.init_scale_factors() + self.local_conv1 = InceptionB(32, nchannels[0]) + self.local_conv2 = InceptionB(nchannels[0], nchannels[1]) + self.local_conv3 = InceptionB(nchannels[1], nchannels[2]) + self.fc_local = nn.Sequential( + nn.Linear(nchannels[2] * 4, feat_dim), + nn.BatchNorm1d(feat_dim), + nn.ReLU(), + ) + self.classifier_local = nn.Linear(feat_dim, num_classes) + self.feat_dim = feat_dim * 2 + else: + self.feat_dim = feat_dim + + def init_scale_factors(self): + # initialize scale factors (s_w, s_h) for four regions + self.scale_factors = [] + self.scale_factors.append( + torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float) + ) + self.scale_factors.append( + torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float) + ) + self.scale_factors.append( + torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float) + ) + self.scale_factors.append( + torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float) + ) + + def stn(self, x, theta): + """Performs spatial transform + + x: (batch, channel, height, width) + theta: (batch, 2, 3) + """ + grid = F.affine_grid(theta, x.size()) + x = F.grid_sample(x, grid) + return x + + def transform_theta(self, theta_i, region_idx): + """Transforms theta to include (s_w, s_h), resulting in (batch, 2, 3)""" + scale_factors = self.scale_factors[region_idx] + theta = torch.zeros(theta_i.size(0), 2, 3) + theta[:, :, :2] = scale_factors + theta[:, :, -1] = theta_i + if self.use_gpu: + theta = theta.cuda() + return theta + + def forward(self, x): + assert x.size(2) == 160 and x.size(3) == 64, \ + 'Input size does not match, expected (160, 64) but got ({}, {})'.format(x.size(2), x.size(3)) + x = self.conv(x) + + # ============== Block 1 ============== + # global branch + x1 = self.inception1(x) + x1_attn, x1_theta = self.ha1(x1) + x1_out = x1 * x1_attn + # local branch + if self.learn_region: + x1_local_list = [] + for region_idx in range(4): + x1_theta_i = x1_theta[:, region_idx, :] + x1_theta_i = self.transform_theta(x1_theta_i, region_idx) + x1_trans_i = self.stn(x, x1_theta_i) + x1_trans_i = F.upsample( + x1_trans_i, (24, 28), mode='bilinear', align_corners=True + ) + x1_local_i = self.local_conv1(x1_trans_i) + x1_local_list.append(x1_local_i) + + # ============== Block 2 ============== + # Block 2 + # global branch + x2 = self.inception2(x1_out) + x2_attn, x2_theta = self.ha2(x2) + x2_out = x2 * x2_attn + # local branch + if self.learn_region: + x2_local_list = [] + for region_idx in range(4): + x2_theta_i = x2_theta[:, region_idx, :] + x2_theta_i = self.transform_theta(x2_theta_i, region_idx) + x2_trans_i = self.stn(x1_out, x2_theta_i) + x2_trans_i = F.upsample( + x2_trans_i, (12, 14), mode='bilinear', align_corners=True + ) + x2_local_i = x2_trans_i + x1_local_list[region_idx] + x2_local_i = self.local_conv2(x2_local_i) + x2_local_list.append(x2_local_i) + + # ============== Block 3 ============== + # Block 3 + # global branch + x3 = self.inception3(x2_out) + x3_attn, x3_theta = self.ha3(x3) + x3_out = x3 * x3_attn + # local branch + if self.learn_region: + x3_local_list = [] + for region_idx in range(4): + x3_theta_i = x3_theta[:, region_idx, :] + x3_theta_i = self.transform_theta(x3_theta_i, region_idx) + x3_trans_i = self.stn(x2_out, x3_theta_i) + x3_trans_i = F.upsample( + x3_trans_i, (6, 7), mode='bilinear', align_corners=True + ) + x3_local_i = x3_trans_i + x2_local_list[region_idx] + x3_local_i = self.local_conv3(x3_local_i) + x3_local_list.append(x3_local_i) + + # ============== Feature generation ============== + # global branch + x_global = F.avg_pool2d(x3_out, + x3_out.size()[2:] + ).view(x3_out.size(0), x3_out.size(1)) + x_global = self.fc_global(x_global) + # local branch + if self.learn_region: + x_local_list = [] + for region_idx in range(4): + x_local_i = x3_local_list[region_idx] + x_local_i = F.avg_pool2d(x_local_i, + x_local_i.size()[2:] + ).view(x_local_i.size(0), -1) + x_local_list.append(x_local_i) + x_local = torch.cat(x_local_list, 1) + x_local = self.fc_local(x_local) + + if not self.training: + # l2 normalization before concatenation + if self.learn_region: + x_global = x_global / x_global.norm(p=2, dim=1, keepdim=True) + x_local = x_local / x_local.norm(p=2, dim=1, keepdim=True) + return torch.cat([x_global, x_local], 1) + else: + return x_global + + prelogits_global = self.classifier_global(x_global) + if self.learn_region: + prelogits_local = self.classifier_local(x_local) + + if self.loss == 'softmax': + if self.learn_region: + return (prelogits_global, prelogits_local) + else: + return prelogits_global + + elif self.loss == 'triplet': + if self.learn_region: + return (prelogits_global, prelogits_local), (x_global, x_local) + else: + return prelogits_global, x_global + + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) diff --git a/feeder/trackers/strongsort/deep/models/inceptionresnetv2.py b/feeder/trackers/strongsort/deep/models/inceptionresnetv2.py new file mode 100644 index 0000000..03e4034 --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/inceptionresnetv2.py @@ -0,0 +1,361 @@ +""" +Code imported from https://github.com/Cadene/pretrained-models.pytorch +""" +from __future__ import division, absolute_import +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + +__all__ = ['inceptionresnetv2'] + +pretrained_settings = { + 'inceptionresnetv2': { + 'imagenet': { + 'url': + 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth', + 'input_space': 'RGB', + 'input_size': [3, 299, 299], + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1000 + }, + 'imagenet+background': { + 'url': + 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth', + 'input_space': 'RGB', + 'input_size': [3, 299, 299], + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1001 + } + } +} + + +class BasicConv2d(nn.Module): + + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=False + ) # verify bias false + self.bn = nn.BatchNorm2d( + out_planes, + eps=0.001, # value found in tensorflow + momentum=0.1, # default pytorch value + affine=True + ) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Mixed_5b(nn.Module): + + def __init__(self): + super(Mixed_5b, self).__init__() + + self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(192, 48, kernel_size=1, stride=1), + BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(192, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(192, 64, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block35(nn.Module): + + def __init__(self, scale=1.0): + super(Block35, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(320, 32, kernel_size=1, stride=1), + BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(320, 32, kernel_size=1, stride=1), + BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1), + BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1) + ) + + self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.relu(out) + return out + + +class Mixed_6a(nn.Module): + + def __init__(self): + super(Mixed_6a, self).__init__() + + self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + BasicConv2d(320, 256, kernel_size=1, stride=1), + BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), + BasicConv2d(256, 384, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Block17(nn.Module): + + def __init__(self, scale=1.0): + super(Block17, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(1088, 128, kernel_size=1, stride=1), + BasicConv2d( + 128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3) + ), + BasicConv2d( + 160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0) + ) + ) + + self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.relu(out) + return out + + +class Mixed_7a(nn.Module): + + def __init__(self): + super(Mixed_7a, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 384, kernel_size=3, stride=2) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 288, kernel_size=3, stride=2) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1), + BasicConv2d(288, 320, kernel_size=3, stride=2) + ) + + self.branch3 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block8(nn.Module): + + def __init__(self, scale=1.0, noReLU=False): + super(Block8, self).__init__() + + self.scale = scale + self.noReLU = noReLU + + self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(2080, 192, kernel_size=1, stride=1), + BasicConv2d( + 192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1) + ), + BasicConv2d( + 224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0) + ) + ) + + self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) + if not self.noReLU: + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + if not self.noReLU: + out = self.relu(out) + return out + + +# ---------------- +# Model Definition +# ---------------- +class InceptionResNetV2(nn.Module): + """Inception-ResNet-V2. + + Reference: + Szegedy et al. Inception-v4, Inception-ResNet and the Impact of Residual + Connections on Learning. AAAI 2017. + + Public keys: + - ``inceptionresnetv2``: Inception-ResNet-V2. + """ + + def __init__(self, num_classes, loss='softmax', **kwargs): + super(InceptionResNetV2, self).__init__() + self.loss = loss + + # Modules + self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) + self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) + self.conv2d_2b = BasicConv2d( + 32, 64, kernel_size=3, stride=1, padding=1 + ) + self.maxpool_3a = nn.MaxPool2d(3, stride=2) + self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) + self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) + self.maxpool_5a = nn.MaxPool2d(3, stride=2) + self.mixed_5b = Mixed_5b() + self.repeat = nn.Sequential( + Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17), + Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17), + Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17), + Block35(scale=0.17) + ) + self.mixed_6a = Mixed_6a() + self.repeat_1 = nn.Sequential( + Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), + Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), + Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), + Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), + Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), + Block17(scale=0.10), Block17(scale=0.10), Block17(scale=0.10), + Block17(scale=0.10), Block17(scale=0.10) + ) + self.mixed_7a = Mixed_7a() + self.repeat_2 = nn.Sequential( + Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20), + Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20), + Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20) + ) + + self.block8 = Block8(noReLU=True) + self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(1536, num_classes) + + def load_imagenet_weights(self): + settings = pretrained_settings['inceptionresnetv2']['imagenet'] + pretrain_dict = model_zoo.load_url(settings['url']) + model_dict = self.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + self.load_state_dict(model_dict) + + def featuremaps(self, x): + x = self.conv2d_1a(x) + x = self.conv2d_2a(x) + x = self.conv2d_2b(x) + x = self.maxpool_3a(x) + x = self.conv2d_3b(x) + x = self.conv2d_4a(x) + x = self.maxpool_5a(x) + x = self.mixed_5b(x) + x = self.repeat(x) + x = self.mixed_6a(x) + x = self.repeat_1(x) + x = self.mixed_7a(x) + x = self.repeat_2(x) + x = self.block8(x) + x = self.conv2d_7b(x) + return x + + def forward(self, x): + f = self.featuremaps(x) + v = self.global_avgpool(f) + v = v.view(v.size(0), -1) + + if not self.training: + return v + + y = self.classifier(v) + + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + return y, v + else: + raise KeyError('Unsupported loss: {}'.format(self.loss)) + + +def inceptionresnetv2(num_classes, loss='softmax', pretrained=True, **kwargs): + model = InceptionResNetV2(num_classes=num_classes, loss=loss, **kwargs) + if pretrained: + model.load_imagenet_weights() + return model diff --git a/feeder/trackers/strongsort/deep/models/inceptionv4.py b/feeder/trackers/strongsort/deep/models/inceptionv4.py new file mode 100644 index 0000000..b14916f --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/inceptionv4.py @@ -0,0 +1,381 @@ +from __future__ import division, absolute_import +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + +__all__ = ['inceptionv4'] +""" +Code imported from https://github.com/Cadene/pretrained-models.pytorch +""" + +pretrained_settings = { + 'inceptionv4': { + 'imagenet': { + 'url': + 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth', + 'input_space': 'RGB', + 'input_size': [3, 299, 299], + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1000 + }, + 'imagenet+background': { + 'url': + 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth', + 'input_space': 'RGB', + 'input_size': [3, 299, 299], + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1001 + } + } +} + + +class BasicConv2d(nn.Module): + + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=False + ) # verify bias false + self.bn = nn.BatchNorm2d( + out_planes, + eps=0.001, # value found in tensorflow + momentum=0.1, # default pytorch value + affine=True + ) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Mixed_3a(nn.Module): + + def __init__(self): + super(Mixed_3a, self).__init__() + self.maxpool = nn.MaxPool2d(3, stride=2) + self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) + + def forward(self, x): + x0 = self.maxpool(x) + x1 = self.conv(x) + out = torch.cat((x0, x1), 1) + return out + + +class Mixed_4a(nn.Module): + + def __init__(self): + super(Mixed_4a, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(160, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(160, 64, kernel_size=1, stride=1), + BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(64, 96, kernel_size=(3, 3), stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + return out + + +class Mixed_5a(nn.Module): + + def __init__(self): + super(Mixed_5a, self).__init__() + self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) + self.maxpool = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.conv(x) + x1 = self.maxpool(x) + out = torch.cat((x0, x1), 1) + return out + + +class Inception_A(nn.Module): + + def __init__(self): + super(Inception_A, self).__init__() + self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(384, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(384, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(384, 96, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Reduction_A(nn.Module): + + def __init__(self): + super(Reduction_A, self).__init__() + self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + BasicConv2d(384, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1), + BasicConv2d(224, 256, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Inception_B(nn.Module): + + def __init__(self): + super(Inception_B, self).__init__() + self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d( + 192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3) + ), + BasicConv2d( + 224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0) + ) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d( + 192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0) + ), + BasicConv2d( + 192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3) + ), + BasicConv2d( + 224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0) + ), + BasicConv2d( + 224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3) + ) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(1024, 128, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Reduction_B(nn.Module): + + def __init__(self): + super(Reduction_B, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 192, kernel_size=3, stride=2) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(1024, 256, kernel_size=1, stride=1), + BasicConv2d( + 256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3) + ), + BasicConv2d( + 256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0) + ), BasicConv2d(320, 320, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Inception_C(nn.Module): + + def __init__(self): + super(Inception_C, self).__init__() + + self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) + + self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) + self.branch1_1a = BasicConv2d( + 384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1) + ) + self.branch1_1b = BasicConv2d( + 384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0) + ) + + self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) + self.branch2_1 = BasicConv2d( + 384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0) + ) + self.branch2_2 = BasicConv2d( + 448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1) + ) + self.branch2_3a = BasicConv2d( + 512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1) + ) + self.branch2_3b = BasicConv2d( + 512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(1536, 256, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + + x1_0 = self.branch1_0(x) + x1_1a = self.branch1_1a(x1_0) + x1_1b = self.branch1_1b(x1_0) + x1 = torch.cat((x1_1a, x1_1b), 1) + + x2_0 = self.branch2_0(x) + x2_1 = self.branch2_1(x2_0) + x2_2 = self.branch2_2(x2_1) + x2_3a = self.branch2_3a(x2_2) + x2_3b = self.branch2_3b(x2_2) + x2 = torch.cat((x2_3a, x2_3b), 1) + + x3 = self.branch3(x) + + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class InceptionV4(nn.Module): + """Inception-v4. + + Reference: + Szegedy et al. Inception-v4, Inception-ResNet and the Impact of Residual + Connections on Learning. AAAI 2017. + + Public keys: + - ``inceptionv4``: InceptionV4. + """ + + def __init__(self, num_classes, loss, **kwargs): + super(InceptionV4, self).__init__() + self.loss = loss + + self.features = nn.Sequential( + BasicConv2d(3, 32, kernel_size=3, stride=2), + BasicConv2d(32, 32, kernel_size=3, stride=1), + BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), + Mixed_3a(), + Mixed_4a(), + Mixed_5a(), + Inception_A(), + Inception_A(), + Inception_A(), + Inception_A(), + Reduction_A(), # Mixed_6a + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Reduction_B(), # Mixed_7a + Inception_C(), + Inception_C(), + Inception_C() + ) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(1536, num_classes) + + def forward(self, x): + f = self.features(x) + v = self.global_avgpool(f) + v = v.view(v.size(0), -1) + + if not self.training: + return v + + y = self.classifier(v) + + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + return y, v + else: + raise KeyError('Unsupported loss: {}'.format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +def inceptionv4(num_classes, loss='softmax', pretrained=True, **kwargs): + model = InceptionV4(num_classes, loss, **kwargs) + if pretrained: + model_url = pretrained_settings['inceptionv4']['imagenet']['url'] + init_pretrained_weights(model, model_url) + return model diff --git a/feeder/trackers/strongsort/deep/models/mlfn.py b/feeder/trackers/strongsort/deep/models/mlfn.py new file mode 100644 index 0000000..ac7e126 --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/mlfn.py @@ -0,0 +1,269 @@ +from __future__ import division, absolute_import +import torch +import torch.utils.model_zoo as model_zoo +from torch import nn +from torch.nn import functional as F + +__all__ = ['mlfn'] + +model_urls = { + # training epoch = 5, top1 = 51.6 + 'imagenet': + 'https://mega.nz/#!YHxAhaxC!yu9E6zWl0x5zscSouTdbZu8gdFFytDdl-RAdD2DEfpk', +} + + +class MLFNBlock(nn.Module): + + def __init__( + self, in_channels, out_channels, stride, fsm_channels, groups=32 + ): + super(MLFNBlock, self).__init__() + self.groups = groups + mid_channels = out_channels // 2 + + # Factor Modules + self.fm_conv1 = nn.Conv2d(in_channels, mid_channels, 1, bias=False) + self.fm_bn1 = nn.BatchNorm2d(mid_channels) + self.fm_conv2 = nn.Conv2d( + mid_channels, + mid_channels, + 3, + stride=stride, + padding=1, + bias=False, + groups=self.groups + ) + self.fm_bn2 = nn.BatchNorm2d(mid_channels) + self.fm_conv3 = nn.Conv2d(mid_channels, out_channels, 1, bias=False) + self.fm_bn3 = nn.BatchNorm2d(out_channels) + + # Factor Selection Module + self.fsm = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, fsm_channels[0], 1), + nn.BatchNorm2d(fsm_channels[0]), + nn.ReLU(inplace=True), + nn.Conv2d(fsm_channels[0], fsm_channels[1], 1), + nn.BatchNorm2d(fsm_channels[1]), + nn.ReLU(inplace=True), + nn.Conv2d(fsm_channels[1], self.groups, 1), + nn.BatchNorm2d(self.groups), + nn.Sigmoid(), + ) + + self.downsample = None + if in_channels != out_channels or stride > 1: + self.downsample = nn.Sequential( + nn.Conv2d( + in_channels, out_channels, 1, stride=stride, bias=False + ), + nn.BatchNorm2d(out_channels), + ) + + def forward(self, x): + residual = x + s = self.fsm(x) + + # reduce dimension + x = self.fm_conv1(x) + x = self.fm_bn1(x) + x = F.relu(x, inplace=True) + + # group convolution + x = self.fm_conv2(x) + x = self.fm_bn2(x) + x = F.relu(x, inplace=True) + + # factor selection + b, c = x.size(0), x.size(1) + n = c // self.groups + ss = s.repeat(1, n, 1, 1) # from (b, g, 1, 1) to (b, g*n=c, 1, 1) + ss = ss.view(b, n, self.groups, 1, 1) + ss = ss.permute(0, 2, 1, 3, 4).contiguous() + ss = ss.view(b, c, 1, 1) + x = ss * x + + # recover dimension + x = self.fm_conv3(x) + x = self.fm_bn3(x) + x = F.relu(x, inplace=True) + + if self.downsample is not None: + residual = self.downsample(residual) + + return F.relu(residual + x, inplace=True), s + + +class MLFN(nn.Module): + """Multi-Level Factorisation Net. + + Reference: + Chang et al. Multi-Level Factorisation Net for + Person Re-Identification. CVPR 2018. + + Public keys: + - ``mlfn``: MLFN (Multi-Level Factorisation Net). + """ + + def __init__( + self, + num_classes, + loss='softmax', + groups=32, + channels=[64, 256, 512, 1024, 2048], + embed_dim=1024, + **kwargs + ): + super(MLFN, self).__init__() + self.loss = loss + self.groups = groups + + # first convolutional layer + self.conv1 = nn.Conv2d(3, channels[0], 7, stride=2, padding=3) + self.bn1 = nn.BatchNorm2d(channels[0]) + self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) + + # main body + self.feature = nn.ModuleList( + [ + # layer 1-3 + MLFNBlock(channels[0], channels[1], 1, [128, 64], self.groups), + MLFNBlock(channels[1], channels[1], 1, [128, 64], self.groups), + MLFNBlock(channels[1], channels[1], 1, [128, 64], self.groups), + # layer 4-7 + MLFNBlock( + channels[1], channels[2], 2, [256, 128], self.groups + ), + MLFNBlock( + channels[2], channels[2], 1, [256, 128], self.groups + ), + MLFNBlock( + channels[2], channels[2], 1, [256, 128], self.groups + ), + MLFNBlock( + channels[2], channels[2], 1, [256, 128], self.groups + ), + # layer 8-13 + MLFNBlock( + channels[2], channels[3], 2, [512, 128], self.groups + ), + MLFNBlock( + channels[3], channels[3], 1, [512, 128], self.groups + ), + MLFNBlock( + channels[3], channels[3], 1, [512, 128], self.groups + ), + MLFNBlock( + channels[3], channels[3], 1, [512, 128], self.groups + ), + MLFNBlock( + channels[3], channels[3], 1, [512, 128], self.groups + ), + MLFNBlock( + channels[3], channels[3], 1, [512, 128], self.groups + ), + # layer 14-16 + MLFNBlock( + channels[3], channels[4], 2, [512, 128], self.groups + ), + MLFNBlock( + channels[4], channels[4], 1, [512, 128], self.groups + ), + MLFNBlock( + channels[4], channels[4], 1, [512, 128], self.groups + ), + ] + ) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + + # projection functions + self.fc_x = nn.Sequential( + nn.Conv2d(channels[4], embed_dim, 1, bias=False), + nn.BatchNorm2d(embed_dim), + nn.ReLU(inplace=True), + ) + self.fc_s = nn.Sequential( + nn.Conv2d(self.groups * 16, embed_dim, 1, bias=False), + nn.BatchNorm2d(embed_dim), + nn.ReLU(inplace=True), + ) + + self.classifier = nn.Linear(embed_dim, num_classes) + + self.init_params() + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu' + ) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x, inplace=True) + x = self.maxpool(x) + + s_hat = [] + for block in self.feature: + x, s = block(x) + s_hat.append(s) + s_hat = torch.cat(s_hat, 1) + + x = self.global_avgpool(x) + x = self.fc_x(x) + s_hat = self.fc_s(s_hat) + + v = (x+s_hat) * 0.5 + v = v.view(v.size(0), -1) + + if not self.training: + return v + + y = self.classifier(v) + + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + return y, v + else: + raise KeyError('Unsupported loss: {}'.format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +def mlfn(num_classes, loss='softmax', pretrained=True, **kwargs): + model = MLFN(num_classes, loss, **kwargs) + if pretrained: + # init_pretrained_weights(model, model_urls['imagenet']) + import warnings + warnings.warn( + 'The imagenet pretrained weights need to be manually downloaded from {}' + .format(model_urls['imagenet']) + ) + return model diff --git a/feeder/trackers/strongsort/deep/models/mobilenetv2.py b/feeder/trackers/strongsort/deep/models/mobilenetv2.py new file mode 100644 index 0000000..c451ef8 --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/mobilenetv2.py @@ -0,0 +1,274 @@ +from __future__ import division, absolute_import +import torch.utils.model_zoo as model_zoo +from torch import nn +from torch.nn import functional as F + +__all__ = ['mobilenetv2_x1_0', 'mobilenetv2_x1_4'] + +model_urls = { + # 1.0: top-1 71.3 + 'mobilenetv2_x1_0': + 'https://mega.nz/#!NKp2wAIA!1NH1pbNzY_M2hVk_hdsxNM1NUOWvvGPHhaNr-fASF6c', + # 1.4: top-1 73.9 + 'mobilenetv2_x1_4': + 'https://mega.nz/#!RGhgEIwS!xN2s2ZdyqI6vQ3EwgmRXLEW3khr9tpXg96G9SUJugGk', +} + + +class ConvBlock(nn.Module): + """Basic convolutional block. + + convolution (bias discarded) + batch normalization + relu6. + + Args: + in_c (int): number of input channels. + out_c (int): number of output channels. + k (int or tuple): kernel size. + s (int or tuple): stride. + p (int or tuple): padding. + g (int): number of blocked connections from input channels + to output channels (default: 1). + """ + + def __init__(self, in_c, out_c, k, s=1, p=0, g=1): + super(ConvBlock, self).__init__() + self.conv = nn.Conv2d( + in_c, out_c, k, stride=s, padding=p, bias=False, groups=g + ) + self.bn = nn.BatchNorm2d(out_c) + + def forward(self, x): + return F.relu6(self.bn(self.conv(x))) + + +class Bottleneck(nn.Module): + + def __init__(self, in_channels, out_channels, expansion_factor, stride=1): + super(Bottleneck, self).__init__() + mid_channels = in_channels * expansion_factor + self.use_residual = stride == 1 and in_channels == out_channels + self.conv1 = ConvBlock(in_channels, mid_channels, 1) + self.dwconv2 = ConvBlock( + mid_channels, mid_channels, 3, stride, 1, g=mid_channels + ) + self.conv3 = nn.Sequential( + nn.Conv2d(mid_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + ) + + def forward(self, x): + m = self.conv1(x) + m = self.dwconv2(m) + m = self.conv3(m) + if self.use_residual: + return x + m + else: + return m + + +class MobileNetV2(nn.Module): + """MobileNetV2. + + Reference: + Sandler et al. MobileNetV2: Inverted Residuals and + Linear Bottlenecks. CVPR 2018. + + Public keys: + - ``mobilenetv2_x1_0``: MobileNetV2 x1.0. + - ``mobilenetv2_x1_4``: MobileNetV2 x1.4. + """ + + def __init__( + self, + num_classes, + width_mult=1, + loss='softmax', + fc_dims=None, + dropout_p=None, + **kwargs + ): + super(MobileNetV2, self).__init__() + self.loss = loss + self.in_channels = int(32 * width_mult) + self.feature_dim = int(1280 * width_mult) if width_mult > 1 else 1280 + + # construct layers + self.conv1 = ConvBlock(3, self.in_channels, 3, s=2, p=1) + self.conv2 = self._make_layer( + Bottleneck, 1, int(16 * width_mult), 1, 1 + ) + self.conv3 = self._make_layer( + Bottleneck, 6, int(24 * width_mult), 2, 2 + ) + self.conv4 = self._make_layer( + Bottleneck, 6, int(32 * width_mult), 3, 2 + ) + self.conv5 = self._make_layer( + Bottleneck, 6, int(64 * width_mult), 4, 2 + ) + self.conv6 = self._make_layer( + Bottleneck, 6, int(96 * width_mult), 3, 1 + ) + self.conv7 = self._make_layer( + Bottleneck, 6, int(160 * width_mult), 3, 2 + ) + self.conv8 = self._make_layer( + Bottleneck, 6, int(320 * width_mult), 1, 1 + ) + self.conv9 = ConvBlock(self.in_channels, self.feature_dim, 1) + + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = self._construct_fc_layer( + fc_dims, self.feature_dim, dropout_p + ) + self.classifier = nn.Linear(self.feature_dim, num_classes) + + self._init_params() + + def _make_layer(self, block, t, c, n, s): + # t: expansion factor + # c: output channels + # n: number of blocks + # s: stride for first layer + layers = [] + layers.append(block(self.in_channels, c, t, s)) + self.in_channels = c + for i in range(1, n): + layers.append(block(self.in_channels, c, t)) + return nn.Sequential(*layers) + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + """Constructs fully connected layer. + + Args: + fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed + input_dim (int): input dimension + dropout_p (float): dropout probability, if None, dropout is unused + """ + if fc_dims is None: + self.feature_dim = input_dim + return None + + assert isinstance( + fc_dims, (list, tuple) + ), 'fc_dims must be either list or tuple, but got {}'.format( + type(fc_dims) + ) + + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU(inplace=True)) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + + self.feature_dim = fc_dims[-1] + + return nn.Sequential(*layers) + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu' + ) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def featuremaps(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = self.conv5(x) + x = self.conv6(x) + x = self.conv7(x) + x = self.conv8(x) + x = self.conv9(x) + return x + + def forward(self, x): + f = self.featuremaps(x) + v = self.global_avgpool(f) + v = v.view(v.size(0), -1) + + if self.fc is not None: + v = self.fc(v) + + if not self.training: + return v + + y = self.classifier(v) + + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + return y, v + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +def mobilenetv2_x1_0(num_classes, loss, pretrained=True, **kwargs): + model = MobileNetV2( + num_classes, + loss=loss, + width_mult=1, + fc_dims=None, + dropout_p=None, + **kwargs + ) + if pretrained: + # init_pretrained_weights(model, model_urls['mobilenetv2_x1_0']) + import warnings + warnings.warn( + 'The imagenet pretrained weights need to be manually downloaded from {}' + .format(model_urls['mobilenetv2_x1_0']) + ) + return model + + +def mobilenetv2_x1_4(num_classes, loss, pretrained=True, **kwargs): + model = MobileNetV2( + num_classes, + loss=loss, + width_mult=1.4, + fc_dims=None, + dropout_p=None, + **kwargs + ) + if pretrained: + # init_pretrained_weights(model, model_urls['mobilenetv2_x1_4']) + import warnings + warnings.warn( + 'The imagenet pretrained weights need to be manually downloaded from {}' + .format(model_urls['mobilenetv2_x1_4']) + ) + return model diff --git a/feeder/trackers/strongsort/deep/models/mudeep.py b/feeder/trackers/strongsort/deep/models/mudeep.py new file mode 100644 index 0000000..ddbca67 --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/mudeep.py @@ -0,0 +1,206 @@ +from __future__ import division, absolute_import +import torch +from torch import nn +from torch.nn import functional as F + +__all__ = ['MuDeep'] + + +class ConvBlock(nn.Module): + """Basic convolutional block. + + convolution + batch normalization + relu. + + Args: + in_c (int): number of input channels. + out_c (int): number of output channels. + k (int or tuple): kernel size. + s (int or tuple): stride. + p (int or tuple): padding. + """ + + def __init__(self, in_c, out_c, k, s, p): + super(ConvBlock, self).__init__() + self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p) + self.bn = nn.BatchNorm2d(out_c) + + def forward(self, x): + return F.relu(self.bn(self.conv(x))) + + +class ConvLayers(nn.Module): + """Preprocessing layers.""" + + def __init__(self): + super(ConvLayers, self).__init__() + self.conv1 = ConvBlock(3, 48, k=3, s=1, p=1) + self.conv2 = ConvBlock(48, 96, k=3, s=1, p=1) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.maxpool(x) + return x + + +class MultiScaleA(nn.Module): + """Multi-scale stream layer A (Sec.3.1)""" + + def __init__(self): + super(MultiScaleA, self).__init__() + self.stream1 = nn.Sequential( + ConvBlock(96, 96, k=1, s=1, p=0), + ConvBlock(96, 24, k=3, s=1, p=1), + ) + self.stream2 = nn.Sequential( + nn.AvgPool2d(kernel_size=3, stride=1, padding=1), + ConvBlock(96, 24, k=1, s=1, p=0), + ) + self.stream3 = ConvBlock(96, 24, k=1, s=1, p=0) + self.stream4 = nn.Sequential( + ConvBlock(96, 16, k=1, s=1, p=0), + ConvBlock(16, 24, k=3, s=1, p=1), + ConvBlock(24, 24, k=3, s=1, p=1), + ) + + def forward(self, x): + s1 = self.stream1(x) + s2 = self.stream2(x) + s3 = self.stream3(x) + s4 = self.stream4(x) + y = torch.cat([s1, s2, s3, s4], dim=1) + return y + + +class Reduction(nn.Module): + """Reduction layer (Sec.3.1)""" + + def __init__(self): + super(Reduction, self).__init__() + self.stream1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.stream2 = ConvBlock(96, 96, k=3, s=2, p=1) + self.stream3 = nn.Sequential( + ConvBlock(96, 48, k=1, s=1, p=0), + ConvBlock(48, 56, k=3, s=1, p=1), + ConvBlock(56, 64, k=3, s=2, p=1), + ) + + def forward(self, x): + s1 = self.stream1(x) + s2 = self.stream2(x) + s3 = self.stream3(x) + y = torch.cat([s1, s2, s3], dim=1) + return y + + +class MultiScaleB(nn.Module): + """Multi-scale stream layer B (Sec.3.1)""" + + def __init__(self): + super(MultiScaleB, self).__init__() + self.stream1 = nn.Sequential( + nn.AvgPool2d(kernel_size=3, stride=1, padding=1), + ConvBlock(256, 256, k=1, s=1, p=0), + ) + self.stream2 = nn.Sequential( + ConvBlock(256, 64, k=1, s=1, p=0), + ConvBlock(64, 128, k=(1, 3), s=1, p=(0, 1)), + ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)), + ) + self.stream3 = ConvBlock(256, 256, k=1, s=1, p=0) + self.stream4 = nn.Sequential( + ConvBlock(256, 64, k=1, s=1, p=0), + ConvBlock(64, 64, k=(1, 3), s=1, p=(0, 1)), + ConvBlock(64, 128, k=(3, 1), s=1, p=(1, 0)), + ConvBlock(128, 128, k=(1, 3), s=1, p=(0, 1)), + ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)), + ) + + def forward(self, x): + s1 = self.stream1(x) + s2 = self.stream2(x) + s3 = self.stream3(x) + s4 = self.stream4(x) + return s1, s2, s3, s4 + + +class Fusion(nn.Module): + """Saliency-based learning fusion layer (Sec.3.2)""" + + def __init__(self): + super(Fusion, self).__init__() + self.a1 = nn.Parameter(torch.rand(1, 256, 1, 1)) + self.a2 = nn.Parameter(torch.rand(1, 256, 1, 1)) + self.a3 = nn.Parameter(torch.rand(1, 256, 1, 1)) + self.a4 = nn.Parameter(torch.rand(1, 256, 1, 1)) + + # We add an average pooling layer to reduce the spatial dimension + # of feature maps, which differs from the original paper. + self.avgpool = nn.AvgPool2d(kernel_size=4, stride=4, padding=0) + + def forward(self, x1, x2, x3, x4): + s1 = self.a1.expand_as(x1) * x1 + s2 = self.a2.expand_as(x2) * x2 + s3 = self.a3.expand_as(x3) * x3 + s4 = self.a4.expand_as(x4) * x4 + y = self.avgpool(s1 + s2 + s3 + s4) + return y + + +class MuDeep(nn.Module): + """Multiscale deep neural network. + + Reference: + Qian et al. Multi-scale Deep Learning Architectures + for Person Re-identification. ICCV 2017. + + Public keys: + - ``mudeep``: Multiscale deep neural network. + """ + + def __init__(self, num_classes, loss='softmax', **kwargs): + super(MuDeep, self).__init__() + self.loss = loss + + self.block1 = ConvLayers() + self.block2 = MultiScaleA() + self.block3 = Reduction() + self.block4 = MultiScaleB() + self.block5 = Fusion() + + # Due to this fully connected layer, input image has to be fixed + # in shape, i.e. (3, 256, 128), such that the last convolutional feature + # maps are of shape (256, 16, 8). If input shape is changed, + # the input dimension of this layer has to be changed accordingly. + self.fc = nn.Sequential( + nn.Linear(256 * 16 * 8, 4096), + nn.BatchNorm1d(4096), + nn.ReLU(), + ) + self.classifier = nn.Linear(4096, num_classes) + self.feat_dim = 4096 + + def featuremaps(self, x): + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.block5(*x) + return x + + def forward(self, x): + x = self.featuremaps(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + y = self.classifier(x) + + if not self.training: + return x + + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + return y, x + else: + raise KeyError('Unsupported loss: {}'.format(self.loss)) diff --git a/feeder/trackers/strongsort/deep/models/nasnet.py b/feeder/trackers/strongsort/deep/models/nasnet.py new file mode 100644 index 0000000..b1f31de --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/nasnet.py @@ -0,0 +1,1131 @@ +from __future__ import division, absolute_import +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo + +__all__ = ['nasnetamobile'] +""" +NASNet Mobile +Thanks to Anastasiia (https://github.com/DagnyT) for the great help, support and motivation! + + +------------------------------------------------------------------------------------ + Architecture | Top-1 Acc | Top-5 Acc | Multiply-Adds | Params (M) +------------------------------------------------------------------------------------ +| NASNet-A (4 @ 1056) | 74.08% | 91.74% | 564 M | 5.3 | +------------------------------------------------------------------------------------ +# References: + - [Learning Transferable Architectures for Scalable Image Recognition] + (https://arxiv.org/abs/1707.07012) +""" +""" +Code imported from https://github.com/Cadene/pretrained-models.pytorch +""" + +pretrained_settings = { + 'nasnetamobile': { + 'imagenet': { + # 'url': 'https://github.com/veronikayurchuk/pretrained-models.pytorch/releases/download/v1.0/nasnetmobile-7e03cead.pth.tar', + 'url': + 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetamobile-7e03cead.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], # resize 256 + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1000 + }, + # 'imagenet+background': { + # # 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth', + # 'input_space': 'RGB', + # 'input_size': [3, 224, 224], # resize 256 + # 'input_range': [0, 1], + # 'mean': [0.5, 0.5, 0.5], + # 'std': [0.5, 0.5, 0.5], + # 'num_classes': 1001 + # } + } +} + + +class MaxPoolPad(nn.Module): + + def __init__(self): + super(MaxPoolPad, self).__init__() + self.pad = nn.ZeroPad2d((1, 0, 1, 0)) + self.pool = nn.MaxPool2d(3, stride=2, padding=1) + + def forward(self, x): + x = self.pad(x) + x = self.pool(x) + x = x[:, :, 1:, 1:].contiguous() + return x + + +class AvgPoolPad(nn.Module): + + def __init__(self, stride=2, padding=1): + super(AvgPoolPad, self).__init__() + self.pad = nn.ZeroPad2d((1, 0, 1, 0)) + self.pool = nn.AvgPool2d( + 3, stride=stride, padding=padding, count_include_pad=False + ) + + def forward(self, x): + x = self.pad(x) + x = self.pool(x) + x = x[:, :, 1:, 1:].contiguous() + return x + + +class SeparableConv2d(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + dw_kernel, + dw_stride, + dw_padding, + bias=False + ): + super(SeparableConv2d, self).__init__() + self.depthwise_conv2d = nn.Conv2d( + in_channels, + in_channels, + dw_kernel, + stride=dw_stride, + padding=dw_padding, + bias=bias, + groups=in_channels + ) + self.pointwise_conv2d = nn.Conv2d( + in_channels, out_channels, 1, stride=1, bias=bias + ) + + def forward(self, x): + x = self.depthwise_conv2d(x) + x = self.pointwise_conv2d(x) + return x + + +class BranchSeparables(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + name=None, + bias=False + ): + super(BranchSeparables, self).__init__() + self.relu = nn.ReLU() + self.separable_1 = SeparableConv2d( + in_channels, in_channels, kernel_size, stride, padding, bias=bias + ) + self.bn_sep_1 = nn.BatchNorm2d( + in_channels, eps=0.001, momentum=0.1, affine=True + ) + self.relu1 = nn.ReLU() + self.separable_2 = SeparableConv2d( + in_channels, out_channels, kernel_size, 1, padding, bias=bias + ) + self.bn_sep_2 = nn.BatchNorm2d( + out_channels, eps=0.001, momentum=0.1, affine=True + ) + self.name = name + + def forward(self, x): + x = self.relu(x) + if self.name == 'specific': + x = nn.ZeroPad2d((1, 0, 1, 0))(x) + x = self.separable_1(x) + if self.name == 'specific': + x = x[:, :, 1:, 1:].contiguous() + + x = self.bn_sep_1(x) + x = self.relu1(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class BranchSeparablesStem(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias=False + ): + super(BranchSeparablesStem, self).__init__() + self.relu = nn.ReLU() + self.separable_1 = SeparableConv2d( + in_channels, out_channels, kernel_size, stride, padding, bias=bias + ) + self.bn_sep_1 = nn.BatchNorm2d( + out_channels, eps=0.001, momentum=0.1, affine=True + ) + self.relu1 = nn.ReLU() + self.separable_2 = SeparableConv2d( + out_channels, out_channels, kernel_size, 1, padding, bias=bias + ) + self.bn_sep_2 = nn.BatchNorm2d( + out_channels, eps=0.001, momentum=0.1, affine=True + ) + + def forward(self, x): + x = self.relu(x) + x = self.separable_1(x) + x = self.bn_sep_1(x) + x = self.relu1(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class BranchSeparablesReduction(BranchSeparables): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + z_padding=1, + bias=False + ): + BranchSeparables.__init__( + self, in_channels, out_channels, kernel_size, stride, padding, bias + ) + self.padding = nn.ZeroPad2d((z_padding, 0, z_padding, 0)) + + def forward(self, x): + x = self.relu(x) + x = self.padding(x) + x = self.separable_1(x) + x = x[:, :, 1:, 1:].contiguous() + x = self.bn_sep_1(x) + x = self.relu1(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class CellStem0(nn.Module): + + def __init__(self, stem_filters, num_filters=42): + super(CellStem0, self).__init__() + self.num_filters = num_filters + self.stem_filters = stem_filters + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module( + 'conv', + nn.Conv2d( + self.stem_filters, self.num_filters, 1, stride=1, bias=False + ) + ) + self.conv_1x1.add_module( + 'bn', + nn.BatchNorm2d( + self.num_filters, eps=0.001, momentum=0.1, affine=True + ) + ) + + self.comb_iter_0_left = BranchSeparables( + self.num_filters, self.num_filters, 5, 2, 2 + ) + self.comb_iter_0_right = BranchSeparablesStem( + self.stem_filters, self.num_filters, 7, 2, 3, bias=False + ) + + self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_1_right = BranchSeparablesStem( + self.stem_filters, self.num_filters, 7, 2, 3, bias=False + ) + + self.comb_iter_2_left = nn.AvgPool2d( + 3, stride=2, padding=1, count_include_pad=False + ) + self.comb_iter_2_right = BranchSeparablesStem( + self.stem_filters, self.num_filters, 5, 2, 2, bias=False + ) + + self.comb_iter_3_right = nn.AvgPool2d( + 3, stride=1, padding=1, count_include_pad=False + ) + + self.comb_iter_4_left = BranchSeparables( + self.num_filters, self.num_filters, 3, 1, 1, bias=False + ) + self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1) + + def forward(self, x): + x1 = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x1) + x_comb_iter_0_right = self.comb_iter_0_right(x) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x1) + x_comb_iter_1_right = self.comb_iter_1_right(x) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x1) + x_comb_iter_2_right = self.comb_iter_2_right(x) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x1) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat( + [x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1 + ) + return x_out + + +class CellStem1(nn.Module): + + def __init__(self, stem_filters, num_filters): + super(CellStem1, self).__init__() + self.num_filters = num_filters + self.stem_filters = stem_filters + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module( + 'conv', + nn.Conv2d( + 2 * self.num_filters, + self.num_filters, + 1, + stride=1, + bias=False + ) + ) + self.conv_1x1.add_module( + 'bn', + nn.BatchNorm2d( + self.num_filters, eps=0.001, momentum=0.1, affine=True + ) + ) + + self.relu = nn.ReLU() + self.path_1 = nn.Sequential() + self.path_1.add_module( + 'avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False) + ) + self.path_1.add_module( + 'conv', + nn.Conv2d( + self.stem_filters, + self.num_filters // 2, + 1, + stride=1, + bias=False + ) + ) + self.path_2 = nn.ModuleList() + self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1))) + self.path_2.add_module( + 'avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False) + ) + self.path_2.add_module( + 'conv', + nn.Conv2d( + self.stem_filters, + self.num_filters // 2, + 1, + stride=1, + bias=False + ) + ) + + self.final_path_bn = nn.BatchNorm2d( + self.num_filters, eps=0.001, momentum=0.1, affine=True + ) + + self.comb_iter_0_left = BranchSeparables( + self.num_filters, + self.num_filters, + 5, + 2, + 2, + name='specific', + bias=False + ) + self.comb_iter_0_right = BranchSeparables( + self.num_filters, + self.num_filters, + 7, + 2, + 3, + name='specific', + bias=False + ) + + # self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_1_left = MaxPoolPad() + self.comb_iter_1_right = BranchSeparables( + self.num_filters, + self.num_filters, + 7, + 2, + 3, + name='specific', + bias=False + ) + + # self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) + self.comb_iter_2_left = AvgPoolPad() + self.comb_iter_2_right = BranchSeparables( + self.num_filters, + self.num_filters, + 5, + 2, + 2, + name='specific', + bias=False + ) + + self.comb_iter_3_right = nn.AvgPool2d( + 3, stride=1, padding=1, count_include_pad=False + ) + + self.comb_iter_4_left = BranchSeparables( + self.num_filters, + self.num_filters, + 3, + 1, + 1, + name='specific', + bias=False + ) + # self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_4_right = MaxPoolPad() + + def forward(self, x_conv0, x_stem_0): + x_left = self.conv_1x1(x_stem_0) + + x_relu = self.relu(x_conv0) + # path 1 + x_path1 = self.path_1(x_relu) + # path 2 + x_path2 = self.path_2.pad(x_relu) + x_path2 = x_path2[:, :, 1:, 1:] + x_path2 = self.path_2.avgpool(x_path2) + x_path2 = self.path_2.conv(x_path2) + # final path + x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + + x_comb_iter_0_left = self.comb_iter_0_left(x_left) + x_comb_iter_0_right = self.comb_iter_0_right(x_right) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_right) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_left) + x_comb_iter_2_right = self.comb_iter_2_right(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_left) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat( + [x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1 + ) + return x_out + + +class FirstCell(nn.Module): + + def __init__( + self, in_channels_left, out_channels_left, in_channels_right, + out_channels_right + ): + super(FirstCell, self).__init__() + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module( + 'conv', + nn.Conv2d( + in_channels_right, out_channels_right, 1, stride=1, bias=False + ) + ) + self.conv_1x1.add_module( + 'bn', + nn.BatchNorm2d( + out_channels_right, eps=0.001, momentum=0.1, affine=True + ) + ) + + self.relu = nn.ReLU() + self.path_1 = nn.Sequential() + self.path_1.add_module( + 'avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False) + ) + self.path_1.add_module( + 'conv', + nn.Conv2d( + in_channels_left, out_channels_left, 1, stride=1, bias=False + ) + ) + self.path_2 = nn.ModuleList() + self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1))) + self.path_2.add_module( + 'avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False) + ) + self.path_2.add_module( + 'conv', + nn.Conv2d( + in_channels_left, out_channels_left, 1, stride=1, bias=False + ) + ) + + self.final_path_bn = nn.BatchNorm2d( + out_channels_left * 2, eps=0.001, momentum=0.1, affine=True + ) + + self.comb_iter_0_left = BranchSeparables( + out_channels_right, out_channels_right, 5, 1, 2, bias=False + ) + self.comb_iter_0_right = BranchSeparables( + out_channels_right, out_channels_right, 3, 1, 1, bias=False + ) + + self.comb_iter_1_left = BranchSeparables( + out_channels_right, out_channels_right, 5, 1, 2, bias=False + ) + self.comb_iter_1_right = BranchSeparables( + out_channels_right, out_channels_right, 3, 1, 1, bias=False + ) + + self.comb_iter_2_left = nn.AvgPool2d( + 3, stride=1, padding=1, count_include_pad=False + ) + + self.comb_iter_3_left = nn.AvgPool2d( + 3, stride=1, padding=1, count_include_pad=False + ) + self.comb_iter_3_right = nn.AvgPool2d( + 3, stride=1, padding=1, count_include_pad=False + ) + + self.comb_iter_4_left = BranchSeparables( + out_channels_right, out_channels_right, 3, 1, 1, bias=False + ) + + def forward(self, x, x_prev): + x_relu = self.relu(x_prev) + # path 1 + x_path1 = self.path_1(x_relu) + # path 2 + x_path2 = self.path_2.pad(x_relu) + x_path2 = x_path2[:, :, 1:, 1:] + x_path2 = self.path_2.avgpool(x_path2) + x_path2 = self.path_2.conv(x_path2) + # final path + x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_left + + x_comb_iter_3_left = self.comb_iter_3_left(x_left) + x_comb_iter_3_right = self.comb_iter_3_right(x_left) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_right + + x_out = torch.cat( + [ + x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, + x_comb_iter_3, x_comb_iter_4 + ], 1 + ) + return x_out + + +class NormalCell(nn.Module): + + def __init__( + self, in_channels_left, out_channels_left, in_channels_right, + out_channels_right + ): + super(NormalCell, self).__init__() + self.conv_prev_1x1 = nn.Sequential() + self.conv_prev_1x1.add_module('relu', nn.ReLU()) + self.conv_prev_1x1.add_module( + 'conv', + nn.Conv2d( + in_channels_left, out_channels_left, 1, stride=1, bias=False + ) + ) + self.conv_prev_1x1.add_module( + 'bn', + nn.BatchNorm2d( + out_channels_left, eps=0.001, momentum=0.1, affine=True + ) + ) + + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module( + 'conv', + nn.Conv2d( + in_channels_right, out_channels_right, 1, stride=1, bias=False + ) + ) + self.conv_1x1.add_module( + 'bn', + nn.BatchNorm2d( + out_channels_right, eps=0.001, momentum=0.1, affine=True + ) + ) + + self.comb_iter_0_left = BranchSeparables( + out_channels_right, out_channels_right, 5, 1, 2, bias=False + ) + self.comb_iter_0_right = BranchSeparables( + out_channels_left, out_channels_left, 3, 1, 1, bias=False + ) + + self.comb_iter_1_left = BranchSeparables( + out_channels_left, out_channels_left, 5, 1, 2, bias=False + ) + self.comb_iter_1_right = BranchSeparables( + out_channels_left, out_channels_left, 3, 1, 1, bias=False + ) + + self.comb_iter_2_left = nn.AvgPool2d( + 3, stride=1, padding=1, count_include_pad=False + ) + + self.comb_iter_3_left = nn.AvgPool2d( + 3, stride=1, padding=1, count_include_pad=False + ) + self.comb_iter_3_right = nn.AvgPool2d( + 3, stride=1, padding=1, count_include_pad=False + ) + + self.comb_iter_4_left = BranchSeparables( + out_channels_right, out_channels_right, 3, 1, 1, bias=False + ) + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_left + + x_comb_iter_3_left = self.comb_iter_3_left(x_left) + x_comb_iter_3_right = self.comb_iter_3_right(x_left) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_right + + x_out = torch.cat( + [ + x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, + x_comb_iter_3, x_comb_iter_4 + ], 1 + ) + return x_out + + +class ReductionCell0(nn.Module): + + def __init__( + self, in_channels_left, out_channels_left, in_channels_right, + out_channels_right + ): + super(ReductionCell0, self).__init__() + self.conv_prev_1x1 = nn.Sequential() + self.conv_prev_1x1.add_module('relu', nn.ReLU()) + self.conv_prev_1x1.add_module( + 'conv', + nn.Conv2d( + in_channels_left, out_channels_left, 1, stride=1, bias=False + ) + ) + self.conv_prev_1x1.add_module( + 'bn', + nn.BatchNorm2d( + out_channels_left, eps=0.001, momentum=0.1, affine=True + ) + ) + + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module( + 'conv', + nn.Conv2d( + in_channels_right, out_channels_right, 1, stride=1, bias=False + ) + ) + self.conv_1x1.add_module( + 'bn', + nn.BatchNorm2d( + out_channels_right, eps=0.001, momentum=0.1, affine=True + ) + ) + + self.comb_iter_0_left = BranchSeparablesReduction( + out_channels_right, out_channels_right, 5, 2, 2, bias=False + ) + self.comb_iter_0_right = BranchSeparablesReduction( + out_channels_right, out_channels_right, 7, 2, 3, bias=False + ) + + self.comb_iter_1_left = MaxPoolPad() + self.comb_iter_1_right = BranchSeparablesReduction( + out_channels_right, out_channels_right, 7, 2, 3, bias=False + ) + + self.comb_iter_2_left = AvgPoolPad() + self.comb_iter_2_right = BranchSeparablesReduction( + out_channels_right, out_channels_right, 5, 2, 2, bias=False + ) + + self.comb_iter_3_right = nn.AvgPool2d( + 3, stride=1, padding=1, count_include_pad=False + ) + + self.comb_iter_4_left = BranchSeparablesReduction( + out_channels_right, out_channels_right, 3, 1, 1, bias=False + ) + self.comb_iter_4_right = MaxPoolPad() + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_left) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat( + [x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1 + ) + return x_out + + +class ReductionCell1(nn.Module): + + def __init__( + self, in_channels_left, out_channels_left, in_channels_right, + out_channels_right + ): + super(ReductionCell1, self).__init__() + self.conv_prev_1x1 = nn.Sequential() + self.conv_prev_1x1.add_module('relu', nn.ReLU()) + self.conv_prev_1x1.add_module( + 'conv', + nn.Conv2d( + in_channels_left, out_channels_left, 1, stride=1, bias=False + ) + ) + self.conv_prev_1x1.add_module( + 'bn', + nn.BatchNorm2d( + out_channels_left, eps=0.001, momentum=0.1, affine=True + ) + ) + + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module( + 'conv', + nn.Conv2d( + in_channels_right, out_channels_right, 1, stride=1, bias=False + ) + ) + self.conv_1x1.add_module( + 'bn', + nn.BatchNorm2d( + out_channels_right, eps=0.001, momentum=0.1, affine=True + ) + ) + + self.comb_iter_0_left = BranchSeparables( + out_channels_right, + out_channels_right, + 5, + 2, + 2, + name='specific', + bias=False + ) + self.comb_iter_0_right = BranchSeparables( + out_channels_right, + out_channels_right, + 7, + 2, + 3, + name='specific', + bias=False + ) + + # self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_1_left = MaxPoolPad() + self.comb_iter_1_right = BranchSeparables( + out_channels_right, + out_channels_right, + 7, + 2, + 3, + name='specific', + bias=False + ) + + # self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) + self.comb_iter_2_left = AvgPoolPad() + self.comb_iter_2_right = BranchSeparables( + out_channels_right, + out_channels_right, + 5, + 2, + 2, + name='specific', + bias=False + ) + + self.comb_iter_3_right = nn.AvgPool2d( + 3, stride=1, padding=1, count_include_pad=False + ) + + self.comb_iter_4_left = BranchSeparables( + out_channels_right, + out_channels_right, + 3, + 1, + 1, + name='specific', + bias=False + ) + # self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_4_right = MaxPoolPad() + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_left) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat( + [x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1 + ) + return x_out + + +class NASNetAMobile(nn.Module): + """Neural Architecture Search (NAS). + + Reference: + Zoph et al. Learning Transferable Architectures + for Scalable Image Recognition. CVPR 2018. + + Public keys: + - ``nasnetamobile``: NASNet-A Mobile. + """ + + def __init__( + self, + num_classes, + loss, + stem_filters=32, + penultimate_filters=1056, + filters_multiplier=2, + **kwargs + ): + super(NASNetAMobile, self).__init__() + self.stem_filters = stem_filters + self.penultimate_filters = penultimate_filters + self.filters_multiplier = filters_multiplier + self.loss = loss + + filters = self.penultimate_filters // 24 + # 24 is default value for the architecture + + self.conv0 = nn.Sequential() + self.conv0.add_module( + 'conv', + nn.Conv2d( + in_channels=3, + out_channels=self.stem_filters, + kernel_size=3, + padding=0, + stride=2, + bias=False + ) + ) + self.conv0.add_module( + 'bn', + nn.BatchNorm2d( + self.stem_filters, eps=0.001, momentum=0.1, affine=True + ) + ) + + self.cell_stem_0 = CellStem0( + self.stem_filters, num_filters=filters // (filters_multiplier**2) + ) + self.cell_stem_1 = CellStem1( + self.stem_filters, num_filters=filters // filters_multiplier + ) + + self.cell_0 = FirstCell( + in_channels_left=filters, + out_channels_left=filters // 2, # 1, 0.5 + in_channels_right=2 * filters, + out_channels_right=filters + ) # 2, 1 + self.cell_1 = NormalCell( + in_channels_left=2 * filters, + out_channels_left=filters, # 2, 1 + in_channels_right=6 * filters, + out_channels_right=filters + ) # 6, 1 + self.cell_2 = NormalCell( + in_channels_left=6 * filters, + out_channels_left=filters, # 6, 1 + in_channels_right=6 * filters, + out_channels_right=filters + ) # 6, 1 + self.cell_3 = NormalCell( + in_channels_left=6 * filters, + out_channels_left=filters, # 6, 1 + in_channels_right=6 * filters, + out_channels_right=filters + ) # 6, 1 + + self.reduction_cell_0 = ReductionCell0( + in_channels_left=6 * filters, + out_channels_left=2 * filters, # 6, 2 + in_channels_right=6 * filters, + out_channels_right=2 * filters + ) # 6, 2 + + self.cell_6 = FirstCell( + in_channels_left=6 * filters, + out_channels_left=filters, # 6, 1 + in_channels_right=8 * filters, + out_channels_right=2 * filters + ) # 8, 2 + self.cell_7 = NormalCell( + in_channels_left=8 * filters, + out_channels_left=2 * filters, # 8, 2 + in_channels_right=12 * filters, + out_channels_right=2 * filters + ) # 12, 2 + self.cell_8 = NormalCell( + in_channels_left=12 * filters, + out_channels_left=2 * filters, # 12, 2 + in_channels_right=12 * filters, + out_channels_right=2 * filters + ) # 12, 2 + self.cell_9 = NormalCell( + in_channels_left=12 * filters, + out_channels_left=2 * filters, # 12, 2 + in_channels_right=12 * filters, + out_channels_right=2 * filters + ) # 12, 2 + + self.reduction_cell_1 = ReductionCell1( + in_channels_left=12 * filters, + out_channels_left=4 * filters, # 12, 4 + in_channels_right=12 * filters, + out_channels_right=4 * filters + ) # 12, 4 + + self.cell_12 = FirstCell( + in_channels_left=12 * filters, + out_channels_left=2 * filters, # 12, 2 + in_channels_right=16 * filters, + out_channels_right=4 * filters + ) # 16, 4 + self.cell_13 = NormalCell( + in_channels_left=16 * filters, + out_channels_left=4 * filters, # 16, 4 + in_channels_right=24 * filters, + out_channels_right=4 * filters + ) # 24, 4 + self.cell_14 = NormalCell( + in_channels_left=24 * filters, + out_channels_left=4 * filters, # 24, 4 + in_channels_right=24 * filters, + out_channels_right=4 * filters + ) # 24, 4 + self.cell_15 = NormalCell( + in_channels_left=24 * filters, + out_channels_left=4 * filters, # 24, 4 + in_channels_right=24 * filters, + out_channels_right=4 * filters + ) # 24, 4 + + self.relu = nn.ReLU() + self.dropout = nn.Dropout() + self.classifier = nn.Linear(24 * filters, num_classes) + + self._init_params() + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu' + ) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def features(self, input): + x_conv0 = self.conv0(input) + x_stem_0 = self.cell_stem_0(x_conv0) + x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0) + + x_cell_0 = self.cell_0(x_stem_1, x_stem_0) + x_cell_1 = self.cell_1(x_cell_0, x_stem_1) + x_cell_2 = self.cell_2(x_cell_1, x_cell_0) + x_cell_3 = self.cell_3(x_cell_2, x_cell_1) + + x_reduction_cell_0 = self.reduction_cell_0(x_cell_3, x_cell_2) + + x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_3) + x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0) + x_cell_8 = self.cell_8(x_cell_7, x_cell_6) + x_cell_9 = self.cell_9(x_cell_8, x_cell_7) + + x_reduction_cell_1 = self.reduction_cell_1(x_cell_9, x_cell_8) + + x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_9) + x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1) + x_cell_14 = self.cell_14(x_cell_13, x_cell_12) + x_cell_15 = self.cell_15(x_cell_14, x_cell_13) + + x_cell_15 = self.relu(x_cell_15) + x_cell_15 = F.avg_pool2d( + x_cell_15, + x_cell_15.size()[2:] + ) # global average pool + x_cell_15 = x_cell_15.view(x_cell_15.size(0), -1) + x_cell_15 = self.dropout(x_cell_15) + + return x_cell_15 + + def forward(self, input): + v = self.features(input) + + if not self.training: + return v + + y = self.classifier(v) + + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + return y, v + else: + raise KeyError('Unsupported loss: {}'.format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +def nasnetamobile(num_classes, loss='softmax', pretrained=True, **kwargs): + model = NASNetAMobile(num_classes, loss, **kwargs) + if pretrained: + model_url = pretrained_settings['nasnetamobile']['imagenet']['url'] + init_pretrained_weights(model, model_url) + return model diff --git a/feeder/trackers/strongsort/deep/models/osnet.py b/feeder/trackers/strongsort/deep/models/osnet.py new file mode 100644 index 0000000..b77388f --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/osnet.py @@ -0,0 +1,598 @@ +from __future__ import division, absolute_import +import warnings +import torch +from torch import nn +from torch.nn import functional as F + +__all__ = [ + 'osnet_x1_0', 'osnet_x0_75', 'osnet_x0_5', 'osnet_x0_25', 'osnet_ibn_x1_0' +] + +pretrained_urls = { + 'osnet_x1_0': + 'https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY', + 'osnet_x0_75': + 'https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq', + 'osnet_x0_5': + 'https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i', + 'osnet_x0_25': + 'https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs', + 'osnet_ibn_x1_0': + 'https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l' +} + + +########## +# Basic layers +########## +class ConvLayer(nn.Module): + """Convolution layer (conv + bn + relu).""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + groups=1, + IN=False + ): + super(ConvLayer, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=False, + groups=groups + ) + if IN: + self.bn = nn.InstanceNorm2d(out_channels, affine=True) + else: + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Conv1x1(nn.Module): + """1x1 convolution + bn + relu.""" + + def __init__(self, in_channels, out_channels, stride=1, groups=1): + super(Conv1x1, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + 1, + stride=stride, + padding=0, + bias=False, + groups=groups + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Conv1x1Linear(nn.Module): + """1x1 convolution + bn (w/o non-linearity).""" + + def __init__(self, in_channels, out_channels, stride=1): + super(Conv1x1Linear, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, 1, stride=stride, padding=0, bias=False + ) + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class Conv3x3(nn.Module): + """3x3 convolution + bn + relu.""" + + def __init__(self, in_channels, out_channels, stride=1, groups=1): + super(Conv3x3, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + 3, + stride=stride, + padding=1, + bias=False, + groups=groups + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class LightConv3x3(nn.Module): + """Lightweight 3x3 convolution. + + 1x1 (linear) + dw 3x3 (nonlinear). + """ + + def __init__(self, in_channels, out_channels): + super(LightConv3x3, self).__init__() + self.conv1 = nn.Conv2d( + in_channels, out_channels, 1, stride=1, padding=0, bias=False + ) + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + 3, + stride=1, + padding=1, + bias=False, + groups=out_channels + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.bn(x) + x = self.relu(x) + return x + + +########## +# Building blocks for omni-scale feature learning +########## +class ChannelGate(nn.Module): + """A mini-network that generates channel-wise gates conditioned on input tensor.""" + + def __init__( + self, + in_channels, + num_gates=None, + return_gates=False, + gate_activation='sigmoid', + reduction=16, + layer_norm=False + ): + super(ChannelGate, self).__init__() + if num_gates is None: + num_gates = in_channels + self.return_gates = return_gates + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Conv2d( + in_channels, + in_channels // reduction, + kernel_size=1, + bias=True, + padding=0 + ) + self.norm1 = None + if layer_norm: + self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1)) + self.relu = nn.ReLU(inplace=True) + self.fc2 = nn.Conv2d( + in_channels // reduction, + num_gates, + kernel_size=1, + bias=True, + padding=0 + ) + if gate_activation == 'sigmoid': + self.gate_activation = nn.Sigmoid() + elif gate_activation == 'relu': + self.gate_activation = nn.ReLU(inplace=True) + elif gate_activation == 'linear': + self.gate_activation = None + else: + raise RuntimeError( + "Unknown gate activation: {}".format(gate_activation) + ) + + def forward(self, x): + input = x + x = self.global_avgpool(x) + x = self.fc1(x) + if self.norm1 is not None: + x = self.norm1(x) + x = self.relu(x) + x = self.fc2(x) + if self.gate_activation is not None: + x = self.gate_activation(x) + if self.return_gates: + return x + return input * x + + +class OSBlock(nn.Module): + """Omni-scale feature learning block.""" + + def __init__( + self, + in_channels, + out_channels, + IN=False, + bottleneck_reduction=4, + **kwargs + ): + super(OSBlock, self).__init__() + mid_channels = out_channels // bottleneck_reduction + self.conv1 = Conv1x1(in_channels, mid_channels) + self.conv2a = LightConv3x3(mid_channels, mid_channels) + self.conv2b = nn.Sequential( + LightConv3x3(mid_channels, mid_channels), + LightConv3x3(mid_channels, mid_channels), + ) + self.conv2c = nn.Sequential( + LightConv3x3(mid_channels, mid_channels), + LightConv3x3(mid_channels, mid_channels), + LightConv3x3(mid_channels, mid_channels), + ) + self.conv2d = nn.Sequential( + LightConv3x3(mid_channels, mid_channels), + LightConv3x3(mid_channels, mid_channels), + LightConv3x3(mid_channels, mid_channels), + LightConv3x3(mid_channels, mid_channels), + ) + self.gate = ChannelGate(mid_channels) + self.conv3 = Conv1x1Linear(mid_channels, out_channels) + self.downsample = None + if in_channels != out_channels: + self.downsample = Conv1x1Linear(in_channels, out_channels) + self.IN = None + if IN: + self.IN = nn.InstanceNorm2d(out_channels, affine=True) + + def forward(self, x): + identity = x + x1 = self.conv1(x) + x2a = self.conv2a(x1) + x2b = self.conv2b(x1) + x2c = self.conv2c(x1) + x2d = self.conv2d(x1) + x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d) + x3 = self.conv3(x2) + if self.downsample is not None: + identity = self.downsample(identity) + out = x3 + identity + if self.IN is not None: + out = self.IN(out) + return F.relu(out) + + +########## +# Network architecture +########## +class OSNet(nn.Module): + """Omni-Scale Network. + + Reference: + - Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019. + - Zhou et al. Learning Generalisable Omni-Scale Representations + for Person Re-Identification. TPAMI, 2021. + """ + + def __init__( + self, + num_classes, + blocks, + layers, + channels, + feature_dim=512, + loss='softmax', + IN=False, + **kwargs + ): + super(OSNet, self).__init__() + num_blocks = len(blocks) + assert num_blocks == len(layers) + assert num_blocks == len(channels) - 1 + self.loss = loss + self.feature_dim = feature_dim + + # convolutional backbone + self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN) + self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) + self.conv2 = self._make_layer( + blocks[0], + layers[0], + channels[0], + channels[1], + reduce_spatial_size=True, + IN=IN + ) + self.conv3 = self._make_layer( + blocks[1], + layers[1], + channels[1], + channels[2], + reduce_spatial_size=True + ) + self.conv4 = self._make_layer( + blocks[2], + layers[2], + channels[2], + channels[3], + reduce_spatial_size=False + ) + self.conv5 = Conv1x1(channels[3], channels[3]) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + # fully connected layer + self.fc = self._construct_fc_layer( + self.feature_dim, channels[3], dropout_p=None + ) + # identity classification layer + self.classifier = nn.Linear(self.feature_dim, num_classes) + + self._init_params() + + def _make_layer( + self, + block, + layer, + in_channels, + out_channels, + reduce_spatial_size, + IN=False + ): + layers = [] + + layers.append(block(in_channels, out_channels, IN=IN)) + for i in range(1, layer): + layers.append(block(out_channels, out_channels, IN=IN)) + + if reduce_spatial_size: + layers.append( + nn.Sequential( + Conv1x1(out_channels, out_channels), + nn.AvgPool2d(2, stride=2) + ) + ) + + return nn.Sequential(*layers) + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + if fc_dims is None or fc_dims < 0: + self.feature_dim = input_dim + return None + + if isinstance(fc_dims, int): + fc_dims = [fc_dims] + + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU(inplace=True)) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + + self.feature_dim = fc_dims[-1] + + return nn.Sequential(*layers) + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu' + ) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def featuremaps(self, x): + x = self.conv1(x) + x = self.maxpool(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = self.conv5(x) + return x + + def forward(self, x, return_featuremaps=False): + x = self.featuremaps(x) + if return_featuremaps: + return x + v = self.global_avgpool(x) + v = v.view(v.size(0), -1) + if self.fc is not None: + v = self.fc(v) + if not self.training: + return v + y = self.classifier(v) + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + return y, v + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + + +def init_pretrained_weights(model, key=''): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + import os + import errno + import gdown + from collections import OrderedDict + + def _get_torch_home(): + ENV_TORCH_HOME = 'TORCH_HOME' + ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' + DEFAULT_CACHE_DIR = '~/.cache' + torch_home = os.path.expanduser( + os.getenv( + ENV_TORCH_HOME, + os.path.join( + os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch' + ) + ) + ) + return torch_home + + torch_home = _get_torch_home() + model_dir = os.path.join(torch_home, 'checkpoints') + try: + os.makedirs(model_dir) + except OSError as e: + if e.errno == errno.EEXIST: + # Directory already exists, ignore. + pass + else: + # Unexpected OSError, re-raise. + raise + filename = key + '_imagenet.pth' + cached_file = os.path.join(model_dir, filename) + + if not os.path.exists(cached_file): + gdown.download(pretrained_urls[key], cached_file, quiet=False) + + state_dict = torch.load(cached_file) + model_dict = model.state_dict() + new_state_dict = OrderedDict() + matched_layers, discarded_layers = [], [] + + for k, v in state_dict.items(): + if k.startswith('module.'): + k = k[7:] # discard module. + + if k in model_dict and model_dict[k].size() == v.size(): + new_state_dict[k] = v + matched_layers.append(k) + else: + discarded_layers.append(k) + + model_dict.update(new_state_dict) + model.load_state_dict(model_dict) + + if len(matched_layers) == 0: + warnings.warn( + 'The pretrained weights from "{}" cannot be loaded, ' + 'please check the key names manually ' + '(** ignored and continue **)'.format(cached_file) + ) + else: + print( + 'Successfully loaded imagenet pretrained weights from "{}"'. + format(cached_file) + ) + if len(discarded_layers) > 0: + print( + '** The following layers are discarded ' + 'due to unmatched keys or layer size: {}'. + format(discarded_layers) + ) + + +########## +# Instantiation +########## +def osnet_x1_0(num_classes=1000, pretrained=True, loss='softmax', **kwargs): + # standard size (width x1.0) + model = OSNet( + num_classes, + blocks=[OSBlock, OSBlock, OSBlock], + layers=[2, 2, 2], + channels=[64, 256, 384, 512], + loss=loss, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, key='osnet_x1_0') + return model + + +def osnet_x0_75(num_classes=1000, pretrained=True, loss='softmax', **kwargs): + # medium size (width x0.75) + model = OSNet( + num_classes, + blocks=[OSBlock, OSBlock, OSBlock], + layers=[2, 2, 2], + channels=[48, 192, 288, 384], + loss=loss, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, key='osnet_x0_75') + return model + + +def osnet_x0_5(num_classes=1000, pretrained=True, loss='softmax', **kwargs): + # tiny size (width x0.5) + model = OSNet( + num_classes, + blocks=[OSBlock, OSBlock, OSBlock], + layers=[2, 2, 2], + channels=[32, 128, 192, 256], + loss=loss, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, key='osnet_x0_5') + return model + + +def osnet_x0_25(num_classes=1000, pretrained=True, loss='softmax', **kwargs): + # very tiny size (width x0.25) + model = OSNet( + num_classes, + blocks=[OSBlock, OSBlock, OSBlock], + layers=[2, 2, 2], + channels=[16, 64, 96, 128], + loss=loss, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, key='osnet_x0_25') + return model + + +def osnet_ibn_x1_0( + num_classes=1000, pretrained=True, loss='softmax', **kwargs +): + # standard size (width x1.0) + IBN layer + # Ref: Pan et al. Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net. ECCV, 2018. + model = OSNet( + num_classes, + blocks=[OSBlock, OSBlock, OSBlock], + layers=[2, 2, 2], + channels=[64, 256, 384, 512], + loss=loss, + IN=True, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, key='osnet_ibn_x1_0') + return model diff --git a/feeder/trackers/strongsort/deep/models/osnet_ain.py b/feeder/trackers/strongsort/deep/models/osnet_ain.py new file mode 100644 index 0000000..3f9f7bd --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/osnet_ain.py @@ -0,0 +1,609 @@ +from __future__ import division, absolute_import +import warnings +import torch +from torch import nn +from torch.nn import functional as F + +__all__ = [ + 'osnet_ain_x1_0', 'osnet_ain_x0_75', 'osnet_ain_x0_5', 'osnet_ain_x0_25' +] + +pretrained_urls = { + 'osnet_ain_x1_0': + 'https://drive.google.com/uc?id=1-CaioD9NaqbHK_kzSMW8VE4_3KcsRjEo', + 'osnet_ain_x0_75': + 'https://drive.google.com/uc?id=1apy0hpsMypqstfencdH-jKIUEFOW4xoM', + 'osnet_ain_x0_5': + 'https://drive.google.com/uc?id=1KusKvEYyKGDTUBVRxRiz55G31wkihB6l', + 'osnet_ain_x0_25': + 'https://drive.google.com/uc?id=1SxQt2AvmEcgWNhaRb2xC4rP6ZwVDP0Wt' +} + + +########## +# Basic layers +########## +class ConvLayer(nn.Module): + """Convolution layer (conv + bn + relu).""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + groups=1, + IN=False + ): + super(ConvLayer, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=False, + groups=groups + ) + if IN: + self.bn = nn.InstanceNorm2d(out_channels, affine=True) + else: + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return self.relu(x) + + +class Conv1x1(nn.Module): + """1x1 convolution + bn + relu.""" + + def __init__(self, in_channels, out_channels, stride=1, groups=1): + super(Conv1x1, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + 1, + stride=stride, + padding=0, + bias=False, + groups=groups + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return self.relu(x) + + +class Conv1x1Linear(nn.Module): + """1x1 convolution + bn (w/o non-linearity).""" + + def __init__(self, in_channels, out_channels, stride=1, bn=True): + super(Conv1x1Linear, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, 1, stride=stride, padding=0, bias=False + ) + self.bn = None + if bn: + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + return x + + +class Conv3x3(nn.Module): + """3x3 convolution + bn + relu.""" + + def __init__(self, in_channels, out_channels, stride=1, groups=1): + super(Conv3x3, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + 3, + stride=stride, + padding=1, + bias=False, + groups=groups + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return self.relu(x) + + +class LightConv3x3(nn.Module): + """Lightweight 3x3 convolution. + + 1x1 (linear) + dw 3x3 (nonlinear). + """ + + def __init__(self, in_channels, out_channels): + super(LightConv3x3, self).__init__() + self.conv1 = nn.Conv2d( + in_channels, out_channels, 1, stride=1, padding=0, bias=False + ) + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + 3, + stride=1, + padding=1, + bias=False, + groups=out_channels + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.bn(x) + return self.relu(x) + + +class LightConvStream(nn.Module): + """Lightweight convolution stream.""" + + def __init__(self, in_channels, out_channels, depth): + super(LightConvStream, self).__init__() + assert depth >= 1, 'depth must be equal to or larger than 1, but got {}'.format( + depth + ) + layers = [] + layers += [LightConv3x3(in_channels, out_channels)] + for i in range(depth - 1): + layers += [LightConv3x3(out_channels, out_channels)] + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +########## +# Building blocks for omni-scale feature learning +########## +class ChannelGate(nn.Module): + """A mini-network that generates channel-wise gates conditioned on input tensor.""" + + def __init__( + self, + in_channels, + num_gates=None, + return_gates=False, + gate_activation='sigmoid', + reduction=16, + layer_norm=False + ): + super(ChannelGate, self).__init__() + if num_gates is None: + num_gates = in_channels + self.return_gates = return_gates + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Conv2d( + in_channels, + in_channels // reduction, + kernel_size=1, + bias=True, + padding=0 + ) + self.norm1 = None + if layer_norm: + self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1)) + self.relu = nn.ReLU() + self.fc2 = nn.Conv2d( + in_channels // reduction, + num_gates, + kernel_size=1, + bias=True, + padding=0 + ) + if gate_activation == 'sigmoid': + self.gate_activation = nn.Sigmoid() + elif gate_activation == 'relu': + self.gate_activation = nn.ReLU() + elif gate_activation == 'linear': + self.gate_activation = None + else: + raise RuntimeError( + "Unknown gate activation: {}".format(gate_activation) + ) + + def forward(self, x): + input = x + x = self.global_avgpool(x) + x = self.fc1(x) + if self.norm1 is not None: + x = self.norm1(x) + x = self.relu(x) + x = self.fc2(x) + if self.gate_activation is not None: + x = self.gate_activation(x) + if self.return_gates: + return x + return input * x + + +class OSBlock(nn.Module): + """Omni-scale feature learning block.""" + + def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs): + super(OSBlock, self).__init__() + assert T >= 1 + assert out_channels >= reduction and out_channels % reduction == 0 + mid_channels = out_channels // reduction + + self.conv1 = Conv1x1(in_channels, mid_channels) + self.conv2 = nn.ModuleList() + for t in range(1, T + 1): + self.conv2 += [LightConvStream(mid_channels, mid_channels, t)] + self.gate = ChannelGate(mid_channels) + self.conv3 = Conv1x1Linear(mid_channels, out_channels) + self.downsample = None + if in_channels != out_channels: + self.downsample = Conv1x1Linear(in_channels, out_channels) + + def forward(self, x): + identity = x + x1 = self.conv1(x) + x2 = 0 + for conv2_t in self.conv2: + x2_t = conv2_t(x1) + x2 = x2 + self.gate(x2_t) + x3 = self.conv3(x2) + if self.downsample is not None: + identity = self.downsample(identity) + out = x3 + identity + return F.relu(out) + + +class OSBlockINin(nn.Module): + """Omni-scale feature learning block with instance normalization.""" + + def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs): + super(OSBlockINin, self).__init__() + assert T >= 1 + assert out_channels >= reduction and out_channels % reduction == 0 + mid_channels = out_channels // reduction + + self.conv1 = Conv1x1(in_channels, mid_channels) + self.conv2 = nn.ModuleList() + for t in range(1, T + 1): + self.conv2 += [LightConvStream(mid_channels, mid_channels, t)] + self.gate = ChannelGate(mid_channels) + self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn=False) + self.downsample = None + if in_channels != out_channels: + self.downsample = Conv1x1Linear(in_channels, out_channels) + self.IN = nn.InstanceNorm2d(out_channels, affine=True) + + def forward(self, x): + identity = x + x1 = self.conv1(x) + x2 = 0 + for conv2_t in self.conv2: + x2_t = conv2_t(x1) + x2 = x2 + self.gate(x2_t) + x3 = self.conv3(x2) + x3 = self.IN(x3) # IN inside residual + if self.downsample is not None: + identity = self.downsample(identity) + out = x3 + identity + return F.relu(out) + + +########## +# Network architecture +########## +class OSNet(nn.Module): + """Omni-Scale Network. + + Reference: + - Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019. + - Zhou et al. Learning Generalisable Omni-Scale Representations + for Person Re-Identification. TPAMI, 2021. + """ + + def __init__( + self, + num_classes, + blocks, + layers, + channels, + feature_dim=512, + loss='softmax', + conv1_IN=False, + **kwargs + ): + super(OSNet, self).__init__() + num_blocks = len(blocks) + assert num_blocks == len(layers) + assert num_blocks == len(channels) - 1 + self.loss = loss + self.feature_dim = feature_dim + + # convolutional backbone + self.conv1 = ConvLayer( + 3, channels[0], 7, stride=2, padding=3, IN=conv1_IN + ) + self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) + self.conv2 = self._make_layer( + blocks[0], layers[0], channels[0], channels[1] + ) + self.pool2 = nn.Sequential( + Conv1x1(channels[1], channels[1]), nn.AvgPool2d(2, stride=2) + ) + self.conv3 = self._make_layer( + blocks[1], layers[1], channels[1], channels[2] + ) + self.pool3 = nn.Sequential( + Conv1x1(channels[2], channels[2]), nn.AvgPool2d(2, stride=2) + ) + self.conv4 = self._make_layer( + blocks[2], layers[2], channels[2], channels[3] + ) + self.conv5 = Conv1x1(channels[3], channels[3]) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + # fully connected layer + self.fc = self._construct_fc_layer( + self.feature_dim, channels[3], dropout_p=None + ) + # identity classification layer + self.classifier = nn.Linear(self.feature_dim, num_classes) + + self._init_params() + + def _make_layer(self, blocks, layer, in_channels, out_channels): + layers = [] + layers += [blocks[0](in_channels, out_channels)] + for i in range(1, len(blocks)): + layers += [blocks[i](out_channels, out_channels)] + return nn.Sequential(*layers) + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + if fc_dims is None or fc_dims < 0: + self.feature_dim = input_dim + return None + + if isinstance(fc_dims, int): + fc_dims = [fc_dims] + + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU()) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + + self.feature_dim = fc_dims[-1] + + return nn.Sequential(*layers) + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu' + ) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.InstanceNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def featuremaps(self, x): + x = self.conv1(x) + x = self.maxpool(x) + x = self.conv2(x) + x = self.pool2(x) + x = self.conv3(x) + x = self.pool3(x) + x = self.conv4(x) + x = self.conv5(x) + return x + + def forward(self, x, return_featuremaps=False): + x = self.featuremaps(x) + if return_featuremaps: + return x + v = self.global_avgpool(x) + v = v.view(v.size(0), -1) + if self.fc is not None: + v = self.fc(v) + if not self.training: + return v + y = self.classifier(v) + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + return y, v + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + + +def init_pretrained_weights(model, key=''): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + import os + import errno + import gdown + from collections import OrderedDict + + def _get_torch_home(): + ENV_TORCH_HOME = 'TORCH_HOME' + ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' + DEFAULT_CACHE_DIR = '~/.cache' + torch_home = os.path.expanduser( + os.getenv( + ENV_TORCH_HOME, + os.path.join( + os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch' + ) + ) + ) + return torch_home + + torch_home = _get_torch_home() + model_dir = os.path.join(torch_home, 'checkpoints') + try: + os.makedirs(model_dir) + except OSError as e: + if e.errno == errno.EEXIST: + # Directory already exists, ignore. + pass + else: + # Unexpected OSError, re-raise. + raise + filename = key + '_imagenet.pth' + cached_file = os.path.join(model_dir, filename) + + if not os.path.exists(cached_file): + gdown.download(pretrained_urls[key], cached_file, quiet=False) + + state_dict = torch.load(cached_file) + model_dict = model.state_dict() + new_state_dict = OrderedDict() + matched_layers, discarded_layers = [], [] + + for k, v in state_dict.items(): + if k.startswith('module.'): + k = k[7:] # discard module. + + if k in model_dict and model_dict[k].size() == v.size(): + new_state_dict[k] = v + matched_layers.append(k) + else: + discarded_layers.append(k) + + model_dict.update(new_state_dict) + model.load_state_dict(model_dict) + + if len(matched_layers) == 0: + warnings.warn( + 'The pretrained weights from "{}" cannot be loaded, ' + 'please check the key names manually ' + '(** ignored and continue **)'.format(cached_file) + ) + else: + print( + 'Successfully loaded imagenet pretrained weights from "{}"'. + format(cached_file) + ) + if len(discarded_layers) > 0: + print( + '** The following layers are discarded ' + 'due to unmatched keys or layer size: {}'. + format(discarded_layers) + ) + + +########## +# Instantiation +########## +def osnet_ain_x1_0( + num_classes=1000, pretrained=True, loss='softmax', **kwargs +): + model = OSNet( + num_classes, + blocks=[ + [OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin], + [OSBlockINin, OSBlock] + ], + layers=[2, 2, 2], + channels=[64, 256, 384, 512], + loss=loss, + conv1_IN=True, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, key='osnet_ain_x1_0') + return model + + +def osnet_ain_x0_75( + num_classes=1000, pretrained=True, loss='softmax', **kwargs +): + model = OSNet( + num_classes, + blocks=[ + [OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin], + [OSBlockINin, OSBlock] + ], + layers=[2, 2, 2], + channels=[48, 192, 288, 384], + loss=loss, + conv1_IN=True, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, key='osnet_ain_x0_75') + return model + + +def osnet_ain_x0_5( + num_classes=1000, pretrained=True, loss='softmax', **kwargs +): + model = OSNet( + num_classes, + blocks=[ + [OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin], + [OSBlockINin, OSBlock] + ], + layers=[2, 2, 2], + channels=[32, 128, 192, 256], + loss=loss, + conv1_IN=True, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, key='osnet_ain_x0_5') + return model + + +def osnet_ain_x0_25( + num_classes=1000, pretrained=True, loss='softmax', **kwargs +): + model = OSNet( + num_classes, + blocks=[ + [OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin], + [OSBlockINin, OSBlock] + ], + layers=[2, 2, 2], + channels=[16, 64, 96, 128], + loss=loss, + conv1_IN=True, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, key='osnet_ain_x0_25') + return model diff --git a/feeder/trackers/strongsort/deep/models/pcb.py b/feeder/trackers/strongsort/deep/models/pcb.py new file mode 100644 index 0000000..92c7414 --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/pcb.py @@ -0,0 +1,314 @@ +from __future__ import division, absolute_import +import torch.utils.model_zoo as model_zoo +from torch import nn +from torch.nn import functional as F + +__all__ = ['pcb_p6', 'pcb_p4'] + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False + ) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False + ) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False + ) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class DimReduceLayer(nn.Module): + + def __init__(self, in_channels, out_channels, nonlinear): + super(DimReduceLayer, self).__init__() + layers = [] + layers.append( + nn.Conv2d( + in_channels, out_channels, 1, stride=1, padding=0, bias=False + ) + ) + layers.append(nn.BatchNorm2d(out_channels)) + + if nonlinear == 'relu': + layers.append(nn.ReLU(inplace=True)) + elif nonlinear == 'leakyrelu': + layers.append(nn.LeakyReLU(0.1)) + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class PCB(nn.Module): + """Part-based Convolutional Baseline. + + Reference: + Sun et al. Beyond Part Models: Person Retrieval with Refined + Part Pooling (and A Strong Convolutional Baseline). ECCV 2018. + + Public keys: + - ``pcb_p4``: PCB with 4-part strips. + - ``pcb_p6``: PCB with 6-part strips. + """ + + def __init__( + self, + num_classes, + loss, + block, + layers, + parts=6, + reduced_dim=256, + nonlinear='relu', + **kwargs + ): + self.inplanes = 64 + super(PCB, self).__init__() + self.loss = loss + self.parts = parts + self.feature_dim = 512 * block.expansion + + # backbone network + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=1) + + # pcb layers + self.parts_avgpool = nn.AdaptiveAvgPool2d((self.parts, 1)) + self.dropout = nn.Dropout(p=0.5) + self.conv5 = DimReduceLayer( + 512 * block.expansion, reduced_dim, nonlinear=nonlinear + ) + self.feature_dim = reduced_dim + self.classifier = nn.ModuleList( + [ + nn.Linear(self.feature_dim, num_classes) + for _ in range(self.parts) + ] + ) + + self._init_params() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu' + ) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def featuremaps(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def forward(self, x): + f = self.featuremaps(x) + v_g = self.parts_avgpool(f) + + if not self.training: + v_g = F.normalize(v_g, p=2, dim=1) + return v_g.view(v_g.size(0), -1) + + v_g = self.dropout(v_g) + v_h = self.conv5(v_g) + + y = [] + for i in range(self.parts): + v_h_i = v_h[:, :, i, :] + v_h_i = v_h_i.view(v_h_i.size(0), -1) + y_i = self.classifier[i](v_h_i) + y.append(y_i) + + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + v_g = F.normalize(v_g, p=2, dim=1) + return y, v_g.view(v_g.size(0), -1) + else: + raise KeyError('Unsupported loss: {}'.format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +def pcb_p6(num_classes, loss='softmax', pretrained=True, **kwargs): + model = PCB( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 4, 6, 3], + last_stride=1, + parts=6, + reduced_dim=256, + nonlinear='relu', + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['resnet50']) + return model + + +def pcb_p4(num_classes, loss='softmax', pretrained=True, **kwargs): + model = PCB( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 4, 6, 3], + last_stride=1, + parts=4, + reduced_dim=256, + nonlinear='relu', + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['resnet50']) + return model diff --git a/feeder/trackers/strongsort/deep/models/resnet.py b/feeder/trackers/strongsort/deep/models/resnet.py new file mode 100644 index 0000000..63d7f43 --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/resnet.py @@ -0,0 +1,530 @@ +""" +Code source: https://github.com/pytorch/vision +""" +from __future__ import division, absolute_import +import torch.utils.model_zoo as model_zoo +from torch import nn + +__all__ = [ + 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', + 'resnext50_32x4d', 'resnext101_32x8d', 'resnet50_fc512' +] + +model_urls = { + 'resnet18': + 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': + 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': + 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': + 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': + 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': + 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': + 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation + ) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d( + in_planes, out_planes, kernel_size=1, stride=stride, bias=False + ) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None + ): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64' + ) + if dilation > 1: + raise NotImplementedError( + "Dilation > 1 not supported in BasicBlock" + ) + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None + ): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width/64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + """Residual network. + + Reference: + - He et al. Deep Residual Learning for Image Recognition. CVPR 2016. + - Xie et al. Aggregated Residual Transformations for Deep Neural Networks. CVPR 2017. + + Public keys: + - ``resnet18``: ResNet18. + - ``resnet34``: ResNet34. + - ``resnet50``: ResNet50. + - ``resnet101``: ResNet101. + - ``resnet152``: ResNet152. + - ``resnext50_32x4d``: ResNeXt50. + - ``resnext101_32x8d``: ResNeXt101. + - ``resnet50_fc512``: ResNet50 + FC. + """ + + def __init__( + self, + num_classes, + loss, + block, + layers, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=None, + last_stride=2, + fc_dims=None, + dropout_p=None, + **kwargs + ): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + self.loss = loss + self.feature_dim = 512 * block.expansion + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}". + format(replace_stride_with_dilation) + ) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d( + 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0] + ) + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1] + ) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=last_stride, + dilate=replace_stride_with_dilation[2] + ) + self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = self._construct_fc_layer( + fc_dims, 512 * block.expansion, dropout_p + ) + self.classifier = nn.Linear(self.feature_dim, num_classes) + + self._init_params() + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer + ) + ) + + return nn.Sequential(*layers) + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + """Constructs fully connected layer + + Args: + fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed + input_dim (int): input dimension + dropout_p (float): dropout probability, if None, dropout is unused + """ + if fc_dims is None: + self.feature_dim = input_dim + return None + + assert isinstance( + fc_dims, (list, tuple) + ), 'fc_dims must be either list or tuple, but got {}'.format( + type(fc_dims) + ) + + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU(inplace=True)) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + + self.feature_dim = fc_dims[-1] + + return nn.Sequential(*layers) + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu' + ) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def featuremaps(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def forward(self, x): + f = self.featuremaps(x) + v = self.global_avgpool(f) + v = v.view(v.size(0), -1) + + if self.fc is not None: + v = self.fc(v) + + if not self.training: + return v + + y = self.classifier(v) + + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + return y, v + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +"""ResNet""" + + +def resnet18(num_classes, loss='softmax', pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=BasicBlock, + layers=[2, 2, 2, 2], + last_stride=2, + fc_dims=None, + dropout_p=None, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['resnet18']) + return model + + +def resnet34(num_classes, loss='softmax', pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=BasicBlock, + layers=[3, 4, 6, 3], + last_stride=2, + fc_dims=None, + dropout_p=None, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['resnet34']) + return model + + +def resnet50(num_classes, loss='softmax', pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 4, 6, 3], + last_stride=2, + fc_dims=None, + dropout_p=None, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['resnet50']) + return model + + +def resnet101(num_classes, loss='softmax', pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 4, 23, 3], + last_stride=2, + fc_dims=None, + dropout_p=None, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['resnet101']) + return model + + +def resnet152(num_classes, loss='softmax', pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 8, 36, 3], + last_stride=2, + fc_dims=None, + dropout_p=None, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['resnet152']) + return model + + +"""ResNeXt""" + + +def resnext50_32x4d(num_classes, loss='softmax', pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 4, 6, 3], + last_stride=2, + fc_dims=None, + dropout_p=None, + groups=32, + width_per_group=4, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['resnext50_32x4d']) + return model + + +def resnext101_32x8d(num_classes, loss='softmax', pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 4, 23, 3], + last_stride=2, + fc_dims=None, + dropout_p=None, + groups=32, + width_per_group=8, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['resnext101_32x8d']) + return model + + +""" +ResNet + FC +""" + + +def resnet50_fc512(num_classes, loss='softmax', pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 4, 6, 3], + last_stride=1, + fc_dims=[512], + dropout_p=None, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['resnet50']) + return model diff --git a/feeder/trackers/strongsort/deep/models/resnet_ibn_a.py b/feeder/trackers/strongsort/deep/models/resnet_ibn_a.py new file mode 100644 index 0000000..d198e7c --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/resnet_ibn_a.py @@ -0,0 +1,289 @@ +""" +Credit to https://github.com/XingangPan/IBN-Net. +""" +from __future__ import division, absolute_import +import math +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + +__all__ = ['resnet50_ibn_a'] + +model_urls = { + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False + ) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class IBN(nn.Module): + + def __init__(self, planes): + super(IBN, self).__init__() + half1 = int(planes / 2) + self.half = half1 + half2 = planes - half1 + self.IN = nn.InstanceNorm2d(half1, affine=True) + self.BN = nn.BatchNorm2d(half2) + + def forward(self, x): + split = torch.split(x, self.half, 1) + out1 = self.IN(split[0].contiguous()) + out2 = self.BN(split[1].contiguous()) + out = torch.cat((out1, out2), 1) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + if ibn: + self.bn1 = IBN(planes) + else: + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False + ) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False + ) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + """Residual network + IBN layer. + + Reference: + - He et al. Deep Residual Learning for Image Recognition. CVPR 2016. + - Pan et al. Two at Once: Enhancing Learning and Generalization + Capacities via IBN-Net. ECCV 2018. + """ + + def __init__( + self, + block, + layers, + num_classes=1000, + loss='softmax', + fc_dims=None, + dropout_p=None, + **kwargs + ): + scale = 64 + self.inplanes = scale + super(ResNet, self).__init__() + self.loss = loss + self.feature_dim = scale * 8 * block.expansion + + self.conv1 = nn.Conv2d( + 3, scale, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = nn.BatchNorm2d(scale) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, scale, layers[0]) + self.layer2 = self._make_layer(block, scale * 2, layers[1], stride=2) + self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2) + self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = self._construct_fc_layer( + fc_dims, scale * 8 * block.expansion, dropout_p + ) + self.classifier = nn.Linear(self.feature_dim, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.InstanceNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + ibn = True + if planes == 512: + ibn = False + layers.append(block(self.inplanes, planes, ibn, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, ibn)) + + return nn.Sequential(*layers) + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + """Constructs fully connected layer + + Args: + fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed + input_dim (int): input dimension + dropout_p (float): dropout probability, if None, dropout is unused + """ + if fc_dims is None: + self.feature_dim = input_dim + return None + + assert isinstance( + fc_dims, (list, tuple) + ), 'fc_dims must be either list or tuple, but got {}'.format( + type(fc_dims) + ) + + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU(inplace=True)) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + + self.feature_dim = fc_dims[-1] + + return nn.Sequential(*layers) + + def featuremaps(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def forward(self, x): + f = self.featuremaps(x) + v = self.avgpool(f) + v = v.view(v.size(0), -1) + if self.fc is not None: + v = self.fc(v) + if not self.training: + return v + y = self.classifier(v) + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + return y, v + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +def resnet50_ibn_a(num_classes, loss='softmax', pretrained=False, **kwargs): + model = ResNet( + Bottleneck, [3, 4, 6, 3], num_classes=num_classes, loss=loss, **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['resnet50']) + return model diff --git a/feeder/trackers/strongsort/deep/models/resnet_ibn_b.py b/feeder/trackers/strongsort/deep/models/resnet_ibn_b.py new file mode 100644 index 0000000..9881cc7 --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/resnet_ibn_b.py @@ -0,0 +1,274 @@ +""" +Credit to https://github.com/XingangPan/IBN-Net. +""" +from __future__ import division, absolute_import +import math +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + +__all__ = ['resnet50_ibn_b'] + +model_urls = { + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False + ) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, IN=False): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False + ) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False + ) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.IN = None + if IN: + self.IN = nn.InstanceNorm2d(planes * 4, affine=True) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + if self.IN is not None: + out = self.IN(out) + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + """Residual network + IBN layer. + + Reference: + - He et al. Deep Residual Learning for Image Recognition. CVPR 2016. + - Pan et al. Two at Once: Enhancing Learning and Generalization + Capacities via IBN-Net. ECCV 2018. + """ + + def __init__( + self, + block, + layers, + num_classes=1000, + loss='softmax', + fc_dims=None, + dropout_p=None, + **kwargs + ): + scale = 64 + self.inplanes = scale + super(ResNet, self).__init__() + self.loss = loss + self.feature_dim = scale * 8 * block.expansion + + self.conv1 = nn.Conv2d( + 3, scale, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = nn.InstanceNorm2d(scale, affine=True) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer( + block, scale, layers[0], stride=1, IN=True + ) + self.layer2 = self._make_layer( + block, scale * 2, layers[1], stride=2, IN=True + ) + self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2) + self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = self._construct_fc_layer( + fc_dims, scale * 8 * block.expansion, dropout_p + ) + self.classifier = nn.Linear(self.feature_dim, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.InstanceNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, IN=False): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks - 1): + layers.append(block(self.inplanes, planes)) + layers.append(block(self.inplanes, planes, IN=IN)) + + return nn.Sequential(*layers) + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + """Constructs fully connected layer + + Args: + fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed + input_dim (int): input dimension + dropout_p (float): dropout probability, if None, dropout is unused + """ + if fc_dims is None: + self.feature_dim = input_dim + return None + + assert isinstance( + fc_dims, (list, tuple) + ), 'fc_dims must be either list or tuple, but got {}'.format( + type(fc_dims) + ) + + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU(inplace=True)) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + + self.feature_dim = fc_dims[-1] + + return nn.Sequential(*layers) + + def featuremaps(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def forward(self, x): + f = self.featuremaps(x) + v = self.avgpool(f) + v = v.view(v.size(0), -1) + if self.fc is not None: + v = self.fc(v) + if not self.training: + return v + y = self.classifier(v) + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + return y, v + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +def resnet50_ibn_b(num_classes, loss='softmax', pretrained=False, **kwargs): + model = ResNet( + Bottleneck, [3, 4, 6, 3], num_classes=num_classes, loss=loss, **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['resnet50']) + return model diff --git a/feeder/trackers/strongsort/deep/models/resnetmid.py b/feeder/trackers/strongsort/deep/models/resnetmid.py new file mode 100644 index 0000000..017f6c6 --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/resnetmid.py @@ -0,0 +1,307 @@ +from __future__ import division, absolute_import +import torch +import torch.utils.model_zoo as model_zoo +from torch import nn + +__all__ = ['resnet50mid'] + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False + ) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False + ) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False + ) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNetMid(nn.Module): + """Residual network + mid-level features. + + Reference: + Yu et al. The Devil is in the Middle: Exploiting Mid-level Representations for + Cross-Domain Instance Matching. arXiv:1711.08106. + + Public keys: + - ``resnet50mid``: ResNet50 + mid-level feature fusion. + """ + + def __init__( + self, + num_classes, + loss, + block, + layers, + last_stride=2, + fc_dims=None, + **kwargs + ): + self.inplanes = 64 + super(ResNetMid, self).__init__() + self.loss = loss + self.feature_dim = 512 * block.expansion + + # backbone network + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=last_stride + ) + + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + assert fc_dims is not None + self.fc_fusion = self._construct_fc_layer( + fc_dims, 512 * block.expansion * 2 + ) + self.feature_dim += 512 * block.expansion + self.classifier = nn.Linear(self.feature_dim, num_classes) + + self._init_params() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + """Constructs fully connected layer + + Args: + fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed + input_dim (int): input dimension + dropout_p (float): dropout probability, if None, dropout is unused + """ + if fc_dims is None: + self.feature_dim = input_dim + return None + + assert isinstance( + fc_dims, (list, tuple) + ), 'fc_dims must be either list or tuple, but got {}'.format( + type(fc_dims) + ) + + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU(inplace=True)) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + + self.feature_dim = fc_dims[-1] + + return nn.Sequential(*layers) + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu' + ) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def featuremaps(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x4a = self.layer4[0](x) + x4b = self.layer4[1](x4a) + x4c = self.layer4[2](x4b) + return x4a, x4b, x4c + + def forward(self, x): + x4a, x4b, x4c = self.featuremaps(x) + + v4a = self.global_avgpool(x4a) + v4b = self.global_avgpool(x4b) + v4c = self.global_avgpool(x4c) + v4ab = torch.cat([v4a, v4b], 1) + v4ab = v4ab.view(v4ab.size(0), -1) + v4ab = self.fc_fusion(v4ab) + v4c = v4c.view(v4c.size(0), -1) + v = torch.cat([v4ab, v4c], 1) + + if not self.training: + return v + + y = self.classifier(v) + + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + return y, v + else: + raise KeyError('Unsupported loss: {}'.format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +""" +Residual network configurations: +-- +resnet18: block=BasicBlock, layers=[2, 2, 2, 2] +resnet34: block=BasicBlock, layers=[3, 4, 6, 3] +resnet50: block=Bottleneck, layers=[3, 4, 6, 3] +resnet101: block=Bottleneck, layers=[3, 4, 23, 3] +resnet152: block=Bottleneck, layers=[3, 8, 36, 3] +""" + + +def resnet50mid(num_classes, loss='softmax', pretrained=True, **kwargs): + model = ResNetMid( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 4, 6, 3], + last_stride=2, + fc_dims=[1024], + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['resnet50']) + return model diff --git a/feeder/trackers/strongsort/deep/models/senet.py b/feeder/trackers/strongsort/deep/models/senet.py new file mode 100644 index 0000000..baaf9b0 --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/senet.py @@ -0,0 +1,688 @@ +from __future__ import division, absolute_import +import math +from collections import OrderedDict +import torch.nn as nn +from torch.utils import model_zoo + +__all__ = [ + 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', + 'se_resnext50_32x4d', 'se_resnext101_32x4d', 'se_resnet50_fc512' +] +""" +Code imported from https://github.com/Cadene/pretrained-models.pytorch +""" + +pretrained_settings = { + 'senet154': { + 'imagenet': { + 'url': + 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, + 'se_resnet50': { + 'imagenet': { + 'url': + 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, + 'se_resnet101': { + 'imagenet': { + 'url': + 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, + 'se_resnet152': { + 'imagenet': { + 'url': + 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, + 'se_resnext50_32x4d': { + 'imagenet': { + 'url': + 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, + 'se_resnext101_32x4d': { + 'imagenet': { + 'url': + 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, +} + + +class SEModule(nn.Module): + + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Conv2d( + channels, channels // reduction, kernel_size=1, padding=0 + ) + self.relu = nn.ReLU(inplace=True) + self.fc2 = nn.Conv2d( + channels // reduction, channels, kernel_size=1, padding=0 + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class Bottleneck(nn.Module): + """ + Base class for bottlenecks that implements `forward()` method. + """ + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = self.se_module(out) + residual + out = self.relu(out) + + return out + + +class SEBottleneck(Bottleneck): + """ + Bottleneck for SENet154. + """ + expansion = 4 + + def __init__( + self, inplanes, planes, groups, reduction, stride=1, downsample=None + ): + super(SEBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes * 2) + self.conv2 = nn.Conv2d( + planes * 2, + planes * 4, + kernel_size=3, + stride=stride, + padding=1, + groups=groups, + bias=False + ) + self.bn2 = nn.BatchNorm2d(planes * 4) + self.conv3 = nn.Conv2d( + planes * 4, planes * 4, kernel_size=1, bias=False + ) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNetBottleneck(Bottleneck): + """ + ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe + implementation and uses `stride=stride` in `conv1` and not in `conv2` + (the latter is used in the torchvision implementation of ResNet). + """ + expansion = 4 + + def __init__( + self, inplanes, planes, groups, reduction, stride=1, downsample=None + ): + super(SEResNetBottleneck, self).__init__() + self.conv1 = nn.Conv2d( + inplanes, planes, kernel_size=1, bias=False, stride=stride + ) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + padding=1, + groups=groups, + bias=False + ) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNeXtBottleneck(Bottleneck): + """ResNeXt bottleneck type C with a Squeeze-and-Excitation module""" + expansion = 4 + + def __init__( + self, + inplanes, + planes, + groups, + reduction, + stride=1, + downsample=None, + base_width=4 + ): + super(SEResNeXtBottleneck, self).__init__() + width = int(math.floor(planes * (base_width/64.)) * groups) + self.conv1 = nn.Conv2d( + inplanes, width, kernel_size=1, bias=False, stride=1 + ) + self.bn1 = nn.BatchNorm2d(width) + self.conv2 = nn.Conv2d( + width, + width, + kernel_size=3, + stride=stride, + padding=1, + groups=groups, + bias=False + ) + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SENet(nn.Module): + """Squeeze-and-excitation network. + + Reference: + Hu et al. Squeeze-and-Excitation Networks. CVPR 2018. + + Public keys: + - ``senet154``: SENet154. + - ``se_resnet50``: ResNet50 + SE. + - ``se_resnet101``: ResNet101 + SE. + - ``se_resnet152``: ResNet152 + SE. + - ``se_resnext50_32x4d``: ResNeXt50 (groups=32, width=4) + SE. + - ``se_resnext101_32x4d``: ResNeXt101 (groups=32, width=4) + SE. + - ``se_resnet50_fc512``: (ResNet50 + SE) + FC. + """ + + def __init__( + self, + num_classes, + loss, + block, + layers, + groups, + reduction, + dropout_p=0.2, + inplanes=128, + input_3x3=True, + downsample_kernel_size=3, + downsample_padding=1, + last_stride=2, + fc_dims=None, + **kwargs + ): + """ + Parameters + ---------- + block (nn.Module): Bottleneck class. + - For SENet154: SEBottleneck + - For SE-ResNet models: SEResNetBottleneck + - For SE-ResNeXt models: SEResNeXtBottleneck + layers (list of ints): Number of residual blocks for 4 layers of the + network (layer1...layer4). + groups (int): Number of groups for the 3x3 convolution in each + bottleneck block. + - For SENet154: 64 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 32 + reduction (int): Reduction ratio for Squeeze-and-Excitation modules. + - For all models: 16 + dropout_p (float or None): Drop probability for the Dropout layer. + If `None` the Dropout layer is not used. + - For SENet154: 0.2 + - For SE-ResNet models: None + - For SE-ResNeXt models: None + inplanes (int): Number of input channels for layer1. + - For SENet154: 128 + - For SE-ResNet models: 64 + - For SE-ResNeXt models: 64 + input_3x3 (bool): If `True`, use three 3x3 convolutions instead of + a single 7x7 convolution in layer0. + - For SENet154: True + - For SE-ResNet models: False + - For SE-ResNeXt models: False + downsample_kernel_size (int): Kernel size for downsampling convolutions + in layer2, layer3 and layer4. + - For SENet154: 3 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 1 + downsample_padding (int): Padding for downsampling convolutions in + layer2, layer3 and layer4. + - For SENet154: 1 + - For SE-ResNet models: 0 + - For SE-ResNeXt models: 0 + num_classes (int): Number of outputs in `classifier` layer. + """ + super(SENet, self).__init__() + self.inplanes = inplanes + self.loss = loss + + if input_3x3: + layer0_modules = [ + ( + 'conv1', + nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False) + ), + ('bn1', nn.BatchNorm2d(64)), + ('relu1', nn.ReLU(inplace=True)), + ( + 'conv2', + nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False) + ), + ('bn2', nn.BatchNorm2d(64)), + ('relu2', nn.ReLU(inplace=True)), + ( + 'conv3', + nn.Conv2d( + 64, inplanes, 3, stride=1, padding=1, bias=False + ) + ), + ('bn3', nn.BatchNorm2d(inplanes)), + ('relu3', nn.ReLU(inplace=True)), + ] + else: + layer0_modules = [ + ( + 'conv1', + nn.Conv2d( + 3, + inplanes, + kernel_size=7, + stride=2, + padding=3, + bias=False + ) + ), + ('bn1', nn.BatchNorm2d(inplanes)), + ('relu1', nn.ReLU(inplace=True)), + ] + # To preserve compatibility with Caffe weights `ceil_mode=True` + # is used instead of `padding=1`. + layer0_modules.append( + ('pool', nn.MaxPool2d(3, stride=2, ceil_mode=True)) + ) + self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) + self.layer1 = self._make_layer( + block, + planes=64, + blocks=layers[0], + groups=groups, + reduction=reduction, + downsample_kernel_size=1, + downsample_padding=0 + ) + self.layer2 = self._make_layer( + block, + planes=128, + blocks=layers[1], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.layer3 = self._make_layer( + block, + planes=256, + blocks=layers[2], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.layer4 = self._make_layer( + block, + planes=512, + blocks=layers[3], + stride=last_stride, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = self._construct_fc_layer( + fc_dims, 512 * block.expansion, dropout_p + ) + self.classifier = nn.Linear(self.feature_dim, num_classes) + + def _make_layer( + self, + block, + planes, + blocks, + groups, + reduction, + stride=1, + downsample_kernel_size=1, + downsample_padding=0 + ): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=downsample_kernel_size, + stride=stride, + padding=downsample_padding, + bias=False + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, planes, groups, reduction, stride, downsample + ) + ) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, groups, reduction)) + + return nn.Sequential(*layers) + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + """ + Construct fully connected layer + + - fc_dims (list or tuple): dimensions of fc layers, if None, + no fc layers are constructed + - input_dim (int): input dimension + - dropout_p (float): dropout probability, if None, dropout is unused + """ + if fc_dims is None: + self.feature_dim = input_dim + return None + + assert isinstance( + fc_dims, (list, tuple) + ), 'fc_dims must be either list or tuple, but got {}'.format( + type(fc_dims) + ) + + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU(inplace=True)) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + + self.feature_dim = fc_dims[-1] + + return nn.Sequential(*layers) + + def featuremaps(self, x): + x = self.layer0(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def forward(self, x): + f = self.featuremaps(x) + v = self.global_avgpool(f) + v = v.view(v.size(0), -1) + + if self.fc is not None: + v = self.fc(v) + + if not self.training: + return v + + y = self.classifier(v) + + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + return y, v + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +def senet154(num_classes, loss='softmax', pretrained=True, **kwargs): + model = SENet( + num_classes=num_classes, + loss=loss, + block=SEBottleneck, + layers=[3, 8, 36, 3], + groups=64, + reduction=16, + dropout_p=0.2, + last_stride=2, + fc_dims=None, + **kwargs + ) + if pretrained: + model_url = pretrained_settings['senet154']['imagenet']['url'] + init_pretrained_weights(model, model_url) + return model + + +def se_resnet50(num_classes, loss='softmax', pretrained=True, **kwargs): + model = SENet( + num_classes=num_classes, + loss=loss, + block=SEResNetBottleneck, + layers=[3, 4, 6, 3], + groups=1, + reduction=16, + dropout_p=None, + inplanes=64, + input_3x3=False, + downsample_kernel_size=1, + downsample_padding=0, + last_stride=2, + fc_dims=None, + **kwargs + ) + if pretrained: + model_url = pretrained_settings['se_resnet50']['imagenet']['url'] + init_pretrained_weights(model, model_url) + return model + + +def se_resnet50_fc512(num_classes, loss='softmax', pretrained=True, **kwargs): + model = SENet( + num_classes=num_classes, + loss=loss, + block=SEResNetBottleneck, + layers=[3, 4, 6, 3], + groups=1, + reduction=16, + dropout_p=None, + inplanes=64, + input_3x3=False, + downsample_kernel_size=1, + downsample_padding=0, + last_stride=1, + fc_dims=[512], + **kwargs + ) + if pretrained: + model_url = pretrained_settings['se_resnet50']['imagenet']['url'] + init_pretrained_weights(model, model_url) + return model + + +def se_resnet101(num_classes, loss='softmax', pretrained=True, **kwargs): + model = SENet( + num_classes=num_classes, + loss=loss, + block=SEResNetBottleneck, + layers=[3, 4, 23, 3], + groups=1, + reduction=16, + dropout_p=None, + inplanes=64, + input_3x3=False, + downsample_kernel_size=1, + downsample_padding=0, + last_stride=2, + fc_dims=None, + **kwargs + ) + if pretrained: + model_url = pretrained_settings['se_resnet101']['imagenet']['url'] + init_pretrained_weights(model, model_url) + return model + + +def se_resnet152(num_classes, loss='softmax', pretrained=True, **kwargs): + model = SENet( + num_classes=num_classes, + loss=loss, + block=SEResNetBottleneck, + layers=[3, 8, 36, 3], + groups=1, + reduction=16, + dropout_p=None, + inplanes=64, + input_3x3=False, + downsample_kernel_size=1, + downsample_padding=0, + last_stride=2, + fc_dims=None, + **kwargs + ) + if pretrained: + model_url = pretrained_settings['se_resnet152']['imagenet']['url'] + init_pretrained_weights(model, model_url) + return model + + +def se_resnext50_32x4d(num_classes, loss='softmax', pretrained=True, **kwargs): + model = SENet( + num_classes=num_classes, + loss=loss, + block=SEResNeXtBottleneck, + layers=[3, 4, 6, 3], + groups=32, + reduction=16, + dropout_p=None, + inplanes=64, + input_3x3=False, + downsample_kernel_size=1, + downsample_padding=0, + last_stride=2, + fc_dims=None, + **kwargs + ) + if pretrained: + model_url = pretrained_settings['se_resnext50_32x4d']['imagenet']['url' + ] + init_pretrained_weights(model, model_url) + return model + + +def se_resnext101_32x4d( + num_classes, loss='softmax', pretrained=True, **kwargs +): + model = SENet( + num_classes=num_classes, + loss=loss, + block=SEResNeXtBottleneck, + layers=[3, 4, 23, 3], + groups=32, + reduction=16, + dropout_p=None, + inplanes=64, + input_3x3=False, + downsample_kernel_size=1, + downsample_padding=0, + last_stride=2, + fc_dims=None, + **kwargs + ) + if pretrained: + model_url = pretrained_settings['se_resnext101_32x4d']['imagenet'][ + 'url'] + init_pretrained_weights(model, model_url) + return model diff --git a/feeder/trackers/strongsort/deep/models/shufflenet.py b/feeder/trackers/strongsort/deep/models/shufflenet.py new file mode 100644 index 0000000..bc4d34f --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/shufflenet.py @@ -0,0 +1,198 @@ +from __future__ import division, absolute_import +import torch +import torch.utils.model_zoo as model_zoo +from torch import nn +from torch.nn import functional as F + +__all__ = ['shufflenet'] + +model_urls = { + # training epoch = 90, top1 = 61.8 + 'imagenet': + 'https://mega.nz/#!RDpUlQCY!tr_5xBEkelzDjveIYBBcGcovNCOrgfiJO9kiidz9fZM', +} + + +class ChannelShuffle(nn.Module): + + def __init__(self, num_groups): + super(ChannelShuffle, self).__init__() + self.g = num_groups + + def forward(self, x): + b, c, h, w = x.size() + n = c // self.g + # reshape + x = x.view(b, self.g, n, h, w) + # transpose + x = x.permute(0, 2, 1, 3, 4).contiguous() + # flatten + x = x.view(b, c, h, w) + return x + + +class Bottleneck(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + stride, + num_groups, + group_conv1x1=True + ): + super(Bottleneck, self).__init__() + assert stride in [1, 2], 'Warning: stride must be either 1 or 2' + self.stride = stride + mid_channels = out_channels // 4 + if stride == 2: + out_channels -= in_channels + # group conv is not applied to first conv1x1 at stage 2 + num_groups_conv1x1 = num_groups if group_conv1x1 else 1 + self.conv1 = nn.Conv2d( + in_channels, + mid_channels, + 1, + groups=num_groups_conv1x1, + bias=False + ) + self.bn1 = nn.BatchNorm2d(mid_channels) + self.shuffle1 = ChannelShuffle(num_groups) + self.conv2 = nn.Conv2d( + mid_channels, + mid_channels, + 3, + stride=stride, + padding=1, + groups=mid_channels, + bias=False + ) + self.bn2 = nn.BatchNorm2d(mid_channels) + self.conv3 = nn.Conv2d( + mid_channels, out_channels, 1, groups=num_groups, bias=False + ) + self.bn3 = nn.BatchNorm2d(out_channels) + if stride == 2: + self.shortcut = nn.AvgPool2d(3, stride=2, padding=1) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.shuffle1(out) + out = self.bn2(self.conv2(out)) + out = self.bn3(self.conv3(out)) + if self.stride == 2: + res = self.shortcut(x) + out = F.relu(torch.cat([res, out], 1)) + else: + out = F.relu(x + out) + return out + + +# configuration of (num_groups: #out_channels) based on Table 1 in the paper +cfg = { + 1: [144, 288, 576], + 2: [200, 400, 800], + 3: [240, 480, 960], + 4: [272, 544, 1088], + 8: [384, 768, 1536], +} + + +class ShuffleNet(nn.Module): + """ShuffleNet. + + Reference: + Zhang et al. ShuffleNet: An Extremely Efficient Convolutional Neural + Network for Mobile Devices. CVPR 2018. + + Public keys: + - ``shufflenet``: ShuffleNet (groups=3). + """ + + def __init__(self, num_classes, loss='softmax', num_groups=3, **kwargs): + super(ShuffleNet, self).__init__() + self.loss = loss + + self.conv1 = nn.Sequential( + nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(24), + nn.ReLU(), + nn.MaxPool2d(3, stride=2, padding=1), + ) + + self.stage2 = nn.Sequential( + Bottleneck( + 24, cfg[num_groups][0], 2, num_groups, group_conv1x1=False + ), + Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups), + Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups), + Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups), + ) + + self.stage3 = nn.Sequential( + Bottleneck(cfg[num_groups][0], cfg[num_groups][1], 2, num_groups), + Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), + Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), + Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), + Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), + Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), + Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), + Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), + ) + + self.stage4 = nn.Sequential( + Bottleneck(cfg[num_groups][1], cfg[num_groups][2], 2, num_groups), + Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups), + Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups), + Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups), + ) + + self.classifier = nn.Linear(cfg[num_groups][2], num_classes) + self.feat_dim = cfg[num_groups][2] + + def forward(self, x): + x = self.conv1(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.stage4(x) + x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), -1) + + if not self.training: + return x + + y = self.classifier(x) + + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + return y, x + else: + raise KeyError('Unsupported loss: {}'.format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +def shufflenet(num_classes, loss='softmax', pretrained=True, **kwargs): + model = ShuffleNet(num_classes, loss, **kwargs) + if pretrained: + # init_pretrained_weights(model, model_urls['imagenet']) + import warnings + warnings.warn( + 'The imagenet pretrained weights need to be manually downloaded from {}' + .format(model_urls['imagenet']) + ) + return model diff --git a/feeder/trackers/strongsort/deep/models/shufflenetv2.py b/feeder/trackers/strongsort/deep/models/shufflenetv2.py new file mode 100644 index 0000000..3ff879e --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/shufflenetv2.py @@ -0,0 +1,262 @@ +""" +Code source: https://github.com/pytorch/vision +""" +from __future__ import division, absolute_import +import torch +import torch.utils.model_zoo as model_zoo +from torch import nn + +__all__ = [ + 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', + 'shufflenet_v2_x2_0' +] + +model_urls = { + 'shufflenetv2_x0.5': + 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth', + 'shufflenetv2_x1.0': + 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth', + 'shufflenetv2_x1.5': None, + 'shufflenetv2_x2.0': None, +} + + +def channel_shuffle(x, groups): + batchsize, num_channels, height, width = x.data.size() + channels_per_group = num_channels // groups + + # reshape + x = x.view(batchsize, groups, channels_per_group, height, width) + + x = torch.transpose(x, 1, 2).contiguous() + + # flatten + x = x.view(batchsize, -1, height, width) + + return x + + +class InvertedResidual(nn.Module): + + def __init__(self, inp, oup, stride): + super(InvertedResidual, self).__init__() + + if not (1 <= stride <= 3): + raise ValueError('illegal stride value') + self.stride = stride + + branch_features = oup // 2 + assert (self.stride != 1) or (inp == branch_features << 1) + + if self.stride > 1: + self.branch1 = nn.Sequential( + self.depthwise_conv( + inp, inp, kernel_size=3, stride=self.stride, padding=1 + ), + nn.BatchNorm2d(inp), + nn.Conv2d( + inp, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False + ), + nn.BatchNorm2d(branch_features), + nn.ReLU(inplace=True), + ) + + self.branch2 = nn.Sequential( + nn.Conv2d( + inp if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False + ), + nn.BatchNorm2d(branch_features), + nn.ReLU(inplace=True), + self.depthwise_conv( + branch_features, + branch_features, + kernel_size=3, + stride=self.stride, + padding=1 + ), + nn.BatchNorm2d(branch_features), + nn.Conv2d( + branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False + ), + nn.BatchNorm2d(branch_features), + nn.ReLU(inplace=True), + ) + + @staticmethod + def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): + return nn.Conv2d( + i, o, kernel_size, stride, padding, bias=bias, groups=i + ) + + def forward(self, x): + if self.stride == 1: + x1, x2 = x.chunk(2, dim=1) + out = torch.cat((x1, self.branch2(x2)), dim=1) + else: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + + out = channel_shuffle(out, 2) + + return out + + +class ShuffleNetV2(nn.Module): + """ShuffleNetV2. + + Reference: + Ma et al. ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design. ECCV 2018. + + Public keys: + - ``shufflenet_v2_x0_5``: ShuffleNetV2 x0.5. + - ``shufflenet_v2_x1_0``: ShuffleNetV2 x1.0. + - ``shufflenet_v2_x1_5``: ShuffleNetV2 x1.5. + - ``shufflenet_v2_x2_0``: ShuffleNetV2 x2.0. + """ + + def __init__( + self, num_classes, loss, stages_repeats, stages_out_channels, **kwargs + ): + super(ShuffleNetV2, self).__init__() + self.loss = loss + + if len(stages_repeats) != 3: + raise ValueError( + 'expected stages_repeats as list of 3 positive ints' + ) + if len(stages_out_channels) != 5: + raise ValueError( + 'expected stages_out_channels as list of 5 positive ints' + ) + self._stage_out_channels = stages_out_channels + + input_channels = 3 + output_channels = self._stage_out_channels[0] + self.conv1 = nn.Sequential( + nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), + nn.BatchNorm2d(output_channels), + nn.ReLU(inplace=True), + ) + input_channels = output_channels + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] + for name, repeats, output_channels in zip( + stage_names, stages_repeats, self._stage_out_channels[1:] + ): + seq = [InvertedResidual(input_channels, output_channels, 2)] + for i in range(repeats - 1): + seq.append( + InvertedResidual(output_channels, output_channels, 1) + ) + setattr(self, name, nn.Sequential(*seq)) + input_channels = output_channels + + output_channels = self._stage_out_channels[-1] + self.conv5 = nn.Sequential( + nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), + nn.BatchNorm2d(output_channels), + nn.ReLU(inplace=True), + ) + self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + self.classifier = nn.Linear(output_channels, num_classes) + + def featuremaps(self, x): + x = self.conv1(x) + x = self.maxpool(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.stage4(x) + x = self.conv5(x) + return x + + def forward(self, x): + f = self.featuremaps(x) + v = self.global_avgpool(f) + v = v.view(v.size(0), -1) + + if not self.training: + return v + + y = self.classifier(v) + + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + return y, v + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + if model_url is None: + import warnings + warnings.warn( + 'ImageNet pretrained weights are unavailable for this model' + ) + return + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +def shufflenet_v2_x0_5(num_classes, loss='softmax', pretrained=True, **kwargs): + model = ShuffleNetV2( + num_classes, loss, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['shufflenetv2_x0.5']) + return model + + +def shufflenet_v2_x1_0(num_classes, loss='softmax', pretrained=True, **kwargs): + model = ShuffleNetV2( + num_classes, loss, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['shufflenetv2_x1.0']) + return model + + +def shufflenet_v2_x1_5(num_classes, loss='softmax', pretrained=True, **kwargs): + model = ShuffleNetV2( + num_classes, loss, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['shufflenetv2_x1.5']) + return model + + +def shufflenet_v2_x2_0(num_classes, loss='softmax', pretrained=True, **kwargs): + model = ShuffleNetV2( + num_classes, loss, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['shufflenetv2_x2.0']) + return model diff --git a/feeder/trackers/strongsort/deep/models/squeezenet.py b/feeder/trackers/strongsort/deep/models/squeezenet.py new file mode 100644 index 0000000..83e8dc9 --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/squeezenet.py @@ -0,0 +1,236 @@ +""" +Code source: https://github.com/pytorch/vision +""" +from __future__ import division, absolute_import +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + +__all__ = ['squeezenet1_0', 'squeezenet1_1', 'squeezenet1_0_fc512'] + +model_urls = { + 'squeezenet1_0': + 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', + 'squeezenet1_1': + 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', +} + + +class Fire(nn.Module): + + def __init__( + self, inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes + ): + super(Fire, self).__init__() + self.inplanes = inplanes + self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) + self.squeeze_activation = nn.ReLU(inplace=True) + self.expand1x1 = nn.Conv2d( + squeeze_planes, expand1x1_planes, kernel_size=1 + ) + self.expand1x1_activation = nn.ReLU(inplace=True) + self.expand3x3 = nn.Conv2d( + squeeze_planes, expand3x3_planes, kernel_size=3, padding=1 + ) + self.expand3x3_activation = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.squeeze_activation(self.squeeze(x)) + return torch.cat( + [ + self.expand1x1_activation(self.expand1x1(x)), + self.expand3x3_activation(self.expand3x3(x)) + ], 1 + ) + + +class SqueezeNet(nn.Module): + """SqueezeNet. + + Reference: + Iandola et al. SqueezeNet: AlexNet-level accuracy with 50x fewer parameters + and< 0.5 MB model size. arXiv:1602.07360. + + Public keys: + - ``squeezenet1_0``: SqueezeNet (version=1.0). + - ``squeezenet1_1``: SqueezeNet (version=1.1). + - ``squeezenet1_0_fc512``: SqueezeNet (version=1.0) + FC. + """ + + def __init__( + self, + num_classes, + loss, + version=1.0, + fc_dims=None, + dropout_p=None, + **kwargs + ): + super(SqueezeNet, self).__init__() + self.loss = loss + self.feature_dim = 512 + + if version not in [1.0, 1.1]: + raise ValueError( + 'Unsupported SqueezeNet version {version}:' + '1.0 or 1.1 expected'.format(version=version) + ) + + if version == 1.0: + self.features = nn.Sequential( + nn.Conv2d(3, 96, kernel_size=7, stride=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(96, 16, 64, 64), + Fire(128, 16, 64, 64), + Fire(128, 32, 128, 128), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(256, 32, 128, 128), + Fire(256, 48, 192, 192), + Fire(384, 48, 192, 192), + Fire(384, 64, 256, 256), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(512, 64, 256, 256), + ) + else: + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, stride=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(64, 16, 64, 64), + Fire(128, 16, 64, 64), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(128, 32, 128, 128), + Fire(256, 32, 128, 128), + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + Fire(256, 48, 192, 192), + Fire(384, 48, 192, 192), + Fire(384, 64, 256, 256), + Fire(512, 64, 256, 256), + ) + + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = self._construct_fc_layer(fc_dims, 512, dropout_p) + self.classifier = nn.Linear(self.feature_dim, num_classes) + + self._init_params() + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + """Constructs fully connected layer + + Args: + fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed + input_dim (int): input dimension + dropout_p (float): dropout probability, if None, dropout is unused + """ + if fc_dims is None: + self.feature_dim = input_dim + return None + + assert isinstance( + fc_dims, (list, tuple) + ), 'fc_dims must be either list or tuple, but got {}'.format( + type(fc_dims) + ) + + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU(inplace=True)) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + + self.feature_dim = fc_dims[-1] + + return nn.Sequential(*layers) + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu' + ) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + f = self.features(x) + v = self.global_avgpool(f) + v = v.view(v.size(0), -1) + + if self.fc is not None: + v = self.fc(v) + + if not self.training: + return v + + y = self.classifier(v) + + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + return y, v + else: + raise KeyError('Unsupported loss: {}'.format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url, map_location=None) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +def squeezenet1_0(num_classes, loss='softmax', pretrained=True, **kwargs): + model = SqueezeNet( + num_classes, loss, version=1.0, fc_dims=None, dropout_p=None, **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['squeezenet1_0']) + return model + + +def squeezenet1_0_fc512( + num_classes, loss='softmax', pretrained=True, **kwargs +): + model = SqueezeNet( + num_classes, + loss, + version=1.0, + fc_dims=[512], + dropout_p=None, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['squeezenet1_0']) + return model + + +def squeezenet1_1(num_classes, loss='softmax', pretrained=True, **kwargs): + model = SqueezeNet( + num_classes, loss, version=1.1, fc_dims=None, dropout_p=None, **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls['squeezenet1_1']) + return model diff --git a/feeder/trackers/strongsort/deep/models/xception.py b/feeder/trackers/strongsort/deep/models/xception.py new file mode 100644 index 0000000..43db4ab --- /dev/null +++ b/feeder/trackers/strongsort/deep/models/xception.py @@ -0,0 +1,344 @@ +from __future__ import division, absolute_import +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo + +__all__ = ['xception'] + +pretrained_settings = { + 'xception': { + 'imagenet': { + 'url': + 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth', + 'input_space': 'RGB', + 'input_size': [3, 299, 299], + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1000, + 'scale': + 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 + } + } +} + + +class SeparableConv2d(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=False + ): + super(SeparableConv2d, self).__init__() + + self.conv1 = nn.Conv2d( + in_channels, + in_channels, + kernel_size, + stride, + padding, + dilation, + groups=in_channels, + bias=bias + ) + self.pointwise = nn.Conv2d( + in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias + ) + + def forward(self, x): + x = self.conv1(x) + x = self.pointwise(x) + return x + + +class Block(nn.Module): + + def __init__( + self, + in_filters, + out_filters, + reps, + strides=1, + start_with_relu=True, + grow_first=True + ): + super(Block, self).__init__() + + if out_filters != in_filters or strides != 1: + self.skip = nn.Conv2d( + in_filters, out_filters, 1, stride=strides, bias=False + ) + self.skipbn = nn.BatchNorm2d(out_filters) + else: + self.skip = None + + self.relu = nn.ReLU(inplace=True) + rep = [] + + filters = in_filters + if grow_first: + rep.append(self.relu) + rep.append( + SeparableConv2d( + in_filters, + out_filters, + 3, + stride=1, + padding=1, + bias=False + ) + ) + rep.append(nn.BatchNorm2d(out_filters)) + filters = out_filters + + for i in range(reps - 1): + rep.append(self.relu) + rep.append( + SeparableConv2d( + filters, filters, 3, stride=1, padding=1, bias=False + ) + ) + rep.append(nn.BatchNorm2d(filters)) + + if not grow_first: + rep.append(self.relu) + rep.append( + SeparableConv2d( + in_filters, + out_filters, + 3, + stride=1, + padding=1, + bias=False + ) + ) + rep.append(nn.BatchNorm2d(out_filters)) + + if not start_with_relu: + rep = rep[1:] + else: + rep[0] = nn.ReLU(inplace=False) + + if strides != 1: + rep.append(nn.MaxPool2d(3, strides, 1)) + self.rep = nn.Sequential(*rep) + + def forward(self, inp): + x = self.rep(inp) + + if self.skip is not None: + skip = self.skip(inp) + skip = self.skipbn(skip) + else: + skip = inp + + x += skip + return x + + +class Xception(nn.Module): + """Xception. + + Reference: + Chollet. Xception: Deep Learning with Depthwise + Separable Convolutions. CVPR 2017. + + Public keys: + - ``xception``: Xception. + """ + + def __init__( + self, num_classes, loss, fc_dims=None, dropout_p=None, **kwargs + ): + super(Xception, self).__init__() + self.loss = loss + + self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False) + self.bn1 = nn.BatchNorm2d(32) + + self.conv2 = nn.Conv2d(32, 64, 3, bias=False) + self.bn2 = nn.BatchNorm2d(64) + + self.block1 = Block( + 64, 128, 2, 2, start_with_relu=False, grow_first=True + ) + self.block2 = Block( + 128, 256, 2, 2, start_with_relu=True, grow_first=True + ) + self.block3 = Block( + 256, 728, 2, 2, start_with_relu=True, grow_first=True + ) + + self.block4 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True + ) + self.block5 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True + ) + self.block6 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True + ) + self.block7 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True + ) + + self.block8 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True + ) + self.block9 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True + ) + self.block10 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True + ) + self.block11 = Block( + 728, 728, 3, 1, start_with_relu=True, grow_first=True + ) + + self.block12 = Block( + 728, 1024, 2, 2, start_with_relu=True, grow_first=False + ) + + self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) + self.bn3 = nn.BatchNorm2d(1536) + + self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) + self.bn4 = nn.BatchNorm2d(2048) + + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.feature_dim = 2048 + self.fc = self._construct_fc_layer(fc_dims, 2048, dropout_p) + self.classifier = nn.Linear(self.feature_dim, num_classes) + + self._init_params() + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + """Constructs fully connected layer. + + Args: + fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed + input_dim (int): input dimension + dropout_p (float): dropout probability, if None, dropout is unused + """ + if fc_dims is None: + self.feature_dim = input_dim + return None + + assert isinstance( + fc_dims, (list, tuple) + ), 'fc_dims must be either list or tuple, but got {}'.format( + type(fc_dims) + ) + + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU(inplace=True)) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + + self.feature_dim = fc_dims[-1] + + return nn.Sequential(*layers) + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu' + ) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def featuremaps(self, input): + x = self.conv1(input) + x = self.bn1(x) + x = F.relu(x, inplace=True) + + x = self.conv2(x) + x = self.bn2(x) + x = F.relu(x, inplace=True) + + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + x = self.block12(x) + + x = self.conv3(x) + x = self.bn3(x) + x = F.relu(x, inplace=True) + + x = self.conv4(x) + x = self.bn4(x) + x = F.relu(x, inplace=True) + return x + + def forward(self, x): + f = self.featuremaps(x) + v = self.global_avgpool(f) + v = v.view(v.size(0), -1) + + if self.fc is not None: + v = self.fc(v) + + if not self.training: + return v + + y = self.classifier(v) + + if self.loss == 'softmax': + return y + elif self.loss == 'triplet': + return y, v + else: + raise KeyError('Unsupported loss: {}'.format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """Initialize models with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +def xception(num_classes, loss='softmax', pretrained=True, **kwargs): + model = Xception(num_classes, loss, fc_dims=None, dropout_p=None, **kwargs) + if pretrained: + model_url = pretrained_settings['xception']['imagenet']['url'] + init_pretrained_weights(model, model_url) + return model diff --git a/feeder/trackers/strongsort/deep/reid_model_factory.py b/feeder/trackers/strongsort/deep/reid_model_factory.py new file mode 100644 index 0000000..ed0542d --- /dev/null +++ b/feeder/trackers/strongsort/deep/reid_model_factory.py @@ -0,0 +1,215 @@ +import torch +from collections import OrderedDict + + + +__model_types = [ + 'resnet50', 'mlfn', 'hacnn', 'mobilenetv2_x1_0', 'mobilenetv2_x1_4', + 'osnet_x1_0', 'osnet_x0_75', 'osnet_x0_5', 'osnet_x0_25', + 'osnet_ibn_x1_0', 'osnet_ain_x1_0'] + +__trained_urls = { + + # market1501 models ######################################################## + 'resnet50_market1501.pt': + 'https://drive.google.com/uc?id=1dUUZ4rHDWohmsQXCRe2C_HbYkzz94iBV', + 'resnet50_dukemtmcreid.pt': + 'https://drive.google.com/uc?id=17ymnLglnc64NRvGOitY3BqMRS9UWd1wg', + 'resnet50_msmt17.pt': + 'https://drive.google.com/uc?id=1ep7RypVDOthCRIAqDnn4_N-UhkkFHJsj', + + 'resnet50_fc512_market1501.pt': + 'https://drive.google.com/uc?id=1kv8l5laX_YCdIGVCetjlNdzKIA3NvsSt', + 'resnet50_fc512_dukemtmcreid.pt': + 'https://drive.google.com/uc?id=13QN8Mp3XH81GK4BPGXobKHKyTGH50Rtx', + 'resnet50_fc512_msmt17.pt': + 'https://drive.google.com/uc?id=1fDJLcz4O5wxNSUvImIIjoaIF9u1Rwaud', + + 'mlfn_market1501.pt': + 'https://drive.google.com/uc?id=1wXcvhA_b1kpDfrt9s2Pma-MHxtj9pmvS', + 'mlfn_dukemtmcreid.pt': + 'https://drive.google.com/uc?id=1rExgrTNb0VCIcOnXfMsbwSUW1h2L1Bum', + 'mlfn_msmt17.pt': + 'https://drive.google.com/uc?id=18JzsZlJb3Wm7irCbZbZ07TN4IFKvR6p-', + + 'hacnn_market1501.pt': + 'https://drive.google.com/uc?id=1LRKIQduThwGxMDQMiVkTScBwR7WidmYF', + 'hacnn_dukemtmcreid.pt': + 'https://drive.google.com/uc?id=1zNm6tP4ozFUCUQ7Sv1Z98EAJWXJEhtYH', + 'hacnn_msmt17.pt': + 'https://drive.google.com/uc?id=1MsKRtPM5WJ3_Tk2xC0aGOO7pM3VaFDNZ', + + 'mobilenetv2_x1_0_market1501.pt': + 'https://drive.google.com/uc?id=18DgHC2ZJkjekVoqBWszD8_Xiikz-fewp', + 'mobilenetv2_x1_0_dukemtmcreid.pt': + 'https://drive.google.com/uc?id=1q1WU2FETRJ3BXcpVtfJUuqq4z3psetds', + 'mobilenetv2_x1_0_msmt17.pt': + 'https://drive.google.com/uc?id=1j50Hv14NOUAg7ZeB3frzfX-WYLi7SrhZ', + + 'mobilenetv2_x1_4_market1501.pt': + 'https://drive.google.com/uc?id=1t6JCqphJG-fwwPVkRLmGGyEBhGOf2GO5', + 'mobilenetv2_x1_4_dukemtmcreid.pt': + 'https://drive.google.com/uc?id=12uD5FeVqLg9-AFDju2L7SQxjmPb4zpBN', + 'mobilenetv2_x1_4_msmt17.pt': + 'https://drive.google.com/uc?id=1ZY5P2Zgm-3RbDpbXM0kIBMPvspeNIbXz', + + 'osnet_x1_0_market1501.pt': + 'https://drive.google.com/uc?id=1vduhq5DpN2q1g4fYEZfPI17MJeh9qyrA', + 'osnet_x1_0_dukemtmcreid.pt': + 'https://drive.google.com/uc?id=1QZO_4sNf4hdOKKKzKc-TZU9WW1v6zQbq', + 'osnet_x1_0_msmt17.pt': + 'https://drive.google.com/uc?id=112EMUfBPYeYg70w-syK6V6Mx8-Qb9Q1M', + + 'osnet_x0_75_market1501.pt': + 'https://drive.google.com/uc?id=1ozRaDSQw_EQ8_93OUmjDbvLXw9TnfPer', + 'osnet_x0_75_dukemtmcreid.pt': + 'https://drive.google.com/uc?id=1IE3KRaTPp4OUa6PGTFL_d5_KQSJbP0Or', + 'osnet_x0_75_msmt17.pt': + 'https://drive.google.com/uc?id=1QEGO6WnJ-BmUzVPd3q9NoaO_GsPNlmWc', + + 'osnet_x0_5_market1501.pt': + 'https://drive.google.com/uc?id=1PLB9rgqrUM7blWrg4QlprCuPT7ILYGKT', + 'osnet_x0_5_dukemtmcreid.pt': + 'https://drive.google.com/uc?id=1KoUVqmiST175hnkALg9XuTi1oYpqcyTu', + 'osnet_x0_5_msmt17.pt': + 'https://drive.google.com/uc?id=1UT3AxIaDvS2PdxzZmbkLmjtiqq7AIKCv', + + 'osnet_x0_25_market1501.pt': + 'https://drive.google.com/uc?id=1z1UghYvOTtjx7kEoRfmqSMu-z62J6MAj', + 'osnet_x0_25_dukemtmcreid.pt': + 'https://drive.google.com/uc?id=1eumrtiXT4NOspjyEV4j8cHmlOaaCGk5l', + 'osnet_x0_25_msmt17.pt': + 'https://drive.google.com/uc?id=1sSwXSUlj4_tHZequ_iZ8w_Jh0VaRQMqF', + + ####### market1501 models ################################################## + 'resnet50_msmt17.pt': + 'https://drive.google.com/uc?id=1yiBteqgIZoOeywE8AhGmEQl7FTVwrQmf', + 'osnet_x1_0_msmt17.pt': + 'https://drive.google.com/uc?id=1IosIFlLiulGIjwW3H8uMRmx3MzPwf86x', + 'osnet_x0_75_msmt17.pt': + 'https://drive.google.com/uc?id=1fhjSS_7SUGCioIf2SWXaRGPqIY9j7-uw', + + 'osnet_x0_5_msmt17.pt': + 'https://drive.google.com/uc?id=1DHgmb6XV4fwG3n-CnCM0zdL9nMsZ9_RF', + 'osnet_x0_25_msmt17.pt': + 'https://drive.google.com/uc?id=1Kkx2zW89jq_NETu4u42CFZTMVD5Hwm6e', + 'osnet_ibn_x1_0_msmt17.pt': + 'https://drive.google.com/uc?id=1q3Sj2ii34NlfxA4LvmHdWO_75NDRmECJ', + 'osnet_ain_x1_0_msmt17.pt': + 'https://drive.google.com/uc?id=1SigwBE6mPdqiJMqhuIY4aqC7--5CsMal', +} + + +def show_downloadeable_models(): + print('\nAvailable .pt ReID models for automatic download') + print(list(__trained_urls.keys())) + + +def get_model_url(model): + if model.name in __trained_urls: + return __trained_urls[model.name] + else: + None + + +def is_model_in_model_types(model): + if model.name in __model_types: + return True + else: + return False + + +def get_model_name(model): + for x in __model_types: + if x in model.name: + return x + return None + + +def download_url(url, dst): + """Downloads file from a url to a destination. + + Args: + url (str): url to download file. + dst (str): destination path. + """ + from six.moves import urllib + print('* url="{}"'.format(url)) + print('* destination="{}"'.format(dst)) + + def _reporthook(count, block_size, total_size): + global start_time + if count == 0: + start_time = time.time() + return + duration = time.time() - start_time + progress_size = int(count * block_size) + speed = int(progress_size / (1024*duration)) + percent = int(count * block_size * 100 / total_size) + sys.stdout.write( + '\r...%d%%, %d MB, %d KB/s, %d seconds passed' % + (percent, progress_size / (1024*1024), speed, duration) + ) + sys.stdout.flush() + + urllib.request.urlretrieve(url, dst, _reporthook) + sys.stdout.write('\n') + + +def load_pretrained_weights(model, weight_path): + r"""Loads pretrianed weights to model. + + Features:: + - Incompatible layers (unmatched in name or size) will be ignored. + - Can automatically deal with keys containing "module.". + + Args: + model (nn.Module): network model. + weight_path (str): path to pretrained weights. + + Examples:: + >>> from torchreid.utils import load_pretrained_weights + >>> weight_path = 'log/my_model/model-best.pth.tar' + >>> load_pretrained_weights(model, weight_path) + """ + checkpoint = torch.load(weight_path) + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + model_dict = model.state_dict() + new_state_dict = OrderedDict() + matched_layers, discarded_layers = [], [] + + for k, v in state_dict.items(): + if k.startswith('module.'): + k = k[7:] # discard module. + + if k in model_dict and model_dict[k].size() == v.size(): + new_state_dict[k] = v + matched_layers.append(k) + else: + discarded_layers.append(k) + + model_dict.update(new_state_dict) + model.load_state_dict(model_dict) + + if len(matched_layers) == 0: + warnings.warn( + 'The pretrained weights "{}" cannot be loaded, ' + 'please check the key names manually ' + '(** ignored and continue **)'.format(weight_path) + ) + else: + print( + 'Successfully loaded pretrained weights from "{}"'. + format(weight_path) + ) + if len(discarded_layers) > 0: + print( + '** The following layers are discarded ' + 'due to unmatched keys or layer size: {}'. + format(discarded_layers) + ) + diff --git a/feeder/trackers/strongsort/reid_multibackend.py b/feeder/trackers/strongsort/reid_multibackend.py new file mode 100644 index 0000000..58d2fbb --- /dev/null +++ b/feeder/trackers/strongsort/reid_multibackend.py @@ -0,0 +1,237 @@ +import torch.nn as nn +import torch +from pathlib import Path +import numpy as np +from itertools import islice +import torchvision.transforms as transforms +import cv2 +import sys +import torchvision.transforms as T +from collections import OrderedDict, namedtuple +import gdown +from os.path import exists as file_exists + + +from ultralytics.yolo.utils.checks import check_requirements, check_version +from ultralytics.yolo.utils import LOGGER +from trackers.strongsort.deep.reid_model_factory import (show_downloadeable_models, get_model_url, get_model_name, + download_url, load_pretrained_weights) +from trackers.strongsort.deep.models import build_model + + +def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''): + # Check file(s) for acceptable suffix + if file and suffix: + if isinstance(suffix, str): + suffix = [suffix] + for f in file if isinstance(file, (list, tuple)) else [file]: + s = Path(f).suffix.lower() # file suffix + if len(s): + assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}" + + +class ReIDDetectMultiBackend(nn.Module): + # ReID models MultiBackend class for python inference on various backends + def __init__(self, weights='osnet_x0_25_msmt17.pt', device=torch.device('cpu'), fp16=False): + super().__init__() + + w = weights[0] if isinstance(weights, list) else weights + self.pt, self.jit, self.onnx, self.xml, self.engine, self.tflite = self.model_type(w) # get backend + self.fp16 = fp16 + self.fp16 &= self.pt or self.jit or self.engine # FP16 + + # Build transform functions + self.device = device + self.image_size=(256, 128) + self.pixel_mean=[0.485, 0.456, 0.406] + self.pixel_std=[0.229, 0.224, 0.225] + self.transforms = [] + self.transforms += [T.Resize(self.image_size)] + self.transforms += [T.ToTensor()] + self.transforms += [T.Normalize(mean=self.pixel_mean, std=self.pixel_std)] + self.preprocess = T.Compose(self.transforms) + self.to_pil = T.ToPILImage() + + model_name = get_model_name(w) + + if w.suffix == '.pt': + model_url = get_model_url(w) + if not file_exists(w) and model_url is not None: + gdown.download(model_url, str(w), quiet=False) + elif file_exists(w): + pass + else: + print(f'No URL associated to the chosen StrongSORT weights ({w}). Choose between:') + show_downloadeable_models() + exit() + + # Build model + self.model = build_model( + model_name, + num_classes=1, + pretrained=not (w and w.is_file()), + use_gpu=device + ) + + if self.pt: # PyTorch + # populate model arch with weights + if w and w.is_file() and w.suffix == '.pt': + load_pretrained_weights(self.model, w) + + self.model.to(device).eval() + self.model.half() if self.fp16 else self.model.float() + elif self.jit: + LOGGER.info(f'Loading {w} for TorchScript inference...') + self.model = torch.jit.load(w) + self.model.half() if self.fp16 else self.model.float() + elif self.onnx: # ONNX Runtime + LOGGER.info(f'Loading {w} for ONNX Runtime inference...') + cuda = torch.cuda.is_available() and device.type != 'cpu' + #check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime')) + import onnxruntime + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider'] + self.session = onnxruntime.InferenceSession(str(w), providers=providers) + elif self.engine: # TensorRT + LOGGER.info(f'Loading {w} for TensorRT inference...') + import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download + check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0 + if device.type == 'cpu': + device = torch.device('cuda:0') + Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) + logger = trt.Logger(trt.Logger.INFO) + with open(w, 'rb') as f, trt.Runtime(logger) as runtime: + self.model_ = runtime.deserialize_cuda_engine(f.read()) + self.context = self.model_.create_execution_context() + self.bindings = OrderedDict() + self.fp16 = False # default updated below + dynamic = False + for index in range(self.model_.num_bindings): + name = self.model_.get_binding_name(index) + dtype = trt.nptype(self.model_.get_binding_dtype(index)) + if self.model_.binding_is_input(index): + if -1 in tuple(self.model_.get_binding_shape(index)): # dynamic + dynamic = True + self.context.set_binding_shape(index, tuple(self.model_.get_profile_shape(0, index)[2])) + if dtype == np.float16: + self.fp16 = True + shape = tuple(self.context.get_binding_shape(index)) + im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) + self.bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) + self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items()) + batch_size = self.bindings['images'].shape[0] # if dynamic, this is instead max batch size + elif self.xml: # OpenVINO + LOGGER.info(f'Loading {w} for OpenVINO inference...') + check_requirements(('openvino',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/ + from openvino.runtime import Core, Layout, get_batch + ie = Core() + if not Path(w).is_file(): # if not *.xml + w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir + network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin')) + if network.get_parameters()[0].get_layout().empty: + network.get_parameters()[0].set_layout(Layout("NCWH")) + batch_dim = get_batch(network) + if batch_dim.is_static: + batch_size = batch_dim.get_length() + self.executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2 + self.output_layer = next(iter(self.executable_network.outputs)) + + elif self.tflite: + LOGGER.info(f'Loading {w} for TensorFlow Lite inference...') + try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu + from tflite_runtime.interpreter import Interpreter, load_delegate + except ImportError: + import tensorflow as tf + Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate, + self.interpreter = tf.lite.Interpreter(model_path=w) + self.interpreter.allocate_tensors() + # Get input and output tensors. + self.input_details = self.interpreter.get_input_details() + self.output_details = self.interpreter.get_output_details() + + # Test model on random input data. + input_data = np.array(np.random.random_sample((1,256,128,3)), dtype=np.float32) + self.interpreter.set_tensor(self.input_details[0]['index'], input_data) + + self.interpreter.invoke() + + # The function `get_tensor()` returns a copy of the tensor data. + output_data = self.interpreter.get_tensor(self.output_details[0]['index']) + else: + print('This model framework is not supported yet!') + exit() + + + @staticmethod + def model_type(p='path/to/model.pt'): + # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx + from trackers.reid_export import export_formats + sf = list(export_formats().Suffix) # export suffixes + check_suffix(p, sf) # checks + types = [s in Path(p).name for s in sf] + return types + + def _preprocess(self, im_batch): + + images = [] + for element in im_batch: + image = self.to_pil(element) + image = self.preprocess(image) + images.append(image) + + images = torch.stack(images, dim=0) + images = images.to(self.device) + + return images + + + def forward(self, im_batch): + + # preprocess batch + im_batch = self._preprocess(im_batch) + + # batch to half + if self.fp16 and im_batch.dtype != torch.float16: + im_batch = im_batch.half() + + # batch processing + features = [] + if self.pt: + features = self.model(im_batch) + elif self.jit: # TorchScript + features = self.model(im_batch) + elif self.onnx: # ONNX Runtime + im_batch = im_batch.cpu().numpy() # torch to numpy + features = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im_batch})[0] + elif self.engine: # TensorRT + if True and im_batch.shape != self.bindings['images'].shape: + i_in, i_out = (self.model_.get_binding_index(x) for x in ('images', 'output')) + self.context.set_binding_shape(i_in, im_batch.shape) # reshape if dynamic + self.bindings['images'] = self.bindings['images']._replace(shape=im_batch.shape) + self.bindings['output'].data.resize_(tuple(self.context.get_binding_shape(i_out))) + s = self.bindings['images'].shape + assert im_batch.shape == s, f"input size {im_batch.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}" + self.binding_addrs['images'] = int(im_batch.data_ptr()) + self.context.execute_v2(list(self.binding_addrs.values())) + features = self.bindings['output'].data + elif self.xml: # OpenVINO + im_batch = im_batch.cpu().numpy() # FP32 + features = self.executable_network([im_batch])[self.output_layer] + else: + print('Framework not supported at the moment, we are working on it...') + exit() + + if isinstance(features, (list, tuple)): + return self.from_numpy(features[0]) if len(features) == 1 else [self.from_numpy(x) for x in features] + else: + return self.from_numpy(features) + + def from_numpy(self, x): + return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x + + def warmup(self, imgsz=[(256, 128, 3)]): + # Warmup model by running inference once + warmup_types = self.pt, self.jit, self.onnx, self.engine, self.tflite + if any(warmup_types) and self.device.type != 'cpu': + im = [np.empty(*imgsz).astype(np.uint8)] # input + for _ in range(2 if self.jit else 1): # + self.forward(im) # warmup \ No newline at end of file diff --git a/feeder/trackers/strongsort/results/output_04.gif b/feeder/trackers/strongsort/results/output_04.gif new file mode 100644 index 0000000..d379b76 Binary files /dev/null and b/feeder/trackers/strongsort/results/output_04.gif differ diff --git a/feeder/trackers/strongsort/results/output_th025.gif b/feeder/trackers/strongsort/results/output_th025.gif new file mode 100644 index 0000000..fb5e7c1 Binary files /dev/null and b/feeder/trackers/strongsort/results/output_th025.gif differ diff --git a/feeder/trackers/strongsort/results/track_all_1280_025conf.gif b/feeder/trackers/strongsort/results/track_all_1280_025conf.gif new file mode 100644 index 0000000..f0b5718 Binary files /dev/null and b/feeder/trackers/strongsort/results/track_all_1280_025conf.gif differ diff --git a/feeder/trackers/strongsort/results/track_all_seg_1280_025conf.gif b/feeder/trackers/strongsort/results/track_all_seg_1280_025conf.gif new file mode 100644 index 0000000..d58476d Binary files /dev/null and b/feeder/trackers/strongsort/results/track_all_seg_1280_025conf.gif differ diff --git a/feeder/trackers/strongsort/results/track_pedestrians_1280_05conf.gif b/feeder/trackers/strongsort/results/track_pedestrians_1280_05conf.gif new file mode 100644 index 0000000..81ea91b Binary files /dev/null and b/feeder/trackers/strongsort/results/track_pedestrians_1280_05conf.gif differ diff --git a/feeder/trackers/strongsort/sort/__init__.py b/feeder/trackers/strongsort/sort/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feeder/trackers/strongsort/sort/detection.py b/feeder/trackers/strongsort/sort/detection.py new file mode 100644 index 0000000..1fb05f8 --- /dev/null +++ b/feeder/trackers/strongsort/sort/detection.py @@ -0,0 +1,58 @@ +# vim: expandtab:ts=4:sw=4 +import numpy as np + + +class Detection(object): + """ + This class represents a bounding box detection in a single image. + + Parameters + ---------- + tlwh : array_like + Bounding box in format `(x, y, w, h)`. + confidence : float + Detector confidence score. + feature : array_like + A feature vector that describes the object contained in this image. + + Attributes + ---------- + tlwh : ndarray + Bounding box in format `(top left x, top left y, width, height)`. + confidence : ndarray + Detector confidence score. + feature : ndarray | NoneType + A feature vector that describes the object contained in this image. + + """ + + def __init__(self, tlwh, confidence, feature): + self.tlwh = np.asarray(tlwh, dtype=np.float32) + self.confidence = float(confidence) + self.feature = np.asarray(feature.cpu(), dtype=np.float32) + + def to_tlbr(self): + """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., + `(top left, bottom right)`. + """ + ret = self.tlwh.copy() + ret[2:] += ret[:2] + return ret + + def to_xyah(self): + """Convert bounding box to format `(center x, center y, aspect ratio, + height)`, where the aspect ratio is `width / height`. + """ + ret = self.tlwh.copy() + ret[:2] += ret[2:] / 2 + ret[2] /= ret[3] + return ret + +def to_xyah_ext(bbox): + """Convert bounding box to format `(center x, center y, aspect ratio, + height)`, where the aspect ratio is `width / height`. + """ + ret = bbox.copy() + ret[:2] += ret[2:] / 2 + ret[2] /= ret[3] + return ret diff --git a/feeder/trackers/strongsort/sort/iou_matching.py b/feeder/trackers/strongsort/sort/iou_matching.py new file mode 100644 index 0000000..62d5a3f --- /dev/null +++ b/feeder/trackers/strongsort/sort/iou_matching.py @@ -0,0 +1,82 @@ +# vim: expandtab:ts=4:sw=4 +from __future__ import absolute_import +import numpy as np +from . import linear_assignment + + +def iou(bbox, candidates): + """Computer intersection over union. + + Parameters + ---------- + bbox : ndarray + A bounding box in format `(top left x, top left y, width, height)`. + candidates : ndarray + A matrix of candidate bounding boxes (one per row) in the same format + as `bbox`. + + Returns + ------- + ndarray + The intersection over union in [0, 1] between the `bbox` and each + candidate. A higher score means a larger fraction of the `bbox` is + occluded by the candidate. + + """ + bbox_tl, bbox_br = bbox[:2], bbox[:2] + bbox[2:] + candidates_tl = candidates[:, :2] + candidates_br = candidates[:, :2] + candidates[:, 2:] + + tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis], + np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]] + br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis], + np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]] + wh = np.maximum(0., br - tl) + + area_intersection = wh.prod(axis=1) + area_bbox = bbox[2:].prod() + area_candidates = candidates[:, 2:].prod(axis=1) + return area_intersection / (area_bbox + area_candidates - area_intersection) + + +def iou_cost(tracks, detections, track_indices=None, + detection_indices=None): + """An intersection over union distance metric. + + Parameters + ---------- + tracks : List[deep_sort.track.Track] + A list of tracks. + detections : List[deep_sort.detection.Detection] + A list of detections. + track_indices : Optional[List[int]] + A list of indices to tracks that should be matched. Defaults to + all `tracks`. + detection_indices : Optional[List[int]] + A list of indices to detections that should be matched. Defaults + to all `detections`. + + Returns + ------- + ndarray + Returns a cost matrix of shape + len(track_indices), len(detection_indices) where entry (i, j) is + `1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`. + + """ + if track_indices is None: + track_indices = np.arange(len(tracks)) + if detection_indices is None: + detection_indices = np.arange(len(detections)) + + cost_matrix = np.zeros((len(track_indices), len(detection_indices))) + for row, track_idx in enumerate(track_indices): + if tracks[track_idx].time_since_update > 1: + cost_matrix[row, :] = linear_assignment.INFTY_COST + continue + + bbox = tracks[track_idx].to_tlwh() + candidates = np.asarray( + [detections[i].tlwh for i in detection_indices]) + cost_matrix[row, :] = 1. - iou(bbox, candidates) + return cost_matrix diff --git a/feeder/trackers/strongsort/sort/kalman_filter.py b/feeder/trackers/strongsort/sort/kalman_filter.py new file mode 100644 index 0000000..87c48d7 --- /dev/null +++ b/feeder/trackers/strongsort/sort/kalman_filter.py @@ -0,0 +1,214 @@ +# vim: expandtab:ts=4:sw=4 +import numpy as np +import scipy.linalg +""" +Table for the 0.95 quantile of the chi-square distribution with N degrees of +freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv +function and used as Mahalanobis gating threshold. +""" +chi2inv95 = { + 1: 3.8415, + 2: 5.9915, + 3: 7.8147, + 4: 9.4877, + 5: 11.070, + 6: 12.592, + 7: 14.067, + 8: 15.507, + 9: 16.919} + + +class KalmanFilter(object): + """ + A simple Kalman filter for tracking bounding boxes in image space. + The 8-dimensional state space + x, y, a, h, vx, vy, va, vh + contains the bounding box center position (x, y), aspect ratio a, height h, + and their respective velocities. + Object motion follows a constant velocity model. The bounding box location + (x, y, a, h) is taken as direct observation of the state space (linear + observation model). + """ + + def __init__(self): + ndim, dt = 4, 1. + + # Create Kalman filter model matrices. + self._motion_mat = np.eye(2 * ndim, 2 * ndim) + for i in range(ndim): + self._motion_mat[i, ndim + i] = dt + + self._update_mat = np.eye(ndim, 2 * ndim) + + # Motion and observation uncertainty are chosen relative to the current + # state estimate. These weights control the amount of uncertainty in + # the model. This is a bit hacky. + self._std_weight_position = 1. / 20 + self._std_weight_velocity = 1. / 160 + + def initiate(self, measurement): + """Create track from unassociated measurement. + Parameters + ---------- + measurement : ndarray + Bounding box coordinates (x, y, a, h) with center position (x, y), + aspect ratio a, and height h. + Returns + ------- + (ndarray, ndarray) + Returns the mean vector (8 dimensional) and covariance matrix (8x8 + dimensional) of the new track. Unobserved velocities are initialized + to 0 mean. + """ + mean_pos = measurement + mean_vel = np.zeros_like(mean_pos) + mean = np.r_[mean_pos, mean_vel] + + std = [ + 2 * self._std_weight_position * measurement[0], # the center point x + 2 * self._std_weight_position * measurement[1], # the center point y + 1 * measurement[2], # the ratio of width/height + 2 * self._std_weight_position * measurement[3], # the height + 10 * self._std_weight_velocity * measurement[0], + 10 * self._std_weight_velocity * measurement[1], + 0.1 * measurement[2], + 10 * self._std_weight_velocity * measurement[3]] + covariance = np.diag(np.square(std)) + return mean, covariance + + def predict(self, mean, covariance): + """Run Kalman filter prediction step. + Parameters + ---------- + mean : ndarray + The 8 dimensional mean vector of the object state at the previous + time step. + covariance : ndarray + The 8x8 dimensional covariance matrix of the object state at the + previous time step. + Returns + ------- + (ndarray, ndarray) + Returns the mean vector and covariance matrix of the predicted + state. Unobserved velocities are initialized to 0 mean. + """ + std_pos = [ + self._std_weight_position * mean[0], + self._std_weight_position * mean[1], + 1 * mean[2], + self._std_weight_position * mean[3]] + std_vel = [ + self._std_weight_velocity * mean[0], + self._std_weight_velocity * mean[1], + 0.1 * mean[2], + self._std_weight_velocity * mean[3]] + motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) + + mean = np.dot(self._motion_mat, mean) + covariance = np.linalg.multi_dot(( + self._motion_mat, covariance, self._motion_mat.T)) + motion_cov + + return mean, covariance + + def project(self, mean, covariance, confidence=.0): + """Project state distribution to measurement space. + Parameters + ---------- + mean : ndarray + The state's mean vector (8 dimensional array). + covariance : ndarray + The state's covariance matrix (8x8 dimensional). + confidence: (dyh) 检测框置信度 + Returns + ------- + (ndarray, ndarray) + Returns the projected mean and covariance matrix of the given state + estimate. + """ + std = [ + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-1, + self._std_weight_position * mean[3]] + + + std = [(1 - confidence) * x for x in std] + + innovation_cov = np.diag(np.square(std)) + + mean = np.dot(self._update_mat, mean) + covariance = np.linalg.multi_dot(( + self._update_mat, covariance, self._update_mat.T)) + return mean, covariance + innovation_cov + + def update(self, mean, covariance, measurement, confidence=.0): + """Run Kalman filter correction step. + Parameters + ---------- + mean : ndarray + The predicted state's mean vector (8 dimensional). + covariance : ndarray + The state's covariance matrix (8x8 dimensional). + measurement : ndarray + The 4 dimensional measurement vector (x, y, a, h), where (x, y) + is the center position, a the aspect ratio, and h the height of the + bounding box. + confidence: (dyh)检测框置信度 + Returns + ------- + (ndarray, ndarray) + Returns the measurement-corrected state distribution. + """ + projected_mean, projected_cov = self.project(mean, covariance, confidence) + + chol_factor, lower = scipy.linalg.cho_factor( + projected_cov, lower=True, check_finite=False) + kalman_gain = scipy.linalg.cho_solve( + (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, + check_finite=False).T + innovation = measurement - projected_mean + + new_mean = mean + np.dot(innovation, kalman_gain.T) + new_covariance = covariance - np.linalg.multi_dot(( + kalman_gain, projected_cov, kalman_gain.T)) + return new_mean, new_covariance + + def gating_distance(self, mean, covariance, measurements, + only_position=False): + """Compute gating distance between state distribution and measurements. + A suitable distance threshold can be obtained from `chi2inv95`. If + `only_position` is False, the chi-square distribution has 4 degrees of + freedom, otherwise 2. + Parameters + ---------- + mean : ndarray + Mean vector over the state distribution (8 dimensional). + covariance : ndarray + Covariance of the state distribution (8x8 dimensional). + measurements : ndarray + An Nx4 dimensional matrix of N measurements, each in + format (x, y, a, h) where (x, y) is the bounding box center + position, a the aspect ratio, and h the height. + only_position : Optional[bool] + If True, distance computation is done with respect to the bounding + box center position only. + Returns + ------- + ndarray + Returns an array of length N, where the i-th element contains the + squared Mahalanobis distance between (mean, covariance) and + `measurements[i]`. + """ + mean, covariance = self.project(mean, covariance) + + if only_position: + mean, covariance = mean[:2], covariance[:2, :2] + measurements = measurements[:, :2] + + cholesky_factor = np.linalg.cholesky(covariance) + d = measurements - mean + z = scipy.linalg.solve_triangular( + cholesky_factor, d.T, lower=True, check_finite=False, + overwrite_b=True) + squared_maha = np.sum(z * z, axis=0) + return squared_maha \ No newline at end of file diff --git a/feeder/trackers/strongsort/sort/linear_assignment.py b/feeder/trackers/strongsort/sort/linear_assignment.py new file mode 100644 index 0000000..9ab92c5 --- /dev/null +++ b/feeder/trackers/strongsort/sort/linear_assignment.py @@ -0,0 +1,174 @@ +# vim: expandtab:ts=4:sw=4 +from __future__ import absolute_import +import numpy as np +from scipy.optimize import linear_sum_assignment +from . import kalman_filter + + +INFTY_COST = 1e+5 + + +def min_cost_matching( + distance_metric, max_distance, tracks, detections, track_indices=None, + detection_indices=None): + """Solve linear assignment problem. + Parameters + ---------- + distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray + The distance metric is given a list of tracks and detections as well as + a list of N track indices and M detection indices. The metric should + return the NxM dimensional cost matrix, where element (i, j) is the + association cost between the i-th track in the given track indices and + the j-th detection in the given detection_indices. + max_distance : float + Gating threshold. Associations with cost larger than this value are + disregarded. + tracks : List[track.Track] + A list of predicted tracks at the current time step. + detections : List[detection.Detection] + A list of detections at the current time step. + track_indices : List[int] + List of track indices that maps rows in `cost_matrix` to tracks in + `tracks` (see description above). + detection_indices : List[int] + List of detection indices that maps columns in `cost_matrix` to + detections in `detections` (see description above). + Returns + ------- + (List[(int, int)], List[int], List[int]) + Returns a tuple with the following three entries: + * A list of matched track and detection indices. + * A list of unmatched track indices. + * A list of unmatched detection indices. + """ + if track_indices is None: + track_indices = np.arange(len(tracks)) + if detection_indices is None: + detection_indices = np.arange(len(detections)) + + if len(detection_indices) == 0 or len(track_indices) == 0: + return [], track_indices, detection_indices # Nothing to match. + + cost_matrix = distance_metric( + tracks, detections, track_indices, detection_indices) + cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5 + row_indices, col_indices = linear_sum_assignment(cost_matrix) + + matches, unmatched_tracks, unmatched_detections = [], [], [] + for col, detection_idx in enumerate(detection_indices): + if col not in col_indices: + unmatched_detections.append(detection_idx) + for row, track_idx in enumerate(track_indices): + if row not in row_indices: + unmatched_tracks.append(track_idx) + for row, col in zip(row_indices, col_indices): + track_idx = track_indices[row] + detection_idx = detection_indices[col] + if cost_matrix[row, col] > max_distance: + unmatched_tracks.append(track_idx) + unmatched_detections.append(detection_idx) + else: + matches.append((track_idx, detection_idx)) + return matches, unmatched_tracks, unmatched_detections + + +def matching_cascade( + distance_metric, max_distance, cascade_depth, tracks, detections, + track_indices=None, detection_indices=None): + """Run matching cascade. + Parameters + ---------- + distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray + The distance metric is given a list of tracks and detections as well as + a list of N track indices and M detection indices. The metric should + return the NxM dimensional cost matrix, where element (i, j) is the + association cost between the i-th track in the given track indices and + the j-th detection in the given detection indices. + max_distance : float + Gating threshold. Associations with cost larger than this value are + disregarded. + cascade_depth: int + The cascade depth, should be se to the maximum track age. + tracks : List[track.Track] + A list of predicted tracks at the current time step. + detections : List[detection.Detection] + A list of detections at the current time step. + track_indices : Optional[List[int]] + List of track indices that maps rows in `cost_matrix` to tracks in + `tracks` (see description above). Defaults to all tracks. + detection_indices : Optional[List[int]] + List of detection indices that maps columns in `cost_matrix` to + detections in `detections` (see description above). Defaults to all + detections. + Returns + ------- + (List[(int, int)], List[int], List[int]) + Returns a tuple with the following three entries: + * A list of matched track and detection indices. + * A list of unmatched track indices. + * A list of unmatched detection indices. + """ + if track_indices is None: + track_indices = list(range(len(tracks))) + if detection_indices is None: + detection_indices = list(range(len(detections))) + + unmatched_detections = detection_indices + matches = [] + track_indices_l = [ + k for k in track_indices + # if tracks[k].time_since_update == 1 + level + ] + matches_l, _, unmatched_detections = \ + min_cost_matching( + distance_metric, max_distance, tracks, detections, + track_indices_l, unmatched_detections) + matches += matches_l + unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches)) + return matches, unmatched_tracks, unmatched_detections + + +def gate_cost_matrix( + cost_matrix, tracks, detections, track_indices, detection_indices, mc_lambda, + gated_cost=INFTY_COST, only_position=False): + """Invalidate infeasible entries in cost matrix based on the state + distributions obtained by Kalman filtering. + Parameters + ---------- + kf : The Kalman filter. + cost_matrix : ndarray + The NxM dimensional cost matrix, where N is the number of track indices + and M is the number of detection indices, such that entry (i, j) is the + association cost between `tracks[track_indices[i]]` and + `detections[detection_indices[j]]`. + tracks : List[track.Track] + A list of predicted tracks at the current time step. + detections : List[detection.Detection] + A list of detections at the current time step. + track_indices : List[int] + List of track indices that maps rows in `cost_matrix` to tracks in + `tracks` (see description above). + detection_indices : List[int] + List of detection indices that maps columns in `cost_matrix` to + detections in `detections` (see description above). + gated_cost : Optional[float] + Entries in the cost matrix corresponding to infeasible associations are + set this value. Defaults to a very large value. + only_position : Optional[bool] + If True, only the x, y position of the state distribution is considered + during gating. Defaults to False. + Returns + ------- + ndarray + Returns the modified cost matrix. + """ + gating_dim = 2 if only_position else 4 + gating_threshold = kalman_filter.chi2inv95[gating_dim] + measurements = np.asarray( + [detections[i].to_xyah() for i in detection_indices]) + for row, track_idx in enumerate(track_indices): + track = tracks[track_idx] + gating_distance = track.kf.gating_distance(track.mean, track.covariance, measurements, only_position) + cost_matrix[row, gating_distance > gating_threshold] = gated_cost + cost_matrix[row] = mc_lambda * cost_matrix[row] + (1 - mc_lambda) * gating_distance + return cost_matrix diff --git a/feeder/trackers/strongsort/sort/nn_matching.py b/feeder/trackers/strongsort/sort/nn_matching.py new file mode 100644 index 0000000..154f854 --- /dev/null +++ b/feeder/trackers/strongsort/sort/nn_matching.py @@ -0,0 +1,162 @@ +# vim: expandtab:ts=4:sw=4 +import numpy as np +import sys +import torch + + +def _pdist(a, b): + """Compute pair-wise squared distance between points in `a` and `b`. + Parameters + ---------- + a : array_like + An NxM matrix of N samples of dimensionality M. + b : array_like + An LxM matrix of L samples of dimensionality M. + Returns + ------- + ndarray + Returns a matrix of size len(a), len(b) such that eleement (i, j) + contains the squared distance between `a[i]` and `b[j]`. + """ + a, b = np.asarray(a), np.asarray(b) + if len(a) == 0 or len(b) == 0: + return np.zeros((len(a), len(b))) + a2, b2 = np.square(a).sum(axis=1), np.square(b).sum(axis=1) + r2 = -2. * np.dot(a, b.T) + a2[:, None] + b2[None, :] + r2 = np.clip(r2, 0., float(np.inf)) + return r2 + + +def _cosine_distance(a, b, data_is_normalized=False): + """Compute pair-wise cosine distance between points in `a` and `b`. + Parameters + ---------- + a : array_like + An NxM matrix of N samples of dimensionality M. + b : array_like + An LxM matrix of L samples of dimensionality M. + data_is_normalized : Optional[bool] + If True, assumes rows in a and b are unit length vectors. + Otherwise, a and b are explicitly normalized to lenght 1. + Returns + ------- + ndarray + Returns a matrix of size len(a), len(b) such that eleement (i, j) + contains the squared distance between `a[i]` and `b[j]`. + """ + if not data_is_normalized: + a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True) + b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True) + return 1. - np.dot(a, b.T) + + +def _nn_euclidean_distance(x, y): + """ Helper function for nearest neighbor distance metric (Euclidean). + Parameters + ---------- + x : ndarray + A matrix of N row-vectors (sample points). + y : ndarray + A matrix of M row-vectors (query points). + Returns + ------- + ndarray + A vector of length M that contains for each entry in `y` the + smallest Euclidean distance to a sample in `x`. + """ + # x_ = torch.from_numpy(np.asarray(x) / np.linalg.norm(x, axis=1, keepdims=True)) + # y_ = torch.from_numpy(np.asarray(y) / np.linalg.norm(y, axis=1, keepdims=True)) + distances = distances = _pdist(x, y) + return np.maximum(0.0, torch.min(distances, axis=0)[0].numpy()) + + +def _nn_cosine_distance(x, y): + """ Helper function for nearest neighbor distance metric (cosine). + Parameters + ---------- + x : ndarray + A matrix of N row-vectors (sample points). + y : ndarray + A matrix of M row-vectors (query points). + Returns + ------- + ndarray + A vector of length M that contains for each entry in `y` the + smallest cosine distance to a sample in `x`. + """ + x_ = torch.from_numpy(np.asarray(x)) + y_ = torch.from_numpy(np.asarray(y)) + distances = _cosine_distance(x_, y_) + distances = distances + return distances.min(axis=0) + + +class NearestNeighborDistanceMetric(object): + """ + A nearest neighbor distance metric that, for each target, returns + the closest distance to any sample that has been observed so far. + Parameters + ---------- + metric : str + Either "euclidean" or "cosine". + matching_threshold: float + The matching threshold. Samples with larger distance are considered an + invalid match. + budget : Optional[int] + If not None, fix samples per class to at most this number. Removes + the oldest samples when the budget is reached. + Attributes + ---------- + samples : Dict[int -> List[ndarray]] + A dictionary that maps from target identities to the list of samples + that have been observed so far. + """ + + def __init__(self, metric, matching_threshold, budget=None): + if metric == "euclidean": + self._metric = _nn_euclidean_distance + elif metric == "cosine": + self._metric = _nn_cosine_distance + else: + raise ValueError( + "Invalid metric; must be either 'euclidean' or 'cosine'") + self.matching_threshold = matching_threshold + self.budget = budget + self.samples = {} + + def partial_fit(self, features, targets, active_targets): + """Update the distance metric with new data. + Parameters + ---------- + features : ndarray + An NxM matrix of N features of dimensionality M. + targets : ndarray + An integer array of associated target identities. + active_targets : List[int] + A list of targets that are currently present in the scene. + """ + for feature, target in zip(features, targets): + self.samples.setdefault(target, []).append(feature) + if self.budget is not None: + self.samples[target] = self.samples[target][-self.budget:] + self.samples = {k: self.samples[k] for k in active_targets} + + def distance(self, features, targets): + """Compute distance between features and targets. + Parameters + ---------- + features : ndarray + An NxM matrix of N features of dimensionality M. + targets : List[int] + A list of targets to match the given `features` against. + Returns + ------- + ndarray + Returns a cost matrix of shape len(targets), len(features), where + element (i, j) contains the closest squared distance between + `targets[i]` and `features[j]`. + """ + cost_matrix = np.zeros((len(targets), len(features))) + for i, target in enumerate(targets): + cost_matrix[i, :] = self._metric(self.samples[target], features) + return cost_matrix \ No newline at end of file diff --git a/feeder/trackers/strongsort/sort/preprocessing.py b/feeder/trackers/strongsort/sort/preprocessing.py new file mode 100644 index 0000000..5493b12 --- /dev/null +++ b/feeder/trackers/strongsort/sort/preprocessing.py @@ -0,0 +1,73 @@ +# vim: expandtab:ts=4:sw=4 +import numpy as np +import cv2 + + +def non_max_suppression(boxes, max_bbox_overlap, scores=None): + """Suppress overlapping detections. + + Original code from [1]_ has been adapted to include confidence score. + + .. [1] http://www.pyimagesearch.com/2015/02/16/ + faster-non-maximum-suppression-python/ + + Examples + -------- + + >>> boxes = [d.roi for d in detections] + >>> scores = [d.confidence for d in detections] + >>> indices = non_max_suppression(boxes, max_bbox_overlap, scores) + >>> detections = [detections[i] for i in indices] + + Parameters + ---------- + boxes : ndarray + Array of ROIs (x, y, width, height). + max_bbox_overlap : float + ROIs that overlap more than this values are suppressed. + scores : Optional[array_like] + Detector confidence score. + + Returns + ------- + List[int] + Returns indices of detections that have survived non-maxima suppression. + + """ + if len(boxes) == 0: + return [] + + boxes = boxes.astype(np.float) + pick = [] + + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + boxes[:, 0] + y2 = boxes[:, 3] + boxes[:, 1] + + area = (x2 - x1 + 1) * (y2 - y1 + 1) + if scores is not None: + idxs = np.argsort(scores) + else: + idxs = np.argsort(y2) + + while len(idxs) > 0: + last = len(idxs) - 1 + i = idxs[last] + pick.append(i) + + xx1 = np.maximum(x1[i], x1[idxs[:last]]) + yy1 = np.maximum(y1[i], y1[idxs[:last]]) + xx2 = np.minimum(x2[i], x2[idxs[:last]]) + yy2 = np.minimum(y2[i], y2[idxs[:last]]) + + w = np.maximum(0, xx2 - xx1 + 1) + h = np.maximum(0, yy2 - yy1 + 1) + + overlap = (w * h) / area[idxs[:last]] + + idxs = np.delete( + idxs, np.concatenate( + ([last], np.where(overlap > max_bbox_overlap)[0]))) + + return pick diff --git a/feeder/trackers/strongsort/sort/track.py b/feeder/trackers/strongsort/sort/track.py new file mode 100644 index 0000000..bb6773f --- /dev/null +++ b/feeder/trackers/strongsort/sort/track.py @@ -0,0 +1,317 @@ +# vim: expandtab:ts=4:sw=4 +import cv2 +import numpy as np +from trackers.strongsort.sort.kalman_filter import KalmanFilter +from collections import deque + + +class TrackState: + """ + Enumeration type for the single target track state. Newly created tracks are + classified as `tentative` until enough evidence has been collected. Then, + the track state is changed to `confirmed`. Tracks that are no longer alive + are classified as `deleted` to mark them for removal from the set of active + tracks. + + """ + + Tentative = 1 + Confirmed = 2 + Deleted = 3 + + +class Track: + """ + A single target track with state space `(x, y, a, h)` and associated + velocities, where `(x, y)` is the center of the bounding box, `a` is the + aspect ratio and `h` is the height. + + Parameters + ---------- + mean : ndarray + Mean vector of the initial state distribution. + covariance : ndarray + Covariance matrix of the initial state distribution. + track_id : int + A unique track identifier. + n_init : int + Number of consecutive detections before the track is confirmed. The + track state is set to `Deleted` if a miss occurs within the first + `n_init` frames. + max_age : int + The maximum number of consecutive misses before the track state is + set to `Deleted`. + feature : Optional[ndarray] + Feature vector of the detection this track originates from. If not None, + this feature is added to the `features` cache. + + Attributes + ---------- + mean : ndarray + Mean vector of the initial state distribution. + covariance : ndarray + Covariance matrix of the initial state distribution. + track_id : int + A unique track identifier. + hits : int + Total number of measurement updates. + age : int + Total number of frames since first occurance. + time_since_update : int + Total number of frames since last measurement update. + state : TrackState + The current track state. + features : List[ndarray] + A cache of features. On each measurement update, the associated feature + vector is added to this list. + + """ + + def __init__(self, detection, track_id, class_id, conf, n_init, max_age, ema_alpha, + feature=None): + self.track_id = track_id + self.class_id = int(class_id) + self.hits = 1 + self.age = 1 + self.time_since_update = 0 + self.max_num_updates_wo_assignment = 7 + self.updates_wo_assignment = 0 + self.ema_alpha = ema_alpha + + self.state = TrackState.Tentative + self.features = [] + if feature is not None: + feature /= np.linalg.norm(feature) + self.features.append(feature) + + self.conf = conf + self._n_init = n_init + self._max_age = max_age + + self.kf = KalmanFilter() + self.mean, self.covariance = self.kf.initiate(detection) + + # Initializing trajectory queue + self.q = deque(maxlen=25) + + def to_tlwh(self): + """Get current position in bounding box format `(top left x, top left y, + width, height)`. + + Returns + ------- + ndarray + The bounding box. + + """ + ret = self.mean[:4].copy() + ret[2] *= ret[3] + ret[:2] -= ret[2:] / 2 + return ret + + def to_tlbr(self): + """Get kf estimated current position in bounding box format `(min x, miny, max x, + max y)`. + + Returns + ------- + ndarray + The predicted kf bounding box. + + """ + ret = self.to_tlwh() + ret[2:] = ret[:2] + ret[2:] + return ret + + + def ECC(self, src, dst, warp_mode = cv2.MOTION_EUCLIDEAN, eps = 1e-5, + max_iter = 100, scale = 0.1, align = False): + """Compute the warp matrix from src to dst. + Parameters + ---------- + src : ndarray + An NxM matrix of source img(BGR or Gray), it must be the same format as dst. + dst : ndarray + An NxM matrix of target img(BGR or Gray). + warp_mode: flags of opencv + translation: cv2.MOTION_TRANSLATION + rotated and shifted: cv2.MOTION_EUCLIDEAN + affine(shift,rotated,shear): cv2.MOTION_AFFINE + homography(3d): cv2.MOTION_HOMOGRAPHY + eps: float + the threshold of the increment in the correlation coefficient between two iterations + max_iter: int + the number of iterations. + scale: float or [int, int] + scale_ratio: float + scale_size: [W, H] + align: bool + whether to warp affine or perspective transforms to the source image + Returns + ------- + warp matrix : ndarray + Returns the warp matrix from src to dst. + if motion models is homography, the warp matrix will be 3x3, otherwise 2x3 + src_aligned: ndarray + aligned source image of gray + """ + + # BGR2GRAY + if src.ndim == 3: + # Convert images to grayscale + src = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY) + dst = cv2.cvtColor(dst, cv2.COLOR_BGR2GRAY) + + # make the imgs smaller to speed up + if scale is not None: + if isinstance(scale, float) or isinstance(scale, int): + if scale != 1: + src_r = cv2.resize(src, (0, 0), fx = scale, fy = scale,interpolation = cv2.INTER_LINEAR) + dst_r = cv2.resize(dst, (0, 0), fx = scale, fy = scale,interpolation = cv2.INTER_LINEAR) + scale = [scale, scale] + else: + src_r, dst_r = src, dst + scale = None + else: + if scale[0] != src.shape[1] and scale[1] != src.shape[0]: + src_r = cv2.resize(src, (scale[0], scale[1]), interpolation = cv2.INTER_LINEAR) + dst_r = cv2.resize(dst, (scale[0], scale[1]), interpolation=cv2.INTER_LINEAR) + scale = [scale[0] / src.shape[1], scale[1] / src.shape[0]] + else: + src_r, dst_r = src, dst + scale = None + else: + src_r, dst_r = src, dst + + # Define 2x3 or 3x3 matrices and initialize the matrix to identity + if warp_mode == cv2.MOTION_HOMOGRAPHY : + warp_matrix = np.eye(3, 3, dtype=np.float32) + else : + warp_matrix = np.eye(2, 3, dtype=np.float32) + + # Define termination criteria + criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, max_iter, eps) + + # Run the ECC algorithm. The results are stored in warp_matrix. + try: + (cc, warp_matrix) = cv2.findTransformECC (src_r, dst_r, warp_matrix, warp_mode, criteria, None, 1) + except cv2.error as e: + print('ecc transform failed') + return None, None + + if scale is not None: + warp_matrix[0, 2] = warp_matrix[0, 2] / scale[0] + warp_matrix[1, 2] = warp_matrix[1, 2] / scale[1] + + if align: + sz = src.shape + if warp_mode == cv2.MOTION_HOMOGRAPHY: + # Use warpPerspective for Homography + src_aligned = cv2.warpPerspective(src, warp_matrix, (sz[1],sz[0]), flags=cv2.INTER_LINEAR) + else : + # Use warpAffine for Translation, Euclidean and Affine + src_aligned = cv2.warpAffine(src, warp_matrix, (sz[1],sz[0]), flags=cv2.INTER_LINEAR) + return warp_matrix, src_aligned + else: + return warp_matrix, None + + + def get_matrix(self, matrix): + eye = np.eye(3) + dist = np.linalg.norm(eye - matrix) + if dist < 100: + return matrix + else: + return eye + + def camera_update(self, previous_frame, next_frame): + warp_matrix, src_aligned = self.ECC(previous_frame, next_frame) + if warp_matrix is None and src_aligned is None: + return + [a,b] = warp_matrix + warp_matrix=np.array([a,b,[0,0,1]]) + warp_matrix = warp_matrix.tolist() + matrix = self.get_matrix(warp_matrix) + + x1, y1, x2, y2 = self.to_tlbr() + x1_, y1_, _ = matrix @ np.array([x1, y1, 1]).T + x2_, y2_, _ = matrix @ np.array([x2, y2, 1]).T + w, h = x2_ - x1_, y2_ - y1_ + cx, cy = x1_ + w / 2, y1_ + h / 2 + self.mean[:4] = [cx, cy, w / h, h] + + + def increment_age(self): + self.age += 1 + self.time_since_update += 1 + + def predict(self, kf): + """Propagate the state distribution to the current time step using a + Kalman filter prediction step. + + Parameters + ---------- + kf : kalman_filter.KalmanFilter + The Kalman filter. + + """ + self.mean, self.covariance = self.kf.predict(self.mean, self.covariance) + self.age += 1 + self.time_since_update += 1 + + def update_kf(self, bbox, confidence=0.5): + self.updates_wo_assignment = self.updates_wo_assignment + 1 + self.mean, self.covariance = self.kf.update(self.mean, self.covariance, bbox, confidence) + tlbr = self.to_tlbr() + x_c = int((tlbr[0] + tlbr[2]) / 2) + y_c = int((tlbr[1] + tlbr[3]) / 2) + self.q.append(('predupdate', (x_c, y_c))) + + def update(self, detection, class_id, conf): + """Perform Kalman filter measurement update step and update the feature + cache. + Parameters + ---------- + detection : Detection + The associated detection. + """ + self.conf = conf + self.class_id = class_id.int() + self.mean, self.covariance = self.kf.update(self.mean, self.covariance, detection.to_xyah(), detection.confidence) + + feature = detection.feature / np.linalg.norm(detection.feature) + + smooth_feat = self.ema_alpha * self.features[-1] + (1 - self.ema_alpha) * feature + smooth_feat /= np.linalg.norm(smooth_feat) + self.features = [smooth_feat] + + self.hits += 1 + self.time_since_update = 0 + if self.state == TrackState.Tentative and self.hits >= self._n_init: + self.state = TrackState.Confirmed + + tlbr = self.to_tlbr() + x_c = int((tlbr[0] + tlbr[2]) / 2) + y_c = int((tlbr[1] + tlbr[3]) / 2) + self.q.append(('observationupdate', (x_c, y_c))) + + def mark_missed(self): + """Mark this track as missed (no association at the current time step). + """ + if self.state == TrackState.Tentative: + self.state = TrackState.Deleted + elif self.time_since_update > self._max_age: + self.state = TrackState.Deleted + + def is_tentative(self): + """Returns True if this track is tentative (unconfirmed). + """ + return self.state == TrackState.Tentative + + def is_confirmed(self): + """Returns True if this track is confirmed.""" + return self.state == TrackState.Confirmed + + def is_deleted(self): + """Returns True if this track is dead and should be deleted.""" + return self.state == TrackState.Deleted diff --git a/feeder/trackers/strongsort/sort/tracker.py b/feeder/trackers/strongsort/sort/tracker.py new file mode 100644 index 0000000..d889277 --- /dev/null +++ b/feeder/trackers/strongsort/sort/tracker.py @@ -0,0 +1,192 @@ +# vim: expandtab:ts=4:sw=4 +from __future__ import absolute_import +import numpy as np +from . import kalman_filter +from . import linear_assignment +from . import iou_matching +from . import detection +from .track import Track + + +class Tracker: + """ + This is the multi-target tracker. + Parameters + ---------- + metric : nn_matching.NearestNeighborDistanceMetric + A distance metric for measurement-to-track association. + max_age : int + Maximum number of missed misses before a track is deleted. + n_init : int + Number of consecutive detections before the track is confirmed. The + track state is set to `Deleted` if a miss occurs within the first + `n_init` frames. + Attributes + ---------- + metric : nn_matching.NearestNeighborDistanceMetric + The distance metric used for measurement to track association. + max_age : int + Maximum number of missed misses before a track is deleted. + n_init : int + Number of frames that a track remains in initialization phase. + kf : kalman_filter.KalmanFilter + A Kalman filter to filter target trajectories in image space. + tracks : List[Track] + The list of active tracks at the current time step. + """ + GATING_THRESHOLD = np.sqrt(kalman_filter.chi2inv95[4]) + + def __init__(self, metric, max_iou_dist=0.9, max_age=30, max_unmatched_preds=7, n_init=3, _lambda=0, ema_alpha=0.9, mc_lambda=0.995): + self.metric = metric + self.max_iou_dist = max_iou_dist + self.max_age = max_age + self.n_init = n_init + self._lambda = _lambda + self.ema_alpha = ema_alpha + self.mc_lambda = mc_lambda + self.max_unmatched_preds = max_unmatched_preds + + self.kf = kalman_filter.KalmanFilter() + self.tracks = [] + self._next_id = 1 + + def predict(self): + """Propagate track state distributions one time step forward. + + This function should be called once every time step, before `update`. + """ + for track in self.tracks: + track.predict(self.kf) + + def increment_ages(self): + for track in self.tracks: + track.increment_age() + track.mark_missed() + + def camera_update(self, previous_img, current_img): + for track in self.tracks: + track.camera_update(previous_img, current_img) + + def pred_n_update_all_tracks(self): + """Perform predictions and updates for all tracks by its own predicted state. + + """ + self.predict() + for t in self.tracks: + if self.max_unmatched_preds != 0 and t.updates_wo_assignment < t.max_num_updates_wo_assignment: + bbox = t.to_tlwh() + t.update_kf(detection.to_xyah_ext(bbox)) + + def update(self, detections, classes, confidences): + """Perform measurement update and track management. + + Parameters + ---------- + detections : List[deep_sort.detection.Detection] + A list of detections at the current time step. + + """ + # Run matching cascade. + matches, unmatched_tracks, unmatched_detections = \ + self._match(detections) + + # Update track set. + for track_idx, detection_idx in matches: + self.tracks[track_idx].update( + detections[detection_idx], classes[detection_idx], confidences[detection_idx]) + for track_idx in unmatched_tracks: + self.tracks[track_idx].mark_missed() + if self.max_unmatched_preds != 0 and self.tracks[track_idx].updates_wo_assignment < self.tracks[track_idx].max_num_updates_wo_assignment: + bbox = self.tracks[track_idx].to_tlwh() + self.tracks[track_idx].update_kf(detection.to_xyah_ext(bbox)) + for detection_idx in unmatched_detections: + self._initiate_track(detections[detection_idx], classes[detection_idx].item(), confidences[detection_idx].item()) + self.tracks = [t for t in self.tracks if not t.is_deleted()] + + # Update distance metric. + active_targets = [t.track_id for t in self.tracks if t.is_confirmed()] + features, targets = [], [] + for track in self.tracks: + if not track.is_confirmed(): + continue + features += track.features + targets += [track.track_id for _ in track.features] + self.metric.partial_fit(np.asarray(features), np.asarray(targets), active_targets) + + def _full_cost_metric(self, tracks, dets, track_indices, detection_indices): + """ + This implements the full lambda-based cost-metric. However, in doing so, it disregards + the possibility to gate the position only which is provided by + linear_assignment.gate_cost_matrix(). Instead, I gate by everything. + Note that the Mahalanobis distance is itself an unnormalised metric. Given the cosine + distance being normalised, we employ a quick and dirty normalisation based on the + threshold: that is, we divide the positional-cost by the gating threshold, thus ensuring + that the valid values range 0-1. + Note also that the authors work with the squared distance. I also sqrt this, so that it + is more intuitive in terms of values. + """ + # Compute First the Position-based Cost Matrix + pos_cost = np.empty([len(track_indices), len(detection_indices)]) + msrs = np.asarray([dets[i].to_xyah() for i in detection_indices]) + for row, track_idx in enumerate(track_indices): + pos_cost[row, :] = np.sqrt( + self.kf.gating_distance( + tracks[track_idx].mean, tracks[track_idx].covariance, msrs, False + ) + ) / self.GATING_THRESHOLD + pos_gate = pos_cost > 1.0 + # Now Compute the Appearance-based Cost Matrix + app_cost = self.metric.distance( + np.array([dets[i].feature for i in detection_indices]), + np.array([tracks[i].track_id for i in track_indices]), + ) + app_gate = app_cost > self.metric.matching_threshold + # Now combine and threshold + cost_matrix = self._lambda * pos_cost + (1 - self._lambda) * app_cost + cost_matrix[np.logical_or(pos_gate, app_gate)] = linear_assignment.INFTY_COST + # Return Matrix + return cost_matrix + + def _match(self, detections): + + def gated_metric(tracks, dets, track_indices, detection_indices): + features = np.array([dets[i].feature for i in detection_indices]) + targets = np.array([tracks[i].track_id for i in track_indices]) + cost_matrix = self.metric.distance(features, targets) + cost_matrix = linear_assignment.gate_cost_matrix(cost_matrix, tracks, dets, track_indices, detection_indices, self.mc_lambda) + + return cost_matrix + + # Split track set into confirmed and unconfirmed tracks. + confirmed_tracks = [ + i for i, t in enumerate(self.tracks) if t.is_confirmed()] + unconfirmed_tracks = [ + i for i, t in enumerate(self.tracks) if not t.is_confirmed()] + + # Associate confirmed tracks using appearance features. + matches_a, unmatched_tracks_a, unmatched_detections = \ + linear_assignment.matching_cascade( + gated_metric, self.metric.matching_threshold, self.max_age, + self.tracks, detections, confirmed_tracks) + + # Associate remaining tracks together with unconfirmed tracks using IOU. + iou_track_candidates = unconfirmed_tracks + [ + k for k in unmatched_tracks_a if + self.tracks[k].time_since_update == 1] + unmatched_tracks_a = [ + k for k in unmatched_tracks_a if + self.tracks[k].time_since_update != 1] + matches_b, unmatched_tracks_b, unmatched_detections = \ + linear_assignment.min_cost_matching( + iou_matching.iou_cost, self.max_iou_dist, self.tracks, + detections, iou_track_candidates, unmatched_detections) + + matches = matches_a + matches_b + unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b)) + return matches, unmatched_tracks, unmatched_detections + + def _initiate_track(self, detection, class_id, conf): + self.tracks.append(Track( + detection.to_xyah(), self._next_id, class_id, conf, self.n_init, self.max_age, self.ema_alpha, + detection.feature)) + self._next_id += 1 diff --git a/feeder/trackers/strongsort/strong_sort.py b/feeder/trackers/strongsort/strong_sort.py new file mode 100644 index 0000000..352d2c1 --- /dev/null +++ b/feeder/trackers/strongsort/strong_sort.py @@ -0,0 +1,151 @@ +import numpy as np +import torch +import sys +import cv2 +import gdown +from os.path import exists as file_exists, join +import torchvision.transforms as transforms + +from sort.nn_matching import NearestNeighborDistanceMetric +from sort.detection import Detection +from sort.tracker import Tracker + +from reid_multibackend import ReIDDetectMultiBackend + +from ultralytics.yolo.utils.ops import xyxy2xywh + + +class StrongSORT(object): + def __init__(self, + model_weights, + device, + fp16, + max_dist=0.2, + max_iou_dist=0.7, + max_age=70, + max_unmatched_preds=7, + n_init=3, + nn_budget=100, + mc_lambda=0.995, + ema_alpha=0.9 + ): + + self.model = ReIDDetectMultiBackend(weights=model_weights, device=device, fp16=fp16) + + self.max_dist = max_dist + metric = NearestNeighborDistanceMetric( + "cosine", self.max_dist, nn_budget) + self.tracker = Tracker( + metric, max_iou_dist=max_iou_dist, max_age=max_age, n_init=n_init, max_unmatched_preds=max_unmatched_preds, mc_lambda=mc_lambda, ema_alpha=ema_alpha) + + def update(self, dets, ori_img): + + xyxys = dets[:, 0:4] + confs = dets[:, 4] + clss = dets[:, 5] + + classes = clss.numpy() + xywhs = xyxy2xywh(xyxys.numpy()) + confs = confs.numpy() + self.height, self.width = ori_img.shape[:2] + + # generate detections + features = self._get_features(xywhs, ori_img) + bbox_tlwh = self._xywh_to_tlwh(xywhs) + detections = [Detection(bbox_tlwh[i], conf, features[i]) for i, conf in enumerate( + confs)] + + # run on non-maximum supression + boxes = np.array([d.tlwh for d in detections]) + scores = np.array([d.confidence for d in detections]) + + # update tracker + self.tracker.predict() + self.tracker.update(detections, clss, confs) + + # output bbox identities + outputs = [] + for track in self.tracker.tracks: + if not track.is_confirmed() or track.time_since_update > 1: + continue + + box = track.to_tlwh() + x1, y1, x2, y2 = self._tlwh_to_xyxy(box) + + track_id = track.track_id + class_id = track.class_id + conf = track.conf + queue = track.q + outputs.append(np.array([x1, y1, x2, y2, track_id, class_id, conf, queue], dtype=object)) + if len(outputs) > 0: + outputs = np.stack(outputs, axis=0) + return outputs + + """ + TODO: + Convert bbox from xc_yc_w_h to xtl_ytl_w_h + Thanks JieChen91@github.com for reporting this bug! + """ + @staticmethod + def _xywh_to_tlwh(bbox_xywh): + if isinstance(bbox_xywh, np.ndarray): + bbox_tlwh = bbox_xywh.copy() + elif isinstance(bbox_xywh, torch.Tensor): + bbox_tlwh = bbox_xywh.clone() + bbox_tlwh[:, 0] = bbox_xywh[:, 0] - bbox_xywh[:, 2] / 2. + bbox_tlwh[:, 1] = bbox_xywh[:, 1] - bbox_xywh[:, 3] / 2. + return bbox_tlwh + + def _xywh_to_xyxy(self, bbox_xywh): + x, y, w, h = bbox_xywh + x1 = max(int(x - w / 2), 0) + x2 = min(int(x + w / 2), self.width - 1) + y1 = max(int(y - h / 2), 0) + y2 = min(int(y + h / 2), self.height - 1) + return x1, y1, x2, y2 + + def _tlwh_to_xyxy(self, bbox_tlwh): + """ + TODO: + Convert bbox from xtl_ytl_w_h to xc_yc_w_h + Thanks JieChen91@github.com for reporting this bug! + """ + x, y, w, h = bbox_tlwh + x1 = max(int(x), 0) + x2 = min(int(x+w), self.width - 1) + y1 = max(int(y), 0) + y2 = min(int(y+h), self.height - 1) + return x1, y1, x2, y2 + + def increment_ages(self): + self.tracker.increment_ages() + + def _xyxy_to_tlwh(self, bbox_xyxy): + x1, y1, x2, y2 = bbox_xyxy + + t = x1 + l = y1 + w = int(x2 - x1) + h = int(y2 - y1) + return t, l, w, h + + def _get_features(self, bbox_xywh, ori_img): + im_crops = [] + for box in bbox_xywh: + x1, y1, x2, y2 = self._xywh_to_xyxy(box) + im = ori_img[y1:y2, x1:x2] + im_crops.append(im) + if im_crops: + features = self.model(im_crops) + else: + features = np.array([]) + return features + + def trajectory(self, im0, q, color): + # Add rectangle to image (PIL-only) + for i, p in enumerate(q): + thickness = int(np.sqrt(float (i + 1)) * 1.5) + if p[0] == 'observationupdate': + cv2.circle(im0, p[1], 2, color=color, thickness=thickness) + else: + cv2.circle(im0, p[1], 2, color=(255,255,255), thickness=thickness) diff --git a/feeder/trackers/strongsort/utils/__init__.py b/feeder/trackers/strongsort/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/feeder/trackers/strongsort/utils/asserts.py b/feeder/trackers/strongsort/utils/asserts.py new file mode 100644 index 0000000..59a73cc --- /dev/null +++ b/feeder/trackers/strongsort/utils/asserts.py @@ -0,0 +1,13 @@ +from os import environ + + +def assert_in(file, files_to_check): + if file not in files_to_check: + raise AssertionError("{} does not exist in the list".format(str(file))) + return True + + +def assert_in_env(check_list: list): + for item in check_list: + assert_in(item, environ.keys()) + return True diff --git a/feeder/trackers/strongsort/utils/draw.py b/feeder/trackers/strongsort/utils/draw.py new file mode 100644 index 0000000..bc7cb53 --- /dev/null +++ b/feeder/trackers/strongsort/utils/draw.py @@ -0,0 +1,36 @@ +import numpy as np +import cv2 + +palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1) + + +def compute_color_for_labels(label): + """ + Simple function that adds fixed color depending on the class + """ + color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette] + return tuple(color) + + +def draw_boxes(img, bbox, identities=None, offset=(0,0)): + for i,box in enumerate(bbox): + x1,y1,x2,y2 = [int(i) for i in box] + x1 += offset[0] + x2 += offset[0] + y1 += offset[1] + y2 += offset[1] + # box text and bar + id = int(identities[i]) if identities is not None else 0 + color = compute_color_for_labels(id) + label = '{}{:d}'.format("", id) + t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 2 , 2)[0] + cv2.rectangle(img,(x1, y1),(x2,y2),color,3) + cv2.rectangle(img,(x1, y1),(x1+t_size[0]+3,y1+t_size[1]+4), color,-1) + cv2.putText(img,label,(x1,y1+t_size[1]+4), cv2.FONT_HERSHEY_PLAIN, 2, [255,255,255], 2) + return img + + + +if __name__ == '__main__': + for i in range(82): + print(compute_color_for_labels(i)) diff --git a/feeder/trackers/strongsort/utils/evaluation.py b/feeder/trackers/strongsort/utils/evaluation.py new file mode 100644 index 0000000..1001794 --- /dev/null +++ b/feeder/trackers/strongsort/utils/evaluation.py @@ -0,0 +1,103 @@ +import os +import numpy as np +import copy +import motmetrics as mm +mm.lap.default_solver = 'lap' +from utils.io import read_results, unzip_objs + + +class Evaluator(object): + + def __init__(self, data_root, seq_name, data_type): + self.data_root = data_root + self.seq_name = seq_name + self.data_type = data_type + + self.load_annotations() + self.reset_accumulator() + + def load_annotations(self): + assert self.data_type == 'mot' + + gt_filename = os.path.join(self.data_root, self.seq_name, 'gt', 'gt.txt') + self.gt_frame_dict = read_results(gt_filename, self.data_type, is_gt=True) + self.gt_ignore_frame_dict = read_results(gt_filename, self.data_type, is_ignore=True) + + def reset_accumulator(self): + self.acc = mm.MOTAccumulator(auto_id=True) + + def eval_frame(self, frame_id, trk_tlwhs, trk_ids, rtn_events=False): + # results + trk_tlwhs = np.copy(trk_tlwhs) + trk_ids = np.copy(trk_ids) + + # gts + gt_objs = self.gt_frame_dict.get(frame_id, []) + gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2] + + # ignore boxes + ignore_objs = self.gt_ignore_frame_dict.get(frame_id, []) + ignore_tlwhs = unzip_objs(ignore_objs)[0] + + + # remove ignored results + keep = np.ones(len(trk_tlwhs), dtype=bool) + iou_distance = mm.distances.iou_matrix(ignore_tlwhs, trk_tlwhs, max_iou=0.5) + if len(iou_distance) > 0: + match_is, match_js = mm.lap.linear_sum_assignment(iou_distance) + match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js]) + match_ious = iou_distance[match_is, match_js] + + match_js = np.asarray(match_js, dtype=int) + match_js = match_js[np.logical_not(np.isnan(match_ious))] + keep[match_js] = False + trk_tlwhs = trk_tlwhs[keep] + trk_ids = trk_ids[keep] + + # get distance matrix + iou_distance = mm.distances.iou_matrix(gt_tlwhs, trk_tlwhs, max_iou=0.5) + + # acc + self.acc.update(gt_ids, trk_ids, iou_distance) + + if rtn_events and iou_distance.size > 0 and hasattr(self.acc, 'last_mot_events'): + events = self.acc.last_mot_events # only supported by https://github.com/longcw/py-motmetrics + else: + events = None + return events + + def eval_file(self, filename): + self.reset_accumulator() + + result_frame_dict = read_results(filename, self.data_type, is_gt=False) + frames = sorted(list(set(self.gt_frame_dict.keys()) | set(result_frame_dict.keys()))) + for frame_id in frames: + trk_objs = result_frame_dict.get(frame_id, []) + trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2] + self.eval_frame(frame_id, trk_tlwhs, trk_ids, rtn_events=False) + + return self.acc + + @staticmethod + def get_summary(accs, names, metrics=('mota', 'num_switches', 'idp', 'idr', 'idf1', 'precision', 'recall')): + names = copy.deepcopy(names) + if metrics is None: + metrics = mm.metrics.motchallenge_metrics + metrics = copy.deepcopy(metrics) + + mh = mm.metrics.create() + summary = mh.compute_many( + accs, + metrics=metrics, + names=names, + generate_overall=True + ) + + return summary + + @staticmethod + def save_summary(summary, filename): + import pandas as pd + writer = pd.ExcelWriter(filename) + summary.to_excel(writer) + writer.save() diff --git a/feeder/trackers/strongsort/utils/io.py b/feeder/trackers/strongsort/utils/io.py new file mode 100644 index 0000000..2dc9afd --- /dev/null +++ b/feeder/trackers/strongsort/utils/io.py @@ -0,0 +1,133 @@ +import os +from typing import Dict +import numpy as np + +# from utils.log import get_logger + + +def write_results(filename, results, data_type): + if data_type == 'mot': + save_format = '{frame},{id},{x1},{y1},{w},{h},-1,-1,-1,-1\n' + elif data_type == 'kitti': + save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n' + else: + raise ValueError(data_type) + + with open(filename, 'w') as f: + for frame_id, tlwhs, track_ids in results: + if data_type == 'kitti': + frame_id -= 1 + for tlwh, track_id in zip(tlwhs, track_ids): + if track_id < 0: + continue + x1, y1, w, h = tlwh + x2, y2 = x1 + w, y1 + h + line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h) + f.write(line) + + +# def write_results(filename, results_dict: Dict, data_type: str): +# if not filename: +# return +# path = os.path.dirname(filename) +# if not os.path.exists(path): +# os.makedirs(path) + +# if data_type in ('mot', 'mcmot', 'lab'): +# save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n' +# elif data_type == 'kitti': +# save_format = '{frame} {id} pedestrian -1 -1 -10 {x1} {y1} {x2} {y2} -1 -1 -1 -1000 -1000 -1000 -10 {score}\n' +# else: +# raise ValueError(data_type) + +# with open(filename, 'w') as f: +# for frame_id, frame_data in results_dict.items(): +# if data_type == 'kitti': +# frame_id -= 1 +# for tlwh, track_id in frame_data: +# if track_id < 0: +# continue +# x1, y1, w, h = tlwh +# x2, y2 = x1 + w, y1 + h +# line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h, score=1.0) +# f.write(line) +# logger.info('Save results to {}'.format(filename)) + + +def read_results(filename, data_type: str, is_gt=False, is_ignore=False): + if data_type in ('mot', 'lab'): + read_fun = read_mot_results + else: + raise ValueError('Unknown data type: {}'.format(data_type)) + + return read_fun(filename, is_gt, is_ignore) + + +""" +labels={'ped', ... % 1 +'person_on_vhcl', ... % 2 +'car', ... % 3 +'bicycle', ... % 4 +'mbike', ... % 5 +'non_mot_vhcl', ... % 6 +'static_person', ... % 7 +'distractor', ... % 8 +'occluder', ... % 9 +'occluder_on_grnd', ... %10 +'occluder_full', ... % 11 +'reflection', ... % 12 +'crowd' ... % 13 +}; +""" + + +def read_mot_results(filename, is_gt, is_ignore): + valid_labels = {1} + ignore_labels = {2, 7, 8, 12} + results_dict = dict() + if os.path.isfile(filename): + with open(filename, 'r') as f: + for line in f.readlines(): + linelist = line.split(',') + if len(linelist) < 7: + continue + fid = int(linelist[0]) + if fid < 1: + continue + results_dict.setdefault(fid, list()) + + if is_gt: + if 'MOT16-' in filename or 'MOT17-' in filename: + label = int(float(linelist[7])) + mark = int(float(linelist[6])) + if mark == 0 or label not in valid_labels: + continue + score = 1 + elif is_ignore: + if 'MOT16-' in filename or 'MOT17-' in filename: + label = int(float(linelist[7])) + vis_ratio = float(linelist[8]) + if label not in ignore_labels and vis_ratio >= 0: + continue + else: + continue + score = 1 + else: + score = float(linelist[6]) + + tlwh = tuple(map(float, linelist[2:6])) + target_id = int(linelist[1]) + + results_dict[fid].append((tlwh, target_id, score)) + + return results_dict + + +def unzip_objs(objs): + if len(objs) > 0: + tlwhs, ids, scores = zip(*objs) + else: + tlwhs, ids, scores = [], [], [] + tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4) + + return tlwhs, ids, scores \ No newline at end of file diff --git a/feeder/trackers/strongsort/utils/json_logger.py b/feeder/trackers/strongsort/utils/json_logger.py new file mode 100644 index 0000000..0afd0b4 --- /dev/null +++ b/feeder/trackers/strongsort/utils/json_logger.py @@ -0,0 +1,383 @@ +""" +References: + https://medium.com/analytics-vidhya/creating-a-custom-logging-mechanism-for-real-time-object-detection-using-tdd-4ca2cfcd0a2f +""" +import json +from os import makedirs +from os.path import exists, join +from datetime import datetime + + +class JsonMeta(object): + HOURS = 3 + MINUTES = 59 + SECONDS = 59 + PATH_TO_SAVE = 'LOGS' + DEFAULT_FILE_NAME = 'remaining' + + +class BaseJsonLogger(object): + """ + This is the base class that returns __dict__ of its own + it also returns the dicts of objects in the attributes that are list instances + + """ + + def dic(self): + # returns dicts of objects + out = {} + for k, v in self.__dict__.items(): + if hasattr(v, 'dic'): + out[k] = v.dic() + elif isinstance(v, list): + out[k] = self.list(v) + else: + out[k] = v + return out + + @staticmethod + def list(values): + # applies the dic method on items in the list + return [v.dic() if hasattr(v, 'dic') else v for v in values] + + +class Label(BaseJsonLogger): + """ + For each bounding box there are various categories with confidences. Label class keeps track of that information. + """ + + def __init__(self, category: str, confidence: float): + self.category = category + self.confidence = confidence + + +class Bbox(BaseJsonLogger): + """ + This module stores the information for each frame and use them in JsonParser + Attributes: + labels (list): List of label module. + top (int): + left (int): + width (int): + height (int): + + Args: + bbox_id (float): + top (int): + left (int): + width (int): + height (int): + + References: + Check Label module for better understanding. + + + """ + + def __init__(self, bbox_id, top, left, width, height): + self.labels = [] + self.bbox_id = bbox_id + self.top = top + self.left = left + self.width = width + self.height = height + + def add_label(self, category, confidence): + # adds category and confidence only if top_k is not exceeded. + self.labels.append(Label(category, confidence)) + + def labels_full(self, value): + return len(self.labels) == value + + +class Frame(BaseJsonLogger): + """ + This module stores the information for each frame and use them in JsonParser + Attributes: + timestamp (float): The elapsed time of captured frame + frame_id (int): The frame number of the captured video + bboxes (list of Bbox objects): Stores the list of bbox objects. + + References: + Check Bbox class for better information + + Args: + timestamp (float): + frame_id (int): + + """ + + def __init__(self, frame_id: int, timestamp: float = None): + self.frame_id = frame_id + self.timestamp = timestamp + self.bboxes = [] + + def add_bbox(self, bbox_id: int, top: int, left: int, width: int, height: int): + bboxes_ids = [bbox.bbox_id for bbox in self.bboxes] + if bbox_id not in bboxes_ids: + self.bboxes.append(Bbox(bbox_id, top, left, width, height)) + else: + raise ValueError("Frame with id: {} already has a Bbox with id: {}".format(self.frame_id, bbox_id)) + + def add_label_to_bbox(self, bbox_id: int, category: str, confidence: float): + bboxes = {bbox.id: bbox for bbox in self.bboxes} + if bbox_id in bboxes.keys(): + res = bboxes.get(bbox_id) + res.add_label(category, confidence) + else: + raise ValueError('the bbox with id: {} does not exists!'.format(bbox_id)) + + +class BboxToJsonLogger(BaseJsonLogger): + """ + ُ This module is designed to automate the task of logging jsons. An example json is used + to show the contents of json file shortly + Example: + { + "video_details": { + "frame_width": 1920, + "frame_height": 1080, + "frame_rate": 20, + "video_name": "/home/gpu/codes/MSD/pedestrian_2/project/public/camera1.avi" + }, + "frames": [ + { + "frame_id": 329, + "timestamp": 3365.1254 + "bboxes": [ + { + "labels": [ + { + "category": "pedestrian", + "confidence": 0.9 + } + ], + "bbox_id": 0, + "top": 1257, + "left": 138, + "width": 68, + "height": 109 + } + ] + }], + + Attributes: + frames (dict): It's a dictionary that maps each frame_id to json attributes. + video_details (dict): information about video file. + top_k_labels (int): shows the allowed number of labels + start_time (datetime object): we use it to automate the json output by time. + + Args: + top_k_labels (int): shows the allowed number of labels + + """ + + def __init__(self, top_k_labels: int = 1): + self.frames = {} + self.video_details = self.video_details = dict(frame_width=None, frame_height=None, frame_rate=None, + video_name=None) + self.top_k_labels = top_k_labels + self.start_time = datetime.now() + + def set_top_k(self, value): + self.top_k_labels = value + + def frame_exists(self, frame_id: int) -> bool: + """ + Args: + frame_id (int): + + Returns: + bool: true if frame_id is recognized + """ + return frame_id in self.frames.keys() + + def add_frame(self, frame_id: int, timestamp: float = None) -> None: + """ + Args: + frame_id (int): + timestamp (float): opencv captured frame time property + + Raises: + ValueError: if frame_id would not exist in class frames attribute + + Returns: + None + + """ + if not self.frame_exists(frame_id): + self.frames[frame_id] = Frame(frame_id, timestamp) + else: + raise ValueError("Frame id: {} already exists".format(frame_id)) + + def bbox_exists(self, frame_id: int, bbox_id: int) -> bool: + """ + Args: + frame_id: + bbox_id: + + Returns: + bool: if bbox exists in frame bboxes list + """ + bboxes = [] + if self.frame_exists(frame_id=frame_id): + bboxes = [bbox.bbox_id for bbox in self.frames[frame_id].bboxes] + return bbox_id in bboxes + + def find_bbox(self, frame_id: int, bbox_id: int): + """ + + Args: + frame_id: + bbox_id: + + Returns: + bbox_id (int): + + Raises: + ValueError: if bbox_id does not exist in the bbox list of specific frame. + """ + if not self.bbox_exists(frame_id, bbox_id): + raise ValueError("frame with id: {} does not contain bbox with id: {}".format(frame_id, bbox_id)) + bboxes = {bbox.bbox_id: bbox for bbox in self.frames[frame_id].bboxes} + return bboxes.get(bbox_id) + + def add_bbox_to_frame(self, frame_id: int, bbox_id: int, top: int, left: int, width: int, height: int) -> None: + """ + + Args: + frame_id (int): + bbox_id (int): + top (int): + left (int): + width (int): + height (int): + + Returns: + None + + Raises: + ValueError: if bbox_id already exist in frame information with frame_id + ValueError: if frame_id does not exist in frames attribute + """ + if self.frame_exists(frame_id): + frame = self.frames[frame_id] + if not self.bbox_exists(frame_id, bbox_id): + frame.add_bbox(bbox_id, top, left, width, height) + else: + raise ValueError( + "frame with frame_id: {} already contains the bbox with id: {} ".format(frame_id, bbox_id)) + else: + raise ValueError("frame with frame_id: {} does not exist".format(frame_id)) + + def add_label_to_bbox(self, frame_id: int, bbox_id: int, category: str, confidence: float): + """ + Args: + frame_id: + bbox_id: + category: + confidence: the confidence value returned from yolo detection + + Returns: + None + + Raises: + ValueError: if labels quota (top_k_labels) exceeds. + """ + bbox = self.find_bbox(frame_id, bbox_id) + if not bbox.labels_full(self.top_k_labels): + bbox.add_label(category, confidence) + else: + raise ValueError("labels in frame_id: {}, bbox_id: {} is fulled".format(frame_id, bbox_id)) + + def add_video_details(self, frame_width: int = None, frame_height: int = None, frame_rate: int = None, + video_name: str = None): + self.video_details['frame_width'] = frame_width + self.video_details['frame_height'] = frame_height + self.video_details['frame_rate'] = frame_rate + self.video_details['video_name'] = video_name + + def output(self): + output = {'video_details': self.video_details} + result = list(self.frames.values()) + output['frames'] = [item.dic() for item in result] + return output + + def json_output(self, output_name): + """ + Args: + output_name: + + Returns: + None + + Notes: + It creates the json output with `output_name` name. + """ + if not output_name.endswith('.json'): + output_name += '.json' + with open(output_name, 'w') as file: + json.dump(self.output(), file) + file.close() + + def set_start(self): + self.start_time = datetime.now() + + def schedule_output_by_time(self, output_dir=JsonMeta.PATH_TO_SAVE, hours: int = 0, minutes: int = 0, + seconds: int = 60) -> None: + """ + Notes: + Creates folder and then periodically stores the jsons on that address. + + Args: + output_dir (str): the directory where output files will be stored + hours (int): + minutes (int): + seconds (int): + + Returns: + None + + """ + end = datetime.now() + interval = 0 + interval += abs(min([hours, JsonMeta.HOURS]) * 3600) + interval += abs(min([minutes, JsonMeta.MINUTES]) * 60) + interval += abs(min([seconds, JsonMeta.SECONDS])) + diff = (end - self.start_time).seconds + + if diff > interval: + output_name = self.start_time.strftime('%Y-%m-%d %H-%M-%S') + '.json' + if not exists(output_dir): + makedirs(output_dir) + output = join(output_dir, output_name) + self.json_output(output_name=output) + self.frames = {} + self.start_time = datetime.now() + + def schedule_output_by_frames(self, frames_quota, frame_counter, output_dir=JsonMeta.PATH_TO_SAVE): + """ + saves as the number of frames quota increases higher. + :param frames_quota: + :param frame_counter: + :param output_dir: + :return: + """ + pass + + def flush(self, output_dir): + """ + Notes: + We use this function to output jsons whenever possible. + like the time that we exit the while loop of opencv. + + Args: + output_dir: + + Returns: + None + + """ + filename = self.start_time.strftime('%Y-%m-%d %H-%M-%S') + '-remaining.json' + output = join(output_dir, filename) + self.json_output(output_name=output) diff --git a/feeder/trackers/strongsort/utils/log.py b/feeder/trackers/strongsort/utils/log.py new file mode 100644 index 0000000..0d48757 --- /dev/null +++ b/feeder/trackers/strongsort/utils/log.py @@ -0,0 +1,17 @@ +import logging + + +def get_logger(name='root'): + formatter = logging.Formatter( + # fmt='%(asctime)s [%(levelname)s]: %(filename)s(%(funcName)s:%(lineno)s) >> %(message)s') + fmt='%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + + handler = logging.StreamHandler() + handler.setFormatter(formatter) + + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + logger.addHandler(handler) + return logger + + diff --git a/feeder/trackers/strongsort/utils/parser.py b/feeder/trackers/strongsort/utils/parser.py new file mode 100644 index 0000000..c29ed84 --- /dev/null +++ b/feeder/trackers/strongsort/utils/parser.py @@ -0,0 +1,41 @@ +import os +import yaml +from easydict import EasyDict as edict + + +class YamlParser(edict): + """ + This is yaml parser based on EasyDict. + """ + + def __init__(self, cfg_dict=None, config_file=None): + if cfg_dict is None: + cfg_dict = {} + + if config_file is not None: + assert(os.path.isfile(config_file)) + with open(config_file, 'r') as fo: + yaml_ = yaml.load(fo.read(), Loader=yaml.FullLoader) + cfg_dict.update(yaml_) + + super(YamlParser, self).__init__(cfg_dict) + + def merge_from_file(self, config_file): + with open(config_file, 'r') as fo: + yaml_ = yaml.load(fo.read(), Loader=yaml.FullLoader) + self.update(yaml_) + + def merge_from_dict(self, config_dict): + self.update(config_dict) + + +def get_config(config_file=None): + return YamlParser(config_file=config_file) + + +if __name__ == "__main__": + cfg = YamlParser(config_file="../configs/yolov3.yaml") + cfg.merge_from_file("../configs/strong_sort.yaml") + + import ipdb + ipdb.set_trace() diff --git a/feeder/trackers/strongsort/utils/tools.py b/feeder/trackers/strongsort/utils/tools.py new file mode 100644 index 0000000..965fb69 --- /dev/null +++ b/feeder/trackers/strongsort/utils/tools.py @@ -0,0 +1,39 @@ +from functools import wraps +from time import time + + +def is_video(ext: str): + """ + Returns true if ext exists in + allowed_exts for video files. + + Args: + ext: + + Returns: + + """ + + allowed_exts = ('.mp4', '.webm', '.ogg', '.avi', '.wmv', '.mkv', '.3gp') + return any((ext.endswith(x) for x in allowed_exts)) + + +def tik_tok(func): + """ + keep track of time for each process. + Args: + func: + + Returns: + + """ + @wraps(func) + def _time_it(*args, **kwargs): + start = time() + try: + return func(*args, **kwargs) + finally: + end_ = time() + print("time: {:.03f}s, fps: {:.03f}".format(end_ - start, 1 / (end_ - start))) + + return _time_it diff --git a/feeder/video/sample.mp4 b/feeder/video/sample.mp4 new file mode 100644 index 0000000..d89256d Binary files /dev/null and b/feeder/video/sample.mp4 differ diff --git a/feeder/video/sample.mp4-strongsort.log b/feeder/video/sample.mp4-strongsort.log new file mode 100644 index 0000000..2126c5e --- /dev/null +++ b/feeder/video/sample.mp4-strongsort.log @@ -0,0 +1,612 @@ +{"bbox": [1208, 574, 1312, 640], "id": 1, "cls": 2, "conf": 0.7392573952674866, "frame_idx": 2, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1206, 573, 1311, 639], "id": 1, "cls": 2, "conf": 0.7638279795646667, "frame_idx": 3, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1205, 573, 1310, 640], "id": 1, "cls": 2, "conf": 0.745888352394104, "frame_idx": 4, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1205, 572, 1310, 640], "id": 1, "cls": 2, "conf": 0.7273551821708679, "frame_idx": 5, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1204, 572, 1310, 641], "id": 1, "cls": 2, "conf": 0.7593294382095337, "frame_idx": 6, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1203, 571, 1309, 641], "id": 1, "cls": 2, "conf": 0.7566904425621033, "frame_idx": 7, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1202, 570, 1309, 642], "id": 1, "cls": 2, "conf": 0.7727674245834351, "frame_idx": 8, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1202, 570, 1308, 642], "id": 1, "cls": 2, "conf": 0.7940199375152588, "frame_idx": 9, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1200, 570, 1308, 642], "id": 1, "cls": 2, "conf": 0.7740529179573059, "frame_idx": 10, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1200, 570, 1308, 642], "id": 1, "cls": 2, "conf": 0.7652700543403625, "frame_idx": 11, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1201, 571, 1307, 642], "id": 1, "cls": 2, "conf": 0.8012721538543701, "frame_idx": 12, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1200, 570, 1309, 642], "id": 1, "cls": 2, "conf": 0.7976530194282532, "frame_idx": 13, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1199, 569, 1311, 643], "id": 1, "cls": 2, "conf": 0.812846302986145, "frame_idx": 14, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1198, 570, 1310, 643], "id": 1, "cls": 2, "conf": 0.8232163190841675, "frame_idx": 15, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1194, 569, 1309, 644], "id": 1, "cls": 2, "conf": 0.8198840022087097, "frame_idx": 16, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1195, 569, 1306, 643], "id": 1, "cls": 2, "conf": 0.7693840861320496, "frame_idx": 17, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1193, 569, 1305, 645], "id": 1, "cls": 2, "conf": 0.7881284356117249, "frame_idx": 18, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1192, 570, 1305, 645], "id": 1, "cls": 2, "conf": 0.8157638311386108, "frame_idx": 19, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1192, 570, 1305, 644], "id": 1, "cls": 2, "conf": 0.8246914744377136, "frame_idx": 20, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1190, 569, 1305, 645], "id": 1, "cls": 2, "conf": 0.828994631767273, "frame_idx": 21, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1190, 569, 1304, 644], "id": 1, "cls": 2, "conf": 0.8013927936553955, "frame_idx": 22, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1190, 568, 1303, 644], "id": 1, "cls": 2, "conf": 0.8276790380477905, "frame_idx": 23, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1188, 568, 1304, 645], "id": 1, "cls": 2, "conf": 0.8594380021095276, "frame_idx": 24, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1186, 568, 1304, 645], "id": 1, "cls": 2, "conf": 0.8706213235855103, "frame_idx": 25, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1187, 568, 1303, 644], "id": 1, "cls": 2, "conf": 0.8731331825256348, "frame_idx": 26, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1182, 568, 1303, 645], "id": 1, "cls": 2, "conf": 0.87749844789505, "frame_idx": 27, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1182, 569, 1302, 645], "id": 1, "cls": 2, "conf": 0.8746338486671448, "frame_idx": 28, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1181, 568, 1303, 646], "id": 1, "cls": 2, "conf": 0.8688514828681946, "frame_idx": 29, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1180, 569, 1301, 646], "id": 1, "cls": 2, "conf": 0.8689095973968506, "frame_idx": 30, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1179, 568, 1302, 647], "id": 1, "cls": 2, "conf": 0.8720865249633789, "frame_idx": 31, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1178, 568, 1301, 647], "id": 1, "cls": 2, "conf": 0.8609508275985718, "frame_idx": 32, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1177, 568, 1300, 647], "id": 1, "cls": 2, "conf": 0.8541733026504517, "frame_idx": 33, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1178, 569, 1299, 648], "id": 1, "cls": 2, "conf": 0.8305150270462036, "frame_idx": 34, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1177, 569, 1297, 647], "id": 1, "cls": 2, "conf": 0.8163544535636902, "frame_idx": 35, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1175, 568, 1298, 648], "id": 1, "cls": 2, "conf": 0.8103095293045044, "frame_idx": 36, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1174, 568, 1297, 648], "id": 1, "cls": 2, "conf": 0.8175411820411682, "frame_idx": 37, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1171, 569, 1297, 648], "id": 1, "cls": 2, "conf": 0.8210935592651367, "frame_idx": 38, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1171, 568, 1295, 648], "id": 1, "cls": 2, "conf": 0.8320956826210022, "frame_idx": 39, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1167, 568, 1294, 649], "id": 1, "cls": 2, "conf": 0.7790266275405884, "frame_idx": 40, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1166, 568, 1293, 648], "id": 1, "cls": 2, "conf": 0.7791686058044434, "frame_idx": 41, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1166, 568, 1292, 648], "id": 1, "cls": 2, "conf": 0.7617875933647156, "frame_idx": 42, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1164, 567, 1293, 649], "id": 1, "cls": 2, "conf": 0.7618439793586731, "frame_idx": 43, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1162, 567, 1293, 649], "id": 1, "cls": 2, "conf": 0.7654961347579956, "frame_idx": 44, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1161, 567, 1292, 649], "id": 1, "cls": 2, "conf": 0.7552655935287476, "frame_idx": 45, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1160, 568, 1290, 649], "id": 1, "cls": 2, "conf": 0.7659391164779663, "frame_idx": 46, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1158, 570, 1289, 650], "id": 1, "cls": 2, "conf": 0.7770782709121704, "frame_idx": 47, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1156, 569, 1290, 651], "id": 1, "cls": 2, "conf": 0.776265025138855, "frame_idx": 48, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1156, 568, 1289, 649], "id": 1, "cls": 2, "conf": 0.7784299850463867, "frame_idx": 49, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1153, 567, 1289, 650], "id": 1, "cls": 2, "conf": 0.7925119400024414, "frame_idx": 50, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1154, 568, 1290, 651], "id": 1, "cls": 2, "conf": 0.7904253005981445, "frame_idx": 51, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1152, 569, 1291, 651], "id": 1, "cls": 2, "conf": 0.7655163407325745, "frame_idx": 52, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1151, 569, 1291, 651], "id": 1, "cls": 2, "conf": 0.7518490552902222, "frame_idx": 53, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1149, 569, 1289, 652], "id": 1, "cls": 2, "conf": 0.7494193911552429, "frame_idx": 54, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1147, 570, 1289, 654], "id": 1, "cls": 2, "conf": 0.7891559600830078, "frame_idx": 55, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1145, 570, 1289, 655], "id": 1, "cls": 2, "conf": 0.7939369082450867, "frame_idx": 56, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1142, 569, 1289, 656], "id": 1, "cls": 2, "conf": 0.8129497170448303, "frame_idx": 57, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1141, 570, 1287, 656], "id": 1, "cls": 2, "conf": 0.8340080380439758, "frame_idx": 58, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1140, 569, 1288, 657], "id": 1, "cls": 2, "conf": 0.8393167853355408, "frame_idx": 59, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1141, 570, 1287, 657], "id": 1, "cls": 2, "conf": 0.8389145135879517, "frame_idx": 60, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1139, 569, 1285, 658], "id": 1, "cls": 2, "conf": 0.8342702388763428, "frame_idx": 61, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1138, 570, 1284, 658], "id": 1, "cls": 2, "conf": 0.8394166827201843, "frame_idx": 62, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1135, 569, 1284, 658], "id": 1, "cls": 2, "conf": 0.8471781611442566, "frame_idx": 63, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1131, 568, 1281, 659], "id": 1, "cls": 2, "conf": 0.8232806921005249, "frame_idx": 64, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1129, 568, 1279, 660], "id": 1, "cls": 2, "conf": 0.865515410900116, "frame_idx": 65, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1128, 569, 1282, 661], "id": 1, "cls": 2, "conf": 0.8378810882568359, "frame_idx": 66, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1126, 569, 1282, 661], "id": 1, "cls": 2, "conf": 0.8417340517044067, "frame_idx": 67, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1126, 569, 1281, 661], "id": 1, "cls": 2, "conf": 0.8533654808998108, "frame_idx": 68, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1125, 569, 1281, 660], "id": 1, "cls": 2, "conf": 0.8475178480148315, "frame_idx": 69, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1123, 569, 1280, 661], "id": 1, "cls": 2, "conf": 0.8625006675720215, "frame_idx": 70, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1120, 568, 1278, 662], "id": 1, "cls": 2, "conf": 0.8567495346069336, "frame_idx": 71, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1120, 569, 1276, 663], "id": 1, "cls": 2, "conf": 0.8443597555160522, "frame_idx": 72, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1118, 568, 1276, 663], "id": 1, "cls": 2, "conf": 0.8420413732528687, "frame_idx": 73, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1115, 567, 1276, 663], "id": 1, "cls": 2, "conf": 0.8549453020095825, "frame_idx": 74, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1115, 567, 1275, 664], "id": 1, "cls": 2, "conf": 0.8429552316665649, "frame_idx": 75, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1112, 567, 1273, 665], "id": 1, "cls": 2, "conf": 0.8485922813415527, "frame_idx": 76, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1111, 567, 1273, 666], "id": 1, "cls": 2, "conf": 0.8699796199798584, "frame_idx": 77, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1109, 565, 1273, 666], "id": 1, "cls": 2, "conf": 0.8823856115341187, "frame_idx": 78, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1107, 564, 1274, 667], "id": 1, "cls": 2, "conf": 0.8547831177711487, "frame_idx": 79, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1106, 565, 1271, 667], "id": 1, "cls": 2, "conf": 0.8556330800056458, "frame_idx": 80, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1105, 564, 1271, 667], "id": 1, "cls": 2, "conf": 0.8522816896438599, "frame_idx": 81, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1103, 562, 1271, 668], "id": 1, "cls": 2, "conf": 0.8402776718139648, "frame_idx": 82, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1098, 561, 1272, 669], "id": 1, "cls": 2, "conf": 0.849938154220581, "frame_idx": 83, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1095, 561, 1272, 669], "id": 1, "cls": 2, "conf": 0.8956634998321533, "frame_idx": 84, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1092, 561, 1272, 670], "id": 1, "cls": 2, "conf": 0.9015648365020752, "frame_idx": 85, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1093, 562, 1271, 670], "id": 1, "cls": 2, "conf": 0.8583961725234985, "frame_idx": 86, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1091, 562, 1271, 672], "id": 1, "cls": 2, "conf": 0.8442841172218323, "frame_idx": 87, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1089, 562, 1270, 672], "id": 1, "cls": 2, "conf": 0.8542094230651855, "frame_idx": 88, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1085, 560, 1267, 672], "id": 1, "cls": 2, "conf": 0.8753722310066223, "frame_idx": 89, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1081, 559, 1266, 673], "id": 1, "cls": 2, "conf": 0.8686020970344543, "frame_idx": 90, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1079, 558, 1266, 673], "id": 1, "cls": 2, "conf": 0.8676679134368896, "frame_idx": 91, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1075, 558, 1265, 674], "id": 1, "cls": 2, "conf": 0.8485567569732666, "frame_idx": 92, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1074, 558, 1264, 674], "id": 1, "cls": 2, "conf": 0.8431268334388733, "frame_idx": 93, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1074, 557, 1264, 674], "id": 1, "cls": 2, "conf": 0.8517748713493347, "frame_idx": 94, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1070, 559, 1262, 675], "id": 1, "cls": 2, "conf": 0.8630310297012329, "frame_idx": 95, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1068, 559, 1260, 676], "id": 1, "cls": 2, "conf": 0.8517524003982544, "frame_idx": 96, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1065, 557, 1260, 676], "id": 1, "cls": 2, "conf": 0.8309876918792725, "frame_idx": 97, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1062, 558, 1257, 676], "id": 1, "cls": 2, "conf": 0.820047914981842, "frame_idx": 98, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1058, 558, 1258, 680], "id": 1, "cls": 2, "conf": 0.8312326073646545, "frame_idx": 99, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1057, 557, 1255, 681], "id": 1, "cls": 2, "conf": 0.84773850440979, "frame_idx": 100, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1055, 558, 1253, 682], "id": 1, "cls": 2, "conf": 0.8278942108154297, "frame_idx": 101, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1052, 557, 1254, 682], "id": 1, "cls": 2, "conf": 0.8419964909553528, "frame_idx": 102, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1048, 554, 1253, 682], "id": 1, "cls": 2, "conf": 0.8698597550392151, "frame_idx": 103, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1045, 553, 1251, 683], "id": 1, "cls": 2, "conf": 0.8451534509658813, "frame_idx": 104, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1041, 553, 1250, 685], "id": 1, "cls": 2, "conf": 0.8478474617004395, "frame_idx": 105, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1037, 552, 1250, 685], "id": 1, "cls": 2, "conf": 0.8371977210044861, "frame_idx": 106, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1034, 552, 1249, 686], "id": 1, "cls": 2, "conf": 0.8587230443954468, "frame_idx": 107, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1032, 552, 1246, 687], "id": 1, "cls": 2, "conf": 0.8486429452896118, "frame_idx": 108, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1026, 552, 1246, 688], "id": 1, "cls": 2, "conf": 0.8577057123184204, "frame_idx": 109, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1024, 551, 1244, 687], "id": 1, "cls": 2, "conf": 0.847007155418396, "frame_idx": 110, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1020, 551, 1244, 689], "id": 1, "cls": 2, "conf": 0.8531818985939026, "frame_idx": 111, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1014, 550, 1245, 691], "id": 1, "cls": 2, "conf": 0.8777499794960022, "frame_idx": 112, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1011, 550, 1242, 692], "id": 1, "cls": 2, "conf": 0.8970717787742615, "frame_idx": 113, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1009, 550, 1241, 694], "id": 1, "cls": 2, "conf": 0.8887585401535034, "frame_idx": 114, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1007, 549, 1239, 695], "id": 1, "cls": 2, "conf": 0.8952226638793945, "frame_idx": 115, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1002, 549, 1240, 698], "id": 1, "cls": 2, "conf": 0.9019944667816162, "frame_idx": 116, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1000, 550, 1237, 699], "id": 1, "cls": 2, "conf": 0.8975278735160828, "frame_idx": 117, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [993, 549, 1237, 700], "id": 1, "cls": 2, "conf": 0.9004268646240234, "frame_idx": 118, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [988, 550, 1233, 701], "id": 1, "cls": 2, "conf": 0.8971960544586182, "frame_idx": 119, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [986, 549, 1231, 702], "id": 1, "cls": 2, "conf": 0.8989416360855103, "frame_idx": 120, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [980, 548, 1229, 704], "id": 1, "cls": 2, "conf": 0.889881432056427, "frame_idx": 121, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [975, 548, 1228, 708], "id": 1, "cls": 2, "conf": 0.8943332433700562, "frame_idx": 122, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [971, 548, 1228, 710], "id": 1, "cls": 2, "conf": 0.898472785949707, "frame_idx": 123, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [967, 547, 1226, 712], "id": 1, "cls": 2, "conf": 0.8931097388267517, "frame_idx": 124, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [963, 546, 1225, 713], "id": 1, "cls": 2, "conf": 0.8915606141090393, "frame_idx": 125, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [959, 546, 1223, 715], "id": 1, "cls": 2, "conf": 0.8841129541397095, "frame_idx": 126, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [955, 546, 1223, 717], "id": 1, "cls": 2, "conf": 0.850002646446228, "frame_idx": 127, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [950, 545, 1221, 718], "id": 1, "cls": 2, "conf": 0.8723787069320679, "frame_idx": 128, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [942, 544, 1220, 719], "id": 1, "cls": 2, "conf": 0.8795301914215088, "frame_idx": 129, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [940, 544, 1217, 720], "id": 1, "cls": 2, "conf": 0.8854840993881226, "frame_idx": 130, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [932, 543, 1217, 722], "id": 1, "cls": 2, "conf": 0.8812260031700134, "frame_idx": 131, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [927, 544, 1217, 725], "id": 1, "cls": 2, "conf": 0.8683909773826599, "frame_idx": 132, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [918, 543, 1216, 727], "id": 1, "cls": 2, "conf": 0.853493869304657, "frame_idx": 133, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [914, 543, 1214, 728], "id": 1, "cls": 2, "conf": 0.8531240224838257, "frame_idx": 134, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [908, 543, 1213, 730], "id": 1, "cls": 2, "conf": 0.8651628494262695, "frame_idx": 135, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [902, 542, 1209, 732], "id": 1, "cls": 2, "conf": 0.8718039989471436, "frame_idx": 136, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [894, 541, 1208, 735], "id": 1, "cls": 2, "conf": 0.848781943321228, "frame_idx": 137, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [888, 541, 1206, 736], "id": 1, "cls": 2, "conf": 0.8739963173866272, "frame_idx": 138, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [881, 541, 1204, 737], "id": 1, "cls": 2, "conf": 0.8722886443138123, "frame_idx": 139, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [872, 539, 1203, 738], "id": 1, "cls": 2, "conf": 0.8997212052345276, "frame_idx": 140, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [866, 539, 1200, 739], "id": 1, "cls": 2, "conf": 0.8821484446525574, "frame_idx": 141, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [860, 538, 1198, 744], "id": 1, "cls": 2, "conf": 0.8928354978561401, "frame_idx": 142, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [852, 536, 1197, 746], "id": 1, "cls": 2, "conf": 0.8943573832511902, "frame_idx": 143, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [843, 537, 1195, 748], "id": 1, "cls": 2, "conf": 0.8848525285720825, "frame_idx": 144, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [835, 536, 1194, 749], "id": 1, "cls": 2, "conf": 0.8749076724052429, "frame_idx": 145, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [826, 536, 1190, 751], "id": 1, "cls": 2, "conf": 0.8655844330787659, "frame_idx": 146, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [818, 538, 1186, 757], "id": 1, "cls": 2, "conf": 0.8978791236877441, "frame_idx": 147, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [810, 536, 1184, 759], "id": 1, "cls": 2, "conf": 0.9050822257995605, "frame_idx": 148, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [801, 533, 1181, 758], "id": 1, "cls": 2, "conf": 0.9211980104446411, "frame_idx": 149, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [791, 532, 1180, 762], "id": 1, "cls": 2, "conf": 0.9195648431777954, "frame_idx": 150, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [781, 530, 1177, 770], "id": 1, "cls": 2, "conf": 0.9223189353942871, "frame_idx": 151, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [769, 530, 1177, 772], "id": 1, "cls": 2, "conf": 0.9049766063690186, "frame_idx": 152, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [760, 528, 1175, 772], "id": 1, "cls": 2, "conf": 0.9004610776901245, "frame_idx": 153, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [749, 528, 1174, 776], "id": 1, "cls": 2, "conf": 0.9073677062988281, "frame_idx": 154, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [738, 526, 1171, 783], "id": 1, "cls": 2, "conf": 0.9120516777038574, "frame_idx": 155, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1254, 566, 1426, 643], "id": 2, "cls": 2, "conf": 0.702964186668396, "frame_idx": 155, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [725, 526, 1170, 785], "id": 1, "cls": 2, "conf": 0.9064223766326904, "frame_idx": 156, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1253, 568, 1422, 643], "id": 2, "cls": 2, "conf": 0.7038942575454712, "frame_idx": 156, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [712, 527, 1165, 789], "id": 1, "cls": 2, "conf": 0.9063256978988647, "frame_idx": 157, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1252, 568, 1421, 643], "id": 2, "cls": 2, "conf": 0.7038942575454712, "frame_idx": 157, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [699, 524, 1160, 793], "id": 1, "cls": 2, "conf": 0.8908406496047974, "frame_idx": 158, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [685, 524, 1159, 795], "id": 1, "cls": 2, "conf": 0.8844937682151794, "frame_idx": 159, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [673, 525, 1156, 799], "id": 1, "cls": 2, "conf": 0.8897193670272827, "frame_idx": 160, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [659, 524, 1152, 802], "id": 1, "cls": 2, "conf": 0.905559241771698, "frame_idx": 161, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [644, 522, 1149, 809], "id": 1, "cls": 2, "conf": 0.89296555519104, "frame_idx": 162, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [628, 522, 1146, 820], "id": 1, "cls": 2, "conf": 0.8848194479942322, "frame_idx": 163, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1242, 567, 1420, 642], "id": 2, "cls": 2, "conf": 0.717244029045105, "frame_idx": 163, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [611, 519, 1145, 821], "id": 1, "cls": 2, "conf": 0.9121138453483582, "frame_idx": 164, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1242, 568, 1418, 643], "id": 2, "cls": 2, "conf": 0.733672559261322, "frame_idx": 164, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [594, 520, 1141, 827], "id": 1, "cls": 2, "conf": 0.890241801738739, "frame_idx": 165, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1245, 569, 1416, 642], "id": 2, "cls": 2, "conf": 0.7150111794471741, "frame_idx": 165, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [574, 519, 1136, 832], "id": 1, "cls": 2, "conf": 0.9198168516159058, "frame_idx": 166, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1244, 569, 1415, 642], "id": 2, "cls": 2, "conf": 0.7150111794471741, "frame_idx": 166, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [555, 518, 1133, 839], "id": 1, "cls": 2, "conf": 0.9146777987480164, "frame_idx": 167, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [537, 515, 1129, 845], "id": 1, "cls": 2, "conf": 0.9021809101104736, "frame_idx": 168, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [516, 513, 1127, 854], "id": 1, "cls": 2, "conf": 0.9111503958702087, "frame_idx": 169, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [495, 510, 1126, 863], "id": 1, "cls": 2, "conf": 0.9124228954315186, "frame_idx": 170, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [471, 512, 1121, 872], "id": 1, "cls": 2, "conf": 0.9291900396347046, "frame_idx": 171, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [447, 509, 1116, 875], "id": 1, "cls": 2, "conf": 0.8657183051109314, "frame_idx": 172, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [423, 506, 1111, 881], "id": 1, "cls": 2, "conf": 0.8687337636947632, "frame_idx": 173, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [393, 505, 1105, 893], "id": 1, "cls": 2, "conf": 0.9182578921318054, "frame_idx": 174, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [368, 503, 1101, 899], "id": 1, "cls": 2, "conf": 0.9256529808044434, "frame_idx": 175, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [340, 502, 1096, 912], "id": 1, "cls": 2, "conf": 0.9282132983207703, "frame_idx": 176, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [303, 500, 1091, 924], "id": 1, "cls": 2, "conf": 0.9329380989074707, "frame_idx": 177, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [274, 499, 1087, 937], "id": 1, "cls": 2, "conf": 0.9455896019935608, "frame_idx": 178, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [233, 498, 1083, 946], "id": 1, "cls": 2, "conf": 0.9385244846343994, "frame_idx": 179, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [194, 496, 1077, 960], "id": 1, "cls": 2, "conf": 0.9393031001091003, "frame_idx": 180, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [153, 495, 1076, 972], "id": 1, "cls": 2, "conf": 0.9307792782783508, "frame_idx": 181, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [110, 492, 1067, 988], "id": 1, "cls": 2, "conf": 0.9395390748977661, "frame_idx": 182, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [57, 493, 1060, 1008], "id": 1, "cls": 2, "conf": 0.9405025243759155, "frame_idx": 183, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [4, 492, 1053, 1029], "id": 1, "cls": 2, "conf": 0.9425285458564758, "frame_idx": 184, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 490, 1047, 1043], "id": 1, "cls": 2, "conf": 0.9343565106391907, "frame_idx": 185, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 488, 1043, 1061], "id": 1, "cls": 2, "conf": 0.9273869395256042, "frame_idx": 186, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 484, 1035, 1071], "id": 1, "cls": 2, "conf": 0.9321094751358032, "frame_idx": 187, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 475, 1030, 1071], "id": 1, "cls": 2, "conf": 0.9317752122879028, "frame_idx": 188, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 473, 1025, 1073], "id": 1, "cls": 2, "conf": 0.9486481547355652, "frame_idx": 189, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1210, 567, 1396, 640], "id": 2, "cls": 2, "conf": 0.7311104536056519, "frame_idx": 189, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 472, 1016, 1073], "id": 1, "cls": 2, "conf": 0.952238917350769, "frame_idx": 190, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1211, 569, 1397, 642], "id": 2, "cls": 2, "conf": 0.7499367594718933, "frame_idx": 190, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 463, 1008, 1070], "id": 1, "cls": 2, "conf": 0.9457194209098816, "frame_idx": 191, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1219, 570, 1396, 641], "id": 2, "cls": 2, "conf": 0.7276124954223633, "frame_idx": 191, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 454, 1001, 1071], "id": 1, "cls": 2, "conf": 0.9511743187904358, "frame_idx": 192, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1218, 570, 1396, 641], "id": 2, "cls": 2, "conf": 0.7206576466560364, "frame_idx": 192, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 450, 994, 1069], "id": 1, "cls": 2, "conf": 0.9420279264450073, "frame_idx": 193, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1214, 570, 1395, 642], "id": 2, "cls": 2, "conf": 0.7134021520614624, "frame_idx": 193, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 446, 985, 1067], "id": 1, "cls": 2, "conf": 0.9500812292098999, "frame_idx": 194, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1215, 570, 1393, 642], "id": 2, "cls": 2, "conf": 0.7069892287254333, "frame_idx": 194, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 442, 976, 1066], "id": 1, "cls": 2, "conf": 0.9406448006629944, "frame_idx": 195, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1203, 568, 1391, 642], "id": 2, "cls": 2, "conf": 0.7376792430877686, "frame_idx": 195, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 441, 968, 1069], "id": 1, "cls": 2, "conf": 0.9537635445594788, "frame_idx": 196, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1202, 567, 1391, 642], "id": 2, "cls": 2, "conf": 0.7550773024559021, "frame_idx": 196, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 440, 960, 1069], "id": 1, "cls": 2, "conf": 0.9586692452430725, "frame_idx": 197, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1200, 566, 1392, 642], "id": 2, "cls": 2, "conf": 0.7765669822692871, "frame_idx": 197, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 431, 950, 1069], "id": 1, "cls": 2, "conf": 0.9550426006317139, "frame_idx": 198, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1198, 565, 1393, 643], "id": 2, "cls": 2, "conf": 0.7722377777099609, "frame_idx": 198, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 424, 938, 1065], "id": 1, "cls": 2, "conf": 0.9508339762687683, "frame_idx": 199, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1196, 565, 1392, 643], "id": 2, "cls": 2, "conf": 0.751980185508728, "frame_idx": 199, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 419, 927, 1065], "id": 1, "cls": 2, "conf": 0.9454301595687866, "frame_idx": 200, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1196, 566, 1392, 643], "id": 2, "cls": 2, "conf": 0.7461082935333252, "frame_idx": 200, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 413, 916, 1065], "id": 1, "cls": 2, "conf": 0.957693874835968, "frame_idx": 201, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1194, 565, 1392, 644], "id": 2, "cls": 2, "conf": 0.7643528580665588, "frame_idx": 201, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 407, 905, 1065], "id": 1, "cls": 2, "conf": 0.945280134677887, "frame_idx": 202, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1196, 565, 1392, 644], "id": 2, "cls": 2, "conf": 0.7613423466682434, "frame_idx": 202, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 409, 890, 1065], "id": 1, "cls": 2, "conf": 0.9535142183303833, "frame_idx": 203, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1194, 565, 1391, 644], "id": 2, "cls": 2, "conf": 0.7633638978004456, "frame_idx": 203, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 400, 875, 1065], "id": 1, "cls": 2, "conf": 0.9448526501655579, "frame_idx": 204, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1192, 565, 1391, 644], "id": 2, "cls": 2, "conf": 0.7550344467163086, "frame_idx": 204, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 395, 863, 1064], "id": 1, "cls": 2, "conf": 0.9526091814041138, "frame_idx": 205, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1192, 565, 1390, 644], "id": 2, "cls": 2, "conf": 0.7387273907661438, "frame_idx": 205, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 391, 851, 1062], "id": 1, "cls": 2, "conf": 0.9561181664466858, "frame_idx": 206, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1191, 565, 1390, 644], "id": 2, "cls": 2, "conf": 0.7227319478988647, "frame_idx": 206, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 385, 830, 1059], "id": 1, "cls": 2, "conf": 0.9433083534240723, "frame_idx": 207, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1189, 565, 1388, 644], "id": 2, "cls": 2, "conf": 0.703997015953064, "frame_idx": 207, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 369, 812, 1064], "id": 1, "cls": 2, "conf": 0.9332630634307861, "frame_idx": 208, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1192, 566, 1387, 644], "id": 2, "cls": 2, "conf": 0.7098210453987122, "frame_idx": 208, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 364, 792, 1067], "id": 1, "cls": 2, "conf": 0.945813775062561, "frame_idx": 209, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1189, 565, 1388, 644], "id": 2, "cls": 2, "conf": 0.7005091905593872, "frame_idx": 209, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 354, 774, 1068], "id": 1, "cls": 2, "conf": 0.9388237595558167, "frame_idx": 210, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1187, 565, 1385, 643], "id": 2, "cls": 2, "conf": 0.7079640030860901, "frame_idx": 210, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 351, 755, 1070], "id": 1, "cls": 2, "conf": 0.9397347569465637, "frame_idx": 211, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1185, 564, 1385, 644], "id": 2, "cls": 2, "conf": 0.7079640030860901, "frame_idx": 211, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 350, 729, 1068], "id": 1, "cls": 2, "conf": 0.949310839176178, "frame_idx": 212, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1183, 564, 1381, 643], "id": 2, "cls": 2, "conf": 0.7306272983551025, "frame_idx": 212, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 350, 703, 1068], "id": 1, "cls": 2, "conf": 0.9424352645874023, "frame_idx": 213, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1183, 564, 1383, 643], "id": 2, "cls": 2, "conf": 0.7504119873046875, "frame_idx": 213, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 350, 679, 1066], "id": 1, "cls": 2, "conf": 0.9429755806922913, "frame_idx": 214, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1181, 565, 1377, 644], "id": 2, "cls": 2, "conf": 0.7851810455322266, "frame_idx": 214, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 354, 650, 1069], "id": 1, "cls": 2, "conf": 0.9048929214477539, "frame_idx": 215, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1181, 565, 1378, 643], "id": 2, "cls": 2, "conf": 0.7938785552978516, "frame_idx": 215, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 378, 620, 1070], "id": 1, "cls": 2, "conf": 0.9180529713630676, "frame_idx": 216, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1182, 566, 1376, 643], "id": 2, "cls": 2, "conf": 0.7817256450653076, "frame_idx": 216, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 395, 588, 1069], "id": 1, "cls": 2, "conf": 0.9412034749984741, "frame_idx": 217, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1182, 565, 1374, 644], "id": 2, "cls": 2, "conf": 0.8047704100608826, "frame_idx": 217, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 427, 551, 1071], "id": 1, "cls": 2, "conf": 0.9319164752960205, "frame_idx": 218, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1182, 565, 1375, 643], "id": 2, "cls": 2, "conf": 0.7836374640464783, "frame_idx": 218, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 453, 510, 1072], "id": 1, "cls": 2, "conf": 0.9232752919197083, "frame_idx": 219, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1181, 566, 1371, 642], "id": 2, "cls": 2, "conf": 0.8103419542312622, "frame_idx": 219, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 503, 467, 1071], "id": 1, "cls": 2, "conf": 0.904760479927063, "frame_idx": 220, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1179, 566, 1371, 642], "id": 2, "cls": 2, "conf": 0.8125634789466858, "frame_idx": 220, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 549, 418, 1070], "id": 1, "cls": 2, "conf": 0.9279927611351013, "frame_idx": 221, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1179, 566, 1376, 642], "id": 2, "cls": 2, "conf": 0.8272838592529297, "frame_idx": 221, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 583, 363, 1068], "id": 1, "cls": 2, "conf": 0.9242643117904663, "frame_idx": 222, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1178, 565, 1374, 642], "id": 2, "cls": 2, "conf": 0.8221709132194519, "frame_idx": 222, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 593, 303, 1068], "id": 1, "cls": 2, "conf": 0.9143214821815491, "frame_idx": 223, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1177, 565, 1375, 644], "id": 2, "cls": 2, "conf": 0.8016420602798462, "frame_idx": 223, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 600, 238, 1069], "id": 1, "cls": 2, "conf": 0.8708683252334595, "frame_idx": 224, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1177, 565, 1376, 644], "id": 2, "cls": 2, "conf": 0.7917031645774841, "frame_idx": 224, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 616, 197, 1069], "id": 1, "cls": 2, "conf": 0.8708683252334595, "frame_idx": 225, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1178, 565, 1376, 643], "id": 2, "cls": 2, "conf": 0.78056401014328, "frame_idx": 225, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1177, 564, 1377, 644], "id": 2, "cls": 2, "conf": 0.7785735130310059, "frame_idx": 226, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1176, 565, 1370, 644], "id": 2, "cls": 2, "conf": 0.7929512858390808, "frame_idx": 227, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1174, 564, 1371, 645], "id": 2, "cls": 2, "conf": 0.8178865909576416, "frame_idx": 228, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1173, 564, 1371, 645], "id": 2, "cls": 2, "conf": 0.8109760284423828, "frame_idx": 229, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1173, 565, 1370, 645], "id": 2, "cls": 2, "conf": 0.7563623189926147, "frame_idx": 230, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1174, 565, 1370, 645], "id": 2, "cls": 2, "conf": 0.7083349227905273, "frame_idx": 231, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1173, 565, 1368, 645], "id": 2, "cls": 2, "conf": 0.7430815100669861, "frame_idx": 232, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1172, 564, 1359, 643], "id": 2, "cls": 2, "conf": 0.7816348075866699, "frame_idx": 233, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1171, 565, 1356, 642], "id": 2, "cls": 2, "conf": 0.8003019094467163, "frame_idx": 234, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1171, 563, 1360, 644], "id": 2, "cls": 2, "conf": 0.8223402500152588, "frame_idx": 235, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1169, 562, 1362, 645], "id": 2, "cls": 2, "conf": 0.8306653499603271, "frame_idx": 236, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1168, 562, 1359, 645], "id": 2, "cls": 2, "conf": 0.8245570659637451, "frame_idx": 237, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1170, 563, 1359, 645], "id": 2, "cls": 2, "conf": 0.818155825138092, "frame_idx": 238, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1168, 563, 1360, 645], "id": 2, "cls": 2, "conf": 0.8151793479919434, "frame_idx": 239, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1166, 564, 1357, 645], "id": 2, "cls": 2, "conf": 0.8082919120788574, "frame_idx": 240, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1165, 564, 1356, 645], "id": 2, "cls": 2, "conf": 0.8219642043113708, "frame_idx": 241, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1164, 564, 1353, 645], "id": 2, "cls": 2, "conf": 0.7999997138977051, "frame_idx": 242, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1164, 564, 1352, 645], "id": 2, "cls": 2, "conf": 0.7364180088043213, "frame_idx": 243, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1165, 565, 1349, 645], "id": 2, "cls": 2, "conf": 0.7858971357345581, "frame_idx": 244, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1164, 564, 1354, 646], "id": 2, "cls": 2, "conf": 0.7886779308319092, "frame_idx": 245, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1164, 564, 1348, 646], "id": 2, "cls": 2, "conf": 0.818172812461853, "frame_idx": 246, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1163, 564, 1348, 646], "id": 2, "cls": 2, "conf": 0.8523472547531128, "frame_idx": 247, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1164, 564, 1348, 645], "id": 2, "cls": 2, "conf": 0.8364881873130798, "frame_idx": 248, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1163, 563, 1346, 646], "id": 2, "cls": 2, "conf": 0.8150932788848877, "frame_idx": 249, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1163, 564, 1346, 646], "id": 2, "cls": 2, "conf": 0.8284506797790527, "frame_idx": 250, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1163, 563, 1347, 645], "id": 2, "cls": 2, "conf": 0.8243890404701233, "frame_idx": 251, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1163, 564, 1344, 646], "id": 2, "cls": 2, "conf": 0.848281741142273, "frame_idx": 252, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1163, 563, 1341, 646], "id": 2, "cls": 2, "conf": 0.8477445840835571, "frame_idx": 253, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1162, 563, 1339, 648], "id": 2, "cls": 2, "conf": 0.8400436043739319, "frame_idx": 254, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1161, 561, 1336, 647], "id": 2, "cls": 2, "conf": 0.7861170768737793, "frame_idx": 255, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1161, 562, 1338, 649], "id": 2, "cls": 2, "conf": 0.8120461702346802, "frame_idx": 256, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1161, 562, 1336, 648], "id": 2, "cls": 2, "conf": 0.7770818471908569, "frame_idx": 257, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1161, 561, 1332, 648], "id": 2, "cls": 2, "conf": 0.7602912187576294, "frame_idx": 258, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1159, 560, 1331, 649], "id": 2, "cls": 2, "conf": 0.7476798295974731, "frame_idx": 259, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1159, 560, 1330, 649], "id": 2, "cls": 2, "conf": 0.7798804640769958, "frame_idx": 260, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1159, 560, 1328, 649], "id": 2, "cls": 2, "conf": 0.7794782519340515, "frame_idx": 261, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1159, 561, 1328, 649], "id": 2, "cls": 2, "conf": 0.7535544037818909, "frame_idx": 262, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1159, 561, 1326, 649], "id": 2, "cls": 2, "conf": 0.7481237649917603, "frame_idx": 263, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1159, 561, 1325, 647], "id": 2, "cls": 2, "conf": 0.7650920152664185, "frame_idx": 264, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1158, 562, 1324, 647], "id": 2, "cls": 2, "conf": 0.8215755224227905, "frame_idx": 265, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1159, 561, 1324, 647], "id": 2, "cls": 2, "conf": 0.8252439498901367, "frame_idx": 266, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1158, 561, 1323, 648], "id": 2, "cls": 2, "conf": 0.8128286004066467, "frame_idx": 267, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1157, 560, 1323, 649], "id": 2, "cls": 2, "conf": 0.8222718238830566, "frame_idx": 268, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1158, 560, 1323, 649], "id": 2, "cls": 2, "conf": 0.8110289573669434, "frame_idx": 269, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1159, 560, 1323, 649], "id": 2, "cls": 2, "conf": 0.8318296074867249, "frame_idx": 270, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1159, 561, 1321, 649], "id": 2, "cls": 2, "conf": 0.8325403332710266, "frame_idx": 271, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1159, 560, 1323, 650], "id": 2, "cls": 2, "conf": 0.8335207104682922, "frame_idx": 272, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1159, 560, 1321, 650], "id": 2, "cls": 2, "conf": 0.8333126902580261, "frame_idx": 273, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1158, 561, 1320, 650], "id": 2, "cls": 2, "conf": 0.8144757151603699, "frame_idx": 274, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1159, 561, 1319, 650], "id": 2, "cls": 2, "conf": 0.809233546257019, "frame_idx": 275, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1160, 561, 1317, 650], "id": 2, "cls": 2, "conf": 0.7907527685165405, "frame_idx": 276, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1159, 560, 1318, 650], "id": 2, "cls": 2, "conf": 0.8115890026092529, "frame_idx": 277, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1158, 559, 1317, 651], "id": 2, "cls": 2, "conf": 0.7833464741706848, "frame_idx": 278, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1158, 559, 1317, 651], "id": 2, "cls": 2, "conf": 0.7954601645469666, "frame_idx": 279, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1159, 559, 1317, 651], "id": 2, "cls": 2, "conf": 0.774968683719635, "frame_idx": 280, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1159, 559, 1316, 651], "id": 2, "cls": 2, "conf": 0.7699628472328186, "frame_idx": 281, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1158, 559, 1316, 651], "id": 2, "cls": 2, "conf": 0.7739447951316833, "frame_idx": 282, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1159, 559, 1315, 650], "id": 2, "cls": 2, "conf": 0.803051769733429, "frame_idx": 283, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1158, 558, 1312, 652], "id": 2, "cls": 2, "conf": 0.810187041759491, "frame_idx": 284, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1157, 557, 1311, 653], "id": 2, "cls": 2, "conf": 0.8035591840744019, "frame_idx": 285, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1157, 558, 1311, 653], "id": 2, "cls": 2, "conf": 0.8188391923904419, "frame_idx": 286, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1156, 558, 1311, 653], "id": 2, "cls": 2, "conf": 0.8180844187736511, "frame_idx": 287, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1157, 559, 1310, 653], "id": 2, "cls": 2, "conf": 0.8250501155853271, "frame_idx": 288, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1156, 559, 1309, 654], "id": 2, "cls": 2, "conf": 0.8236573338508606, "frame_idx": 289, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1156, 559, 1308, 654], "id": 2, "cls": 2, "conf": 0.8105210661888123, "frame_idx": 290, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1157, 560, 1307, 654], "id": 2, "cls": 2, "conf": 0.8106025457382202, "frame_idx": 291, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1155, 560, 1307, 655], "id": 2, "cls": 2, "conf": 0.788083016872406, "frame_idx": 292, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1157, 560, 1305, 654], "id": 2, "cls": 2, "conf": 0.7796603441238403, "frame_idx": 293, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1157, 560, 1304, 655], "id": 2, "cls": 2, "conf": 0.7901594638824463, "frame_idx": 294, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1155, 560, 1305, 656], "id": 2, "cls": 2, "conf": 0.7907295823097229, "frame_idx": 295, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1156, 560, 1303, 655], "id": 2, "cls": 2, "conf": 0.7933876514434814, "frame_idx": 296, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1157, 559, 1301, 655], "id": 2, "cls": 2, "conf": 0.7832263708114624, "frame_idx": 297, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1156, 559, 1301, 656], "id": 2, "cls": 2, "conf": 0.795276403427124, "frame_idx": 298, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1155, 559, 1301, 656], "id": 2, "cls": 2, "conf": 0.8082300424575806, "frame_idx": 299, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1155, 560, 1299, 656], "id": 2, "cls": 2, "conf": 0.7965103387832642, "frame_idx": 300, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1154, 560, 1300, 657], "id": 2, "cls": 2, "conf": 0.8124801516532898, "frame_idx": 301, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1153, 560, 1300, 657], "id": 2, "cls": 2, "conf": 0.8144661784172058, "frame_idx": 302, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1153, 561, 1299, 658], "id": 2, "cls": 2, "conf": 0.8181474208831787, "frame_idx": 303, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1152, 561, 1298, 658], "id": 2, "cls": 2, "conf": 0.8187706470489502, "frame_idx": 304, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1154, 560, 1298, 656], "id": 2, "cls": 2, "conf": 0.8268204927444458, "frame_idx": 305, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1155, 560, 1297, 655], "id": 2, "cls": 2, "conf": 0.8292365074157715, "frame_idx": 306, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1155, 560, 1295, 656], "id": 2, "cls": 2, "conf": 0.8298918008804321, "frame_idx": 307, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1154, 559, 1297, 657], "id": 2, "cls": 2, "conf": 0.8282919526100159, "frame_idx": 308, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1155, 559, 1298, 657], "id": 2, "cls": 2, "conf": 0.8358256816864014, "frame_idx": 309, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1154, 559, 1297, 657], "id": 2, "cls": 2, "conf": 0.8314154744148254, "frame_idx": 310, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1154, 559, 1297, 657], "id": 2, "cls": 2, "conf": 0.8324777483940125, "frame_idx": 311, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1154, 560, 1294, 657], "id": 2, "cls": 2, "conf": 0.8399393558502197, "frame_idx": 312, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1153, 559, 1295, 658], "id": 2, "cls": 2, "conf": 0.8377672433853149, "frame_idx": 313, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1152, 559, 1294, 658], "id": 2, "cls": 2, "conf": 0.8295931816101074, "frame_idx": 314, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1151, 559, 1293, 658], "id": 2, "cls": 2, "conf": 0.8257358074188232, "frame_idx": 315, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1151, 559, 1292, 658], "id": 2, "cls": 2, "conf": 0.8370307087898254, "frame_idx": 316, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1151, 560, 1291, 658], "id": 2, "cls": 2, "conf": 0.818547785282135, "frame_idx": 317, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1150, 559, 1292, 659], "id": 2, "cls": 2, "conf": 0.7911444306373596, "frame_idx": 318, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1150, 559, 1292, 659], "id": 2, "cls": 2, "conf": 0.7788093686103821, "frame_idx": 319, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1150, 559, 1293, 659], "id": 2, "cls": 2, "conf": 0.7597206830978394, "frame_idx": 320, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1150, 560, 1291, 659], "id": 2, "cls": 2, "conf": 0.7717625498771667, "frame_idx": 321, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1148, 559, 1291, 660], "id": 2, "cls": 2, "conf": 0.7833176255226135, "frame_idx": 322, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1148, 559, 1292, 660], "id": 2, "cls": 2, "conf": 0.7886781096458435, "frame_idx": 323, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1148, 559, 1292, 660], "id": 2, "cls": 2, "conf": 0.7795507311820984, "frame_idx": 324, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1149, 560, 1291, 660], "id": 2, "cls": 2, "conf": 0.7811378240585327, "frame_idx": 325, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1148, 560, 1291, 661], "id": 2, "cls": 2, "conf": 0.7874495387077332, "frame_idx": 326, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1149, 560, 1290, 662], "id": 2, "cls": 2, "conf": 0.8070158958435059, "frame_idx": 327, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1147, 560, 1291, 664], "id": 2, "cls": 2, "conf": 0.8095881342887878, "frame_idx": 328, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1146, 560, 1290, 663], "id": 2, "cls": 2, "conf": 0.8032857775688171, "frame_idx": 329, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1145, 560, 1290, 664], "id": 2, "cls": 2, "conf": 0.826309084892273, "frame_idx": 330, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1145, 560, 1291, 665], "id": 2, "cls": 2, "conf": 0.799944281578064, "frame_idx": 331, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1145, 561, 1290, 665], "id": 2, "cls": 2, "conf": 0.7787960767745972, "frame_idx": 332, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1144, 560, 1290, 665], "id": 2, "cls": 2, "conf": 0.7718071937561035, "frame_idx": 333, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1142, 559, 1291, 666], "id": 2, "cls": 2, "conf": 0.7858945727348328, "frame_idx": 334, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1143, 559, 1290, 665], "id": 2, "cls": 2, "conf": 0.809407114982605, "frame_idx": 335, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1142, 559, 1290, 666], "id": 2, "cls": 2, "conf": 0.8050354719161987, "frame_idx": 336, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1141, 559, 1289, 666], "id": 2, "cls": 2, "conf": 0.8001269102096558, "frame_idx": 337, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1140, 558, 1289, 667], "id": 2, "cls": 2, "conf": 0.8002896308898926, "frame_idx": 338, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1140, 559, 1288, 667], "id": 2, "cls": 2, "conf": 0.8237987160682678, "frame_idx": 339, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1139, 558, 1289, 667], "id": 2, "cls": 2, "conf": 0.8150033950805664, "frame_idx": 340, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1136, 558, 1291, 667], "id": 2, "cls": 2, "conf": 0.7948818802833557, "frame_idx": 341, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1138, 559, 1289, 668], "id": 2, "cls": 2, "conf": 0.8127124905586243, "frame_idx": 342, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1136, 558, 1290, 668], "id": 2, "cls": 2, "conf": 0.8126155138015747, "frame_idx": 343, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1135, 558, 1290, 668], "id": 2, "cls": 2, "conf": 0.8102937936782837, "frame_idx": 344, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1136, 558, 1290, 668], "id": 2, "cls": 2, "conf": 0.7925915718078613, "frame_idx": 345, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1138, 559, 1288, 669], "id": 2, "cls": 2, "conf": 0.7755674123764038, "frame_idx": 346, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1136, 558, 1288, 670], "id": 2, "cls": 2, "conf": 0.7737069129943848, "frame_idx": 347, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1136, 558, 1286, 669], "id": 2, "cls": 2, "conf": 0.7875550389289856, "frame_idx": 348, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1131, 557, 1286, 670], "id": 2, "cls": 2, "conf": 0.7827519178390503, "frame_idx": 349, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1131, 556, 1286, 670], "id": 2, "cls": 2, "conf": 0.7984418272972107, "frame_idx": 350, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1130, 555, 1287, 671], "id": 2, "cls": 2, "conf": 0.7734009027481079, "frame_idx": 351, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1130, 556, 1285, 671], "id": 2, "cls": 2, "conf": 0.7766426205635071, "frame_idx": 352, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1128, 555, 1286, 672], "id": 2, "cls": 2, "conf": 0.7817273139953613, "frame_idx": 353, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1131, 555, 1284, 671], "id": 2, "cls": 2, "conf": 0.7750544548034668, "frame_idx": 354, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1128, 554, 1287, 672], "id": 2, "cls": 2, "conf": 0.7669058442115784, "frame_idx": 355, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1130, 555, 1284, 672], "id": 2, "cls": 2, "conf": 0.7651919722557068, "frame_idx": 356, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1128, 554, 1283, 672], "id": 2, "cls": 2, "conf": 0.7686755061149597, "frame_idx": 357, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1126, 553, 1284, 673], "id": 2, "cls": 2, "conf": 0.7569704055786133, "frame_idx": 358, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1126, 554, 1283, 673], "id": 2, "cls": 2, "conf": 0.788491427898407, "frame_idx": 359, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1123, 553, 1285, 673], "id": 2, "cls": 2, "conf": 0.796739935874939, "frame_idx": 360, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1124, 553, 1284, 674], "id": 2, "cls": 2, "conf": 0.7600229382514954, "frame_idx": 361, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1122, 552, 1285, 675], "id": 2, "cls": 2, "conf": 0.7608688473701477, "frame_idx": 362, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1121, 553, 1285, 676], "id": 2, "cls": 2, "conf": 0.7610014081001282, "frame_idx": 363, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1120, 552, 1285, 675], "id": 2, "cls": 2, "conf": 0.7238069772720337, "frame_idx": 364, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1119, 553, 1284, 675], "id": 2, "cls": 2, "conf": 0.789625883102417, "frame_idx": 365, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1118, 552, 1283, 675], "id": 2, "cls": 2, "conf": 0.7700904607772827, "frame_idx": 366, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1117, 552, 1282, 677], "id": 2, "cls": 2, "conf": 0.7024756669998169, "frame_idx": 367, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1116, 550, 1282, 677], "id": 2, "cls": 2, "conf": 0.7285512685775757, "frame_idx": 368, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1115, 549, 1281, 675], "id": 2, "cls": 2, "conf": 0.7092558145523071, "frame_idx": 369, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1113, 549, 1282, 675], "id": 2, "cls": 2, "conf": 0.7147558331489563, "frame_idx": 370, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1114, 548, 1280, 675], "id": 2, "cls": 2, "conf": 0.7318784594535828, "frame_idx": 371, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1112, 549, 1279, 676], "id": 2, "cls": 2, "conf": 0.7841340899467468, "frame_idx": 372, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1113, 549, 1278, 675], "id": 2, "cls": 2, "conf": 0.7626461386680603, "frame_idx": 373, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1111, 550, 1278, 677], "id": 2, "cls": 2, "conf": 0.7657148241996765, "frame_idx": 374, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1108, 550, 1280, 677], "id": 2, "cls": 2, "conf": 0.7782973647117615, "frame_idx": 375, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1108, 550, 1280, 677], "id": 2, "cls": 2, "conf": 0.7754068970680237, "frame_idx": 376, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1107, 551, 1279, 677], "id": 2, "cls": 2, "conf": 0.7901440858840942, "frame_idx": 377, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1105, 550, 1280, 678], "id": 2, "cls": 2, "conf": 0.811150848865509, "frame_idx": 378, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1105, 550, 1279, 678], "id": 2, "cls": 2, "conf": 0.7904564142227173, "frame_idx": 379, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1105, 550, 1278, 678], "id": 2, "cls": 2, "conf": 0.7392836809158325, "frame_idx": 380, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1104, 548, 1279, 678], "id": 2, "cls": 2, "conf": 0.7411684989929199, "frame_idx": 381, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1104, 551, 1277, 680], "id": 2, "cls": 2, "conf": 0.7404786944389343, "frame_idx": 382, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1102, 550, 1276, 680], "id": 2, "cls": 2, "conf": 0.7326121926307678, "frame_idx": 383, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1102, 550, 1277, 681], "id": 2, "cls": 2, "conf": 0.7641636729240417, "frame_idx": 384, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1101, 549, 1276, 681], "id": 2, "cls": 2, "conf": 0.7742770314216614, "frame_idx": 385, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1099, 549, 1276, 682], "id": 2, "cls": 2, "conf": 0.7556547522544861, "frame_idx": 386, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1098, 548, 1277, 682], "id": 2, "cls": 2, "conf": 0.702316164970398, "frame_idx": 387, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1096, 548, 1275, 683], "id": 2, "cls": 2, "conf": 0.7168530225753784, "frame_idx": 388, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1093, 547, 1273, 684], "id": 2, "cls": 2, "conf": 0.7561923265457153, "frame_idx": 389, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1093, 548, 1275, 684], "id": 2, "cls": 2, "conf": 0.7371773719787598, "frame_idx": 390, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1093, 549, 1275, 684], "id": 2, "cls": 2, "conf": 0.7662423849105835, "frame_idx": 391, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1090, 548, 1276, 685], "id": 2, "cls": 2, "conf": 0.7733460664749146, "frame_idx": 392, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1090, 548, 1275, 684], "id": 2, "cls": 2, "conf": 0.8063229918479919, "frame_idx": 393, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1088, 547, 1275, 685], "id": 2, "cls": 2, "conf": 0.834899365901947, "frame_idx": 394, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1085, 546, 1275, 686], "id": 2, "cls": 2, "conf": 0.8267676830291748, "frame_idx": 395, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1083, 546, 1274, 686], "id": 2, "cls": 2, "conf": 0.8470121622085571, "frame_idx": 396, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1082, 546, 1272, 685], "id": 2, "cls": 2, "conf": 0.8356623649597168, "frame_idx": 397, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1081, 546, 1271, 686], "id": 2, "cls": 2, "conf": 0.8369763493537903, "frame_idx": 398, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1080, 545, 1272, 686], "id": 2, "cls": 2, "conf": 0.8737363219261169, "frame_idx": 399, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1080, 544, 1271, 687], "id": 2, "cls": 2, "conf": 0.8609719276428223, "frame_idx": 400, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1078, 544, 1272, 689], "id": 2, "cls": 2, "conf": 0.83541339635849, "frame_idx": 401, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1078, 545, 1270, 689], "id": 2, "cls": 2, "conf": 0.8013574481010437, "frame_idx": 402, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1075, 544, 1271, 689], "id": 2, "cls": 2, "conf": 0.7798829078674316, "frame_idx": 403, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1074, 543, 1270, 691], "id": 2, "cls": 2, "conf": 0.8236221671104431, "frame_idx": 404, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1070, 543, 1270, 692], "id": 2, "cls": 2, "conf": 0.8620288372039795, "frame_idx": 405, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1070, 543, 1268, 692], "id": 2, "cls": 2, "conf": 0.8752257227897644, "frame_idx": 406, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1067, 542, 1268, 693], "id": 2, "cls": 2, "conf": 0.870403528213501, "frame_idx": 407, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1066, 542, 1269, 695], "id": 2, "cls": 2, "conf": 0.8699027299880981, "frame_idx": 408, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1062, 541, 1270, 696], "id": 2, "cls": 2, "conf": 0.8874167799949646, "frame_idx": 409, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1061, 541, 1269, 696], "id": 2, "cls": 2, "conf": 0.8754041194915771, "frame_idx": 410, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1060, 540, 1269, 698], "id": 2, "cls": 2, "conf": 0.8649414777755737, "frame_idx": 411, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1057, 539, 1268, 699], "id": 2, "cls": 2, "conf": 0.8912915587425232, "frame_idx": 412, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1056, 539, 1268, 700], "id": 2, "cls": 2, "conf": 0.8944886922836304, "frame_idx": 413, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1055, 539, 1269, 700], "id": 2, "cls": 2, "conf": 0.8907544612884521, "frame_idx": 414, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1054, 540, 1268, 701], "id": 2, "cls": 2, "conf": 0.8559849262237549, "frame_idx": 415, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1053, 541, 1266, 701], "id": 2, "cls": 2, "conf": 0.8329747319221497, "frame_idx": 416, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1051, 540, 1265, 702], "id": 2, "cls": 2, "conf": 0.8382128477096558, "frame_idx": 417, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1049, 540, 1266, 702], "id": 2, "cls": 2, "conf": 0.8805363178253174, "frame_idx": 418, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1046, 539, 1266, 703], "id": 2, "cls": 2, "conf": 0.8715322017669678, "frame_idx": 419, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1045, 539, 1267, 704], "id": 2, "cls": 2, "conf": 0.842781662940979, "frame_idx": 420, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1041, 539, 1268, 706], "id": 2, "cls": 2, "conf": 0.8441018462181091, "frame_idx": 421, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1038, 539, 1266, 708], "id": 2, "cls": 2, "conf": 0.7819275856018066, "frame_idx": 422, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1037, 539, 1264, 708], "id": 2, "cls": 2, "conf": 0.8135506510734558, "frame_idx": 423, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1033, 538, 1264, 710], "id": 2, "cls": 2, "conf": 0.8242059350013733, "frame_idx": 424, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1032, 538, 1265, 710], "id": 2, "cls": 2, "conf": 0.7836756110191345, "frame_idx": 425, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1031, 538, 1264, 710], "id": 2, "cls": 2, "conf": 0.8388970494270325, "frame_idx": 426, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1029, 537, 1264, 711], "id": 2, "cls": 2, "conf": 0.7970230579376221, "frame_idx": 427, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1027, 537, 1265, 711], "id": 2, "cls": 2, "conf": 0.7321099638938904, "frame_idx": 428, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1025, 538, 1265, 712], "id": 2, "cls": 2, "conf": 0.7343229651451111, "frame_idx": 429, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1020, 536, 1261, 712], "id": 2, "cls": 2, "conf": 0.787158727645874, "frame_idx": 430, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1018, 537, 1259, 713], "id": 2, "cls": 2, "conf": 0.8460677862167358, "frame_idx": 431, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1013, 536, 1261, 714], "id": 2, "cls": 2, "conf": 0.8292366862297058, "frame_idx": 432, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1011, 536, 1259, 716], "id": 2, "cls": 2, "conf": 0.8152600526809692, "frame_idx": 433, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1008, 535, 1258, 718], "id": 2, "cls": 2, "conf": 0.7996748089790344, "frame_idx": 434, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1007, 535, 1255, 719], "id": 2, "cls": 2, "conf": 0.8389233946800232, "frame_idx": 435, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1008, 535, 1253, 720], "id": 2, "cls": 2, "conf": 0.8631499409675598, "frame_idx": 436, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1002, 534, 1254, 721], "id": 2, "cls": 2, "conf": 0.8657373785972595, "frame_idx": 437, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [998, 534, 1253, 721], "id": 2, "cls": 2, "conf": 0.8603703379631042, "frame_idx": 438, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [995, 532, 1253, 722], "id": 2, "cls": 2, "conf": 0.8645334839820862, "frame_idx": 439, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [994, 532, 1252, 723], "id": 2, "cls": 2, "conf": 0.8768425583839417, "frame_idx": 440, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [991, 530, 1254, 724], "id": 2, "cls": 2, "conf": 0.8931466937065125, "frame_idx": 441, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [986, 530, 1256, 725], "id": 2, "cls": 2, "conf": 0.9038722515106201, "frame_idx": 442, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [985, 530, 1253, 725], "id": 2, "cls": 2, "conf": 0.9084876775741577, "frame_idx": 443, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [983, 530, 1251, 727], "id": 2, "cls": 2, "conf": 0.9005601406097412, "frame_idx": 444, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [980, 529, 1252, 729], "id": 2, "cls": 2, "conf": 0.8964847922325134, "frame_idx": 445, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [977, 529, 1251, 730], "id": 2, "cls": 2, "conf": 0.8957618474960327, "frame_idx": 446, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [974, 529, 1248, 731], "id": 2, "cls": 2, "conf": 0.8834296464920044, "frame_idx": 447, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [970, 527, 1246, 732], "id": 2, "cls": 2, "conf": 0.8654475212097168, "frame_idx": 448, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [966, 526, 1248, 734], "id": 2, "cls": 2, "conf": 0.8783361315727234, "frame_idx": 449, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [962, 526, 1245, 734], "id": 2, "cls": 2, "conf": 0.8720850348472595, "frame_idx": 450, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [959, 525, 1247, 735], "id": 2, "cls": 2, "conf": 0.8909793496131897, "frame_idx": 451, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [957, 525, 1244, 737], "id": 2, "cls": 2, "conf": 0.8911501169204712, "frame_idx": 452, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [954, 525, 1243, 739], "id": 2, "cls": 2, "conf": 0.8941781520843506, "frame_idx": 453, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [948, 524, 1245, 741], "id": 2, "cls": 2, "conf": 0.8771947622299194, "frame_idx": 454, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [943, 524, 1243, 744], "id": 2, "cls": 2, "conf": 0.8804555535316467, "frame_idx": 455, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [940, 523, 1243, 747], "id": 2, "cls": 2, "conf": 0.8785960078239441, "frame_idx": 456, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [934, 522, 1243, 749], "id": 2, "cls": 2, "conf": 0.9005946516990662, "frame_idx": 457, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [931, 521, 1242, 749], "id": 2, "cls": 2, "conf": 0.8925696611404419, "frame_idx": 458, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [928, 521, 1242, 749], "id": 2, "cls": 2, "conf": 0.8925560116767883, "frame_idx": 459, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [925, 522, 1239, 751], "id": 2, "cls": 2, "conf": 0.8871305584907532, "frame_idx": 460, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [917, 523, 1235, 753], "id": 2, "cls": 2, "conf": 0.8800134658813477, "frame_idx": 461, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [913, 523, 1234, 755], "id": 2, "cls": 2, "conf": 0.8769950270652771, "frame_idx": 462, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [912, 522, 1232, 757], "id": 2, "cls": 2, "conf": 0.8771668672561646, "frame_idx": 463, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [907, 521, 1230, 758], "id": 2, "cls": 2, "conf": 0.8780584931373596, "frame_idx": 464, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [902, 520, 1229, 759], "id": 2, "cls": 2, "conf": 0.9009929299354553, "frame_idx": 465, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [892, 520, 1230, 761], "id": 2, "cls": 2, "conf": 0.880210280418396, "frame_idx": 466, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [889, 519, 1227, 762], "id": 2, "cls": 2, "conf": 0.870464026927948, "frame_idx": 467, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [885, 520, 1225, 767], "id": 2, "cls": 2, "conf": 0.9003344774246216, "frame_idx": 468, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [877, 519, 1226, 767], "id": 2, "cls": 2, "conf": 0.920558512210846, "frame_idx": 469, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [875, 519, 1224, 768], "id": 2, "cls": 2, "conf": 0.9045699238777161, "frame_idx": 470, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [868, 518, 1223, 770], "id": 2, "cls": 2, "conf": 0.9074614644050598, "frame_idx": 471, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [864, 517, 1223, 773], "id": 2, "cls": 2, "conf": 0.9183488488197327, "frame_idx": 472, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [857, 516, 1222, 775], "id": 2, "cls": 2, "conf": 0.9148356914520264, "frame_idx": 473, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [853, 516, 1220, 777], "id": 2, "cls": 2, "conf": 0.9280686378479004, "frame_idx": 474, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [841, 514, 1221, 778], "id": 2, "cls": 2, "conf": 0.9198227524757385, "frame_idx": 475, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [838, 513, 1218, 780], "id": 2, "cls": 2, "conf": 0.8942911028862, "frame_idx": 476, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [830, 513, 1218, 782], "id": 2, "cls": 2, "conf": 0.8980481028556824, "frame_idx": 477, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [826, 513, 1213, 787], "id": 2, "cls": 2, "conf": 0.9096649289131165, "frame_idx": 478, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [819, 512, 1212, 793], "id": 2, "cls": 2, "conf": 0.9269362688064575, "frame_idx": 479, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [811, 509, 1213, 794], "id": 2, "cls": 2, "conf": 0.92948979139328, "frame_idx": 480, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [804, 509, 1211, 796], "id": 2, "cls": 2, "conf": 0.9076160788536072, "frame_idx": 481, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [794, 508, 1210, 798], "id": 2, "cls": 2, "conf": 0.9064416289329529, "frame_idx": 482, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [789, 508, 1208, 800], "id": 2, "cls": 2, "conf": 0.9050999879837036, "frame_idx": 483, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [780, 507, 1204, 803], "id": 2, "cls": 2, "conf": 0.9137296080589294, "frame_idx": 484, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [771, 507, 1204, 807], "id": 2, "cls": 2, "conf": 0.9088245630264282, "frame_idx": 485, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [765, 506, 1204, 810], "id": 2, "cls": 2, "conf": 0.9037410020828247, "frame_idx": 486, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [756, 506, 1203, 812], "id": 2, "cls": 2, "conf": 0.9066951870918274, "frame_idx": 487, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [746, 503, 1201, 818], "id": 2, "cls": 2, "conf": 0.914334774017334, "frame_idx": 488, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [735, 503, 1197, 825], "id": 2, "cls": 2, "conf": 0.9123433232307434, "frame_idx": 489, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [725, 502, 1195, 829], "id": 2, "cls": 2, "conf": 0.9094393849372864, "frame_idx": 490, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [717, 498, 1194, 833], "id": 2, "cls": 2, "conf": 0.9276642203330994, "frame_idx": 491, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [705, 499, 1194, 835], "id": 2, "cls": 2, "conf": 0.9282996654510498, "frame_idx": 492, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [696, 498, 1192, 837], "id": 2, "cls": 2, "conf": 0.9298180937767029, "frame_idx": 493, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [684, 496, 1191, 841], "id": 2, "cls": 2, "conf": 0.9258641600608826, "frame_idx": 494, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [673, 496, 1188, 847], "id": 2, "cls": 2, "conf": 0.923974335193634, "frame_idx": 495, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [661, 498, 1186, 856], "id": 2, "cls": 2, "conf": 0.9190512299537659, "frame_idx": 496, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [646, 495, 1183, 859], "id": 2, "cls": 2, "conf": 0.9168910980224609, "frame_idx": 497, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [632, 495, 1183, 868], "id": 2, "cls": 2, "conf": 0.925777018070221, "frame_idx": 498, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [621, 493, 1182, 873], "id": 2, "cls": 2, "conf": 0.9183085560798645, "frame_idx": 499, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [607, 491, 1180, 878], "id": 2, "cls": 2, "conf": 0.9321070909500122, "frame_idx": 500, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [588, 488, 1177, 882], "id": 2, "cls": 2, "conf": 0.9307034611701965, "frame_idx": 501, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [576, 485, 1174, 888], "id": 2, "cls": 2, "conf": 0.9412079453468323, "frame_idx": 502, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [562, 483, 1173, 893], "id": 2, "cls": 2, "conf": 0.9401066303253174, "frame_idx": 503, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [543, 475, 1171, 897], "id": 2, "cls": 2, "conf": 0.9346688389778137, "frame_idx": 504, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [527, 473, 1169, 903], "id": 2, "cls": 2, "conf": 0.9343288540840149, "frame_idx": 505, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [510, 474, 1164, 914], "id": 2, "cls": 2, "conf": 0.9404311180114746, "frame_idx": 506, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [490, 471, 1161, 920], "id": 2, "cls": 2, "conf": 0.9414466619491577, "frame_idx": 507, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [473, 469, 1159, 927], "id": 2, "cls": 2, "conf": 0.9434319138526917, "frame_idx": 508, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [451, 469, 1158, 938], "id": 2, "cls": 2, "conf": 0.9345313906669617, "frame_idx": 509, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [427, 469, 1156, 946], "id": 2, "cls": 2, "conf": 0.9282017946243286, "frame_idx": 510, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [405, 468, 1152, 952], "id": 2, "cls": 2, "conf": 0.9417479038238525, "frame_idx": 511, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [382, 468, 1150, 966], "id": 2, "cls": 2, "conf": 0.9451406598091125, "frame_idx": 512, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [360, 465, 1148, 976], "id": 2, "cls": 2, "conf": 0.9428954720497131, "frame_idx": 513, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [332, 463, 1148, 984], "id": 2, "cls": 2, "conf": 0.9395127892494202, "frame_idx": 514, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [303, 463, 1144, 992], "id": 2, "cls": 2, "conf": 0.9283111095428467, "frame_idx": 515, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [275, 462, 1136, 1003], "id": 2, "cls": 2, "conf": 0.9324305653572083, "frame_idx": 516, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [245, 461, 1131, 1018], "id": 2, "cls": 2, "conf": 0.9247828125953674, "frame_idx": 517, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [208, 453, 1130, 1032], "id": 2, "cls": 2, "conf": 0.9319226741790771, "frame_idx": 518, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [172, 451, 1129, 1045], "id": 2, "cls": 2, "conf": 0.9351807832717896, "frame_idx": 519, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [134, 449, 1125, 1058], "id": 2, "cls": 2, "conf": 0.9390578269958496, "frame_idx": 520, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [91, 445, 1119, 1068], "id": 2, "cls": 2, "conf": 0.947394609451294, "frame_idx": 521, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [46, 443, 1114, 1070], "id": 2, "cls": 2, "conf": 0.9468377232551575, "frame_idx": 522, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [2, 440, 1110, 1072], "id": 2, "cls": 2, "conf": 0.9386428594589233, "frame_idx": 523, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 438, 1105, 1072], "id": 2, "cls": 2, "conf": 0.9346777200698853, "frame_idx": 524, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 435, 1107, 1072], "id": 2, "cls": 2, "conf": 0.9273584485054016, "frame_idx": 525, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 437, 1096, 1071], "id": 2, "cls": 2, "conf": 0.9241657257080078, "frame_idx": 526, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 432, 1095, 1071], "id": 2, "cls": 2, "conf": 0.9355752468109131, "frame_idx": 527, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 428, 1094, 1070], "id": 2, "cls": 2, "conf": 0.9321312308311462, "frame_idx": 528, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 420, 1082, 1073], "id": 2, "cls": 2, "conf": 0.9156169891357422, "frame_idx": 529, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [2, 409, 1077, 1070], "id": 2, "cls": 2, "conf": 0.8867893815040588, "frame_idx": 530, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [2, 388, 1070, 1071], "id": 2, "cls": 2, "conf": 0.9155814051628113, "frame_idx": 531, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 399, 1066, 1072], "id": 2, "cls": 2, "conf": 0.9372450113296509, "frame_idx": 532, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 389, 1057, 1071], "id": 2, "cls": 2, "conf": 0.9160026907920837, "frame_idx": 533, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 390, 1052, 1070], "id": 2, "cls": 2, "conf": 0.9509764313697815, "frame_idx": 534, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 386, 1042, 1070], "id": 2, "cls": 2, "conf": 0.9340437650680542, "frame_idx": 535, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [2, 381, 1038, 1068], "id": 2, "cls": 2, "conf": 0.9404564499855042, "frame_idx": 536, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [2, 375, 1030, 1066], "id": 2, "cls": 2, "conf": 0.9479154348373413, "frame_idx": 537, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [2, 370, 1024, 1067], "id": 2, "cls": 2, "conf": 0.9565911293029785, "frame_idx": 538, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 365, 1016, 1067], "id": 2, "cls": 2, "conf": 0.9608258008956909, "frame_idx": 539, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [2, 357, 1006, 1064], "id": 2, "cls": 2, "conf": 0.9613184332847595, "frame_idx": 540, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [3, 347, 999, 1064], "id": 2, "cls": 2, "conf": 0.9674457311630249, "frame_idx": 541, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 338, 992, 1064], "id": 2, "cls": 2, "conf": 0.97267746925354, "frame_idx": 542, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 328, 983, 1064], "id": 2, "cls": 2, "conf": 0.9624996781349182, "frame_idx": 543, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 319, 972, 1063], "id": 2, "cls": 2, "conf": 0.9598995447158813, "frame_idx": 544, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 307, 959, 1062], "id": 2, "cls": 2, "conf": 0.9514867663383484, "frame_idx": 545, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 292, 948, 1062], "id": 2, "cls": 2, "conf": 0.9584953784942627, "frame_idx": 546, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 279, 935, 1065], "id": 2, "cls": 2, "conf": 0.9569721221923828, "frame_idx": 547, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 270, 927, 1066], "id": 2, "cls": 2, "conf": 0.972572922706604, "frame_idx": 548, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [2, 258, 915, 1066], "id": 2, "cls": 2, "conf": 0.9626525044441223, "frame_idx": 549, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [2, 241, 898, 1064], "id": 2, "cls": 2, "conf": 0.9489137530326843, "frame_idx": 550, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 221, 885, 1065], "id": 2, "cls": 2, "conf": 0.9458200931549072, "frame_idx": 551, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 204, 868, 1066], "id": 2, "cls": 2, "conf": 0.9462317228317261, "frame_idx": 552, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 193, 856, 1066], "id": 2, "cls": 2, "conf": 0.9367963075637817, "frame_idx": 553, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 180, 836, 1067], "id": 2, "cls": 2, "conf": 0.9550886154174805, "frame_idx": 554, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 173, 820, 1068], "id": 2, "cls": 2, "conf": 0.9146677255630493, "frame_idx": 555, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 170, 797, 1066], "id": 2, "cls": 2, "conf": 0.9364038109779358, "frame_idx": 556, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 171, 779, 1067], "id": 2, "cls": 2, "conf": 0.9397339224815369, "frame_idx": 557, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 171, 751, 1068], "id": 2, "cls": 2, "conf": 0.9423396587371826, "frame_idx": 558, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 175, 729, 1067], "id": 2, "cls": 2, "conf": 0.9324960708618164, "frame_idx": 559, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 181, 700, 1066], "id": 2, "cls": 2, "conf": 0.9049985408782959, "frame_idx": 560, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 188, 672, 1067], "id": 2, "cls": 2, "conf": 0.8566305637359619, "frame_idx": 561, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 195, 637, 1067], "id": 2, "cls": 2, "conf": 0.9080706834793091, "frame_idx": 562, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 199, 603, 1068], "id": 2, "cls": 2, "conf": 0.9104960560798645, "frame_idx": 563, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [1, 220, 559, 1063], "id": 2, "cls": 2, "conf": 0.9200505614280701, "frame_idx": 564, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 235, 516, 1067], "id": 2, "cls": 2, "conf": 0.9269247651100159, "frame_idx": 565, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [0, 250, 470, 1065], "id": 2, "cls": 2, "conf": 0.8854379057884216, "frame_idx": 566, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [3, 256, 409, 1066], "id": 2, "cls": 2, "conf": 0.8114883303642273, "frame_idx": 567, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [4, 239, 349, 1070], "id": 2, "cls": 2, "conf": 0.7934050559997559, "frame_idx": 568, "source": "video/sample.mp4", "class_name": "car"} +{"bbox": [7, 409, 283, 1065], "id": 2, "cls": 2, "conf": 0.7185706496238708, "frame_idx": 569, "source": "video/sample.mp4", "class_name": "car"} diff --git a/models/.gitkeep b/models/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/models/bangchakv2.mpta b/models/bangchakv2.mpta new file mode 100644 index 0000000..b79eb1c Binary files /dev/null and b/models/bangchakv2.mpta differ diff --git a/note.txt b/note.txt new file mode 100644 index 0000000..1667b90 --- /dev/null +++ b/note.txt @@ -0,0 +1 @@ +uvicorn app:app --host 0.0.0.0 --port 8000 --reload \ No newline at end of file diff --git a/output/sample2_detected.png b/output/sample2_detected.png new file mode 100644 index 0000000..8dd9f79 Binary files /dev/null and b/output/sample2_detected.png differ diff --git a/output/sample2_results.txt b/output/sample2_results.txt new file mode 100644 index 0000000..d90003d --- /dev/null +++ b/output/sample2_results.txt @@ -0,0 +1,17 @@ +Model: car_frontal_detection_v1.pt +Image: sample2.png +Confidence threshold: 0.3 +Detections: 2 + +Detection 1: + Class: Car + Confidence: 0.863 + Bounding box: (86.5, 73.4, 825.6, 625.2) + Size: 739.1x551.9 + +Detection 2: + Class: Frontal + Confidence: 0.504 + Bounding box: (176.6, 307.2, 708.1, 609.0) + Size: 531.5x301.7 + diff --git a/output/sample_detected.jpg b/output/sample_detected.jpg new file mode 100644 index 0000000..24c3281 Binary files /dev/null and b/output/sample_detected.jpg differ diff --git a/output/sample_results.txt b/output/sample_results.txt new file mode 100644 index 0000000..3228df7 --- /dev/null +++ b/output/sample_results.txt @@ -0,0 +1,17 @@ +Model: car_frontal_detection_v1.pt +Image: sample.jpg +Confidence threshold: 0.3 +Detections: 2 + +Detection 1: + Class: Frontal + Confidence: 0.555 + Bounding box: (175.9, 279.7, 527.6, 500.9) + Size: 351.7x221.2 + +Detection 2: + Class: Car + Confidence: 0.418 + Bounding box: (167.7, 196.7, 881.4, 532.7) + Size: 713.8x336.0 + diff --git a/output/test_image_detected.jpg b/output/test_image_detected.jpg new file mode 100644 index 0000000..3ce5925 Binary files /dev/null and b/output/test_image_detected.jpg differ diff --git a/output/test_image_results.txt b/output/test_image_results.txt new file mode 100644 index 0000000..72db288 --- /dev/null +++ b/output/test_image_results.txt @@ -0,0 +1,5 @@ +Model: car_frontal_detection_v1.pt +Image: test_image.jpg +Confidence threshold: 0.3 +Detections: 0 + diff --git a/requirements.base.txt b/requirements.base.txt index b8af923..af22160 100644 --- a/requirements.base.txt +++ b/requirements.base.txt @@ -4,9 +4,4 @@ ultralytics opencv-python scipy filterpy -psycopg2-binary -lap>=0.5.12 -pynvml -PyTurboJPEG -PyNvVideoCodec -cupy-cuda12x \ No newline at end of file +psycopg2-binary \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2afeb0e..6eaf131 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,4 @@ uvicorn websockets fastapi[standard] redis -urllib3<2.0.0 -numpy -requests -watchdog \ No newline at end of file +urllib3<2.0.0 \ No newline at end of file diff --git a/sample.jpg b/sample.jpg new file mode 100644 index 0000000..6ae641c Binary files /dev/null and b/sample.jpg differ diff --git a/sample2.png b/sample2.png new file mode 100644 index 0000000..6415a7e Binary files /dev/null and b/sample2.png differ diff --git a/archive/siwatsystem/database.py b/siwatsystem/database.py similarity index 100% rename from archive/siwatsystem/database.py rename to siwatsystem/database.py diff --git a/archive/siwatsystem/pympta.py b/siwatsystem/pympta.py similarity index 98% rename from archive/siwatsystem/pympta.py rename to siwatsystem/pympta.py index d21232d..53d97d9 100644 --- a/archive/siwatsystem/pympta.py +++ b/siwatsystem/pympta.py @@ -117,6 +117,21 @@ def load_pipeline_node(node_config: dict, mpta_dir: str, redis_client, db_manage return node def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict: + # Restrict to models directory for security + if not zip_source.startswith('models/'): + zip_source = os.path.join('models', zip_source) + + # Validate the path is within models directory (prevent path traversal) + try: + abs_zip_path = os.path.abspath(zip_source) + abs_models_path = os.path.abspath('models') + if not abs_zip_path.startswith(abs_models_path): + logger.error(f"Security violation: {zip_source} is outside models directory") + return None + except Exception as e: + logger.error(f"Error validating path {zip_source}: {str(e)}") + return None + logger.info(f"Attempting to load pipeline from {zip_source} to {target_dir}") os.makedirs(target_dir, exist_ok=True) zip_path = os.path.join(target_dir, "pipeline.mpta") diff --git a/test.py b/test.py new file mode 100644 index 0000000..ca34f01 --- /dev/null +++ b/test.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +""" +Test script for car_frontal_detection_v1.pt model +Usage: python test.py --image [--confidence ] [--save-output] +""" +# python test.py --image sample.jpg --confidence 0.6 --save-output + +import argparse +import cv2 +import torch +import numpy as np +from pathlib import Path +import sys + +def load_model_direct(model_path): + """Load model directly with torch.load to handle version compatibility""" + try: + # Try to load with weights_only=False for compatibility + checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) + print(f"Model checkpoint keys: {list(checkpoint.keys())}") + + # Try to get model info + if 'model' in checkpoint: + model_info = checkpoint.get('model', {}) + print(f"Model info available: {hasattr(model_info, 'names') if hasattr(model_info, 'names') else 'No names found'}") + + return checkpoint + except Exception as e: + print(f"Direct torch.load failed: {e}") + return None + +def main(): + parser = argparse.ArgumentParser(description='Test car frontal detection model') + parser.add_argument('--image', required=True, help='Path to input image') + parser.add_argument('--model', default='car_frontal_detection_v1.pt', help='Path to model file') + parser.add_argument('--confidence', type=float, default=0.5, help='Confidence threshold (default: 0.5)') + parser.add_argument('--save-output', action='store_true', help='Save output image with detections') + parser.add_argument('--output-dir', default='output', help='Output directory for results') + parser.add_argument('--use-yolo', action='store_true', default=True, help='Use YOLO loading (default: True)') + + args = parser.parse_args() + + # Check if model file exists + if not Path(args.model).exists(): + print(f"Error: Model file '{args.model}' not found") + sys.exit(1) + + # Check if image file exists + if not Path(args.image).exists(): + print(f"Error: Image file '{args.image}' not found") + sys.exit(1) + + print(f"Loading model: {args.model}") + + model = None + if args.use_yolo: + try: + from ultralytics import YOLO + model = YOLO(args.model) + print(f"Model loaded successfully with YOLO") + print(f"Model classes: {model.names}") + except Exception as e: + print(f"Error loading model with YOLO: {e}") + print("Falling back to direct loading...") + + if model is None: + # Try direct loading for inspection + checkpoint = load_model_direct(args.model) + if checkpoint is None: + print("Failed to load model with any method") + sys.exit(1) + + print("Model loaded directly - this is for inspection only") + print("Available keys in checkpoint:", list(checkpoint.keys())) + + # Try to get model information + if 'model' in checkpoint: + model_obj = checkpoint['model'] + print(f"Model object type: {type(model_obj)}") + if hasattr(model_obj, 'names'): + print(f"Model classes: {model_obj.names}") + if hasattr(model_obj, 'yaml'): + print(f"Model YAML config available: {bool(model_obj.yaml)}") + + print("\nTo run inference, you need a compatible Ultralytics version.") + print("Consider upgrading ultralytics: pip install ultralytics --upgrade") + return + + print(f"Loading image: {args.image}") + try: + image = cv2.imread(args.image) + if image is None: + raise ValueError("Could not load image") + print(f"Image shape: {image.shape}") + except Exception as e: + print(f"Error loading image: {e}") + sys.exit(1) + + print(f"Running inference with confidence threshold: {args.confidence}") + try: + results = model(image, conf=args.confidence) + + if len(results) > 0 and len(results[0].boxes) > 0: + print(f"Detections found: {len(results[0].boxes)}") + + # Print detection details + for i, box in enumerate(results[0].boxes): + x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() + conf = box.conf[0].cpu().numpy() + cls = int(box.cls[0].cpu().numpy()) + class_name = model.names[cls] if cls in model.names else f"Class_{cls}" + + print(f"Detection {i+1}:") + print(f" Class: {class_name}") + print(f" Confidence: {conf:.3f}") + print(f" Bounding box: ({x1:.1f}, {y1:.1f}, {x2:.1f}, {y2:.1f})") + print(f" Size: {x2-x1:.1f}x{y2-y1:.1f}") + else: + print("No detections found") + + if args.save_output: + output_dir = Path(args.output_dir) + output_dir.mkdir(exist_ok=True) + + # Draw detections on image + annotated_image = results[0].plot() + + # Save annotated image + input_path = Path(args.image) + output_path = output_dir / f"{input_path.stem}_detected{input_path.suffix}" + cv2.imwrite(str(output_path), annotated_image) + print(f"Output saved to: {output_path}") + + # Also save results as text + results_path = output_dir / f"{input_path.stem}_results.txt" + with open(results_path, 'w') as f: + f.write(f"Model: {args.model}\n") + f.write(f"Image: {args.image}\n") + f.write(f"Confidence threshold: {args.confidence}\n") + f.write(f"Detections: {len(results[0].boxes) if len(results) > 0 else 0}\n\n") + + if len(results) > 0 and len(results[0].boxes) > 0: + for i, box in enumerate(results[0].boxes): + x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() + conf = box.conf[0].cpu().numpy() + cls = int(box.cls[0].cpu().numpy()) + class_name = model.names[cls] if cls in model.names else f"Class_{cls}" + + f.write(f"Detection {i+1}:\n") + f.write(f" Class: {class_name}\n") + f.write(f" Confidence: {conf:.3f}\n") + f.write(f" Bounding box: ({x1:.1f}, {y1:.1f}, {x2:.1f}, {y2:.1f})\n") + f.write(f" Size: {x2-x1:.1f}x{y2-y1:.1f}\n\n") + + print(f"Results saved to: {results_path}") + + except Exception as e: + print(f"Error during inference: {e}") + sys.exit(1) + + print("Test completed successfully!") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test_protocol.py b/test_protocol.py index 6b32fd8..74af7d8 100644 --- a/test_protocol.py +++ b/test_protocol.py @@ -9,7 +9,7 @@ import time async def test_protocol(): """Test the worker protocol implementation""" - uri = "ws://localhost:8001" + uri = "ws://localhost:8000" try: async with websockets.connect(uri) as websocket: @@ -119,7 +119,7 @@ async def test_protocol(): except Exception as e: print(f"✗ Connection failed: {e}") - print("Make sure the worker is running on localhost:8001") + print("Make sure the worker is running on localhost:8000") if __name__ == "__main__": asyncio.run(test_protocol()) \ No newline at end of file diff --git a/worker.md b/worker.md index 72c5e69..302c8ce 100644 --- a/worker.md +++ b/worker.md @@ -2,6 +2,12 @@ This document outlines the WebSocket-based communication protocol between the CMS backend and a detector worker. As a worker developer, your primary responsibility is to implement a WebSocket server that adheres to this protocol. +The current Python Detector Worker implementation supports advanced computer vision pipelines with: +- Multi-class YOLO detection with parallel processing +- PostgreSQL database integration with automatic schema management +- Redis integration for image storage and pub/sub messaging +- Hierarchical pipeline execution with detection → classification branching + ## 1. Connection The worker must run a WebSocket server, preferably on port `8000`. The backend system, which is managed by a container orchestration service, will automatically discover and establish a WebSocket connection to your worker. @@ -15,86 +21,9 @@ Communication is bidirectional and asynchronous. All messages are JSON objects w - **Worker -> Backend:** You will send messages to the backend to report status, forward detection events, or request changes to session data. - **Backend -> Worker:** The backend will send commands to you to manage camera subscriptions. -### 2.1. Multi-Process Cluster Architecture +## 3. Dynamic Configuration via MPTA File -The backend uses a sophisticated multi-process cluster architecture with Redis-based coordination to manage worker connections at scale: - -**Redis Communication Channels:** - -- `worker:commands` - Commands sent TO workers (subscribe, unsubscribe, setSessionId, setProgressionStage) -- `worker:responses` - Detection responses and state reports FROM workers -- `worker:events` - Worker lifecycle events (connection, disconnection, health status) - -**Distributed State Management:** - -- `worker:states` - Redis hash map storing real-time worker performance metrics and connection status -- `worker:assignments` - Redis hash map tracking camera-to-worker assignments across the cluster -- `worker:owners` - Redis key-based worker ownership leases with 30-second TTL for automatic failover - -**Load Balancing & Failover:** - -- **Assignment Algorithm**: Workers are assigned based on subscription count and CPU usage -- **Distributed Locking**: Assignment operations use Redis locks to prevent race conditions -- **Automatic Failover**: Orphaned workers are detected via lease expiration and automatically reclaimed -- **Horizontal Scaling**: New backend processes automatically join the cluster and participate in load balancing - -**Inter-Process Coordination:** - -- Each backend process maintains local WebSocket connections with workers -- Commands are routed via Redis pub/sub to the process that owns the target worker connection -- Master election ensures coordinated cluster management and prevents split-brain scenarios -- Process identification uses UUIDs for clean process tracking and ownership management - -## 3. Message Types and Command Structure - -All worker communication follows a standardized message structure with the following command types: - -**Commands from Backend to Worker:** - -- `setSubscriptionList` - Set complete list of camera subscriptions for declarative state management -- `setSessionId` - Associate a session ID with a display for detection linking -- `setProgressionStage` - Update the progression stage for context-aware processing -- `requestState` - Request immediate state report from worker -- `patchSessionResult` - Response to worker's patch session request - -**Messages from Worker to Backend:** - -- `stateReport` - Periodic heartbeat with performance metrics and subscription status -- `imageDetection` - Real-time detection results with timestamp and data -- `patchSession` - Request to modify display persistent session data - -**Command Structure:** - -```typescript -interface WorkerCommand { - type: string; - subscriptions?: SubscriptionObject[]; // For setSubscriptionList - payload?: { - displayIdentifier?: string; - sessionId?: number | null; - progressionStage?: string | null; - // Additional payload fields based on command type - }; -} - -interface SubscriptionObject { - subscriptionIdentifier: string; // Format: "displayId;cameraId" - rtspUrl: string; - snapshotUrl?: string; - snapshotInterval?: number; // milliseconds - modelUrl: string; // Fresh pre-signed URL (1-hour TTL) - modelId: number; - modelName: string; - cropX1?: number; - cropY1?: number; - cropX2?: number; - cropY2?: number; -} -``` - -## 4. Dynamic Configuration via MPTA File - -To enable modularity and dynamic configuration, the backend will send you a URL to a `.mpta` file in each subscription within the `setSubscriptionList` command. This file is a renamed `.zip` archive that contains everything your worker needs to perform its task. +To enable modularity and dynamic configuration, the backend will send you a URL to a `.mpta` file when it issues a `subscribe` command. This file is a renamed `.zip` archive that contains everything your worker needs to perform its task. **Your worker is responsible for:** @@ -102,75 +31,40 @@ To enable modularity and dynamic configuration, the backend will send you a URL 2. Extracting its contents. 3. Interpreting the contents to configure its internal pipeline. -**The contents of the `.mpta` file are entirely up to the user who configures the model in the CMS.** This allows for maximum flexibility. For example, the archive could contain: +**The current implementation supports comprehensive pipeline configurations including:** -- AI/ML Models: Pre-trained models for libraries like TensorFlow, PyTorch, or ONNX. -- Configuration Files: A `config.json` or `pipeline.yaml` that defines a sequence of operations, specifies model paths, or sets detection thresholds. -- Scripts: Custom Python scripts for pre-processing or post-processing. -- API Integration Details: A JSON file with endpoint information and credentials for interacting with third-party detection services. +- **AI/ML Models**: YOLO models (.pt files) for detection and classification +- **Pipeline Configuration**: `pipeline.json` defining hierarchical detection→classification workflows +- **Multi-class Detection**: Simultaneous detection of multiple object classes (e.g., Car + Frontal) +- **Parallel Processing**: Concurrent execution of classification branches with ThreadPoolExecutor +- **Database Integration**: PostgreSQL configuration for automatic table creation and updates +- **Redis Actions**: Image storage with region cropping and pub/sub messaging +- **Dynamic Field Mapping**: Template-based field resolution for database operations -Essentially, the `.mpta` file is a self-contained package that tells your worker _how_ to process the video stream for a given subscription. +**Enhanced MPTA Structure:** +``` +pipeline.mpta/ +├── pipeline.json # Main configuration with redis/postgresql settings +├── car_detection.pt # Primary YOLO detection model +├── brand_classifier.pt # Classification model for car brands +├── bodytype_classifier.pt # Classification model for body types +└── ... +``` -## 5. Worker State Recovery and Reconnection +The `pipeline.json` now supports advanced features like: +- Multi-class detection with `expectedClasses` validation +- Parallel branch processing with `parallel: true` +- Database actions with `postgresql_update_combined` +- Redis actions with region-specific image cropping +- Branch synchronization with `waitForBranches` -The system provides comprehensive state recovery mechanisms to ensure seamless operation across worker disconnections and backend restarts. +Essentially, the `.mpta` file is a self-contained package that tells your worker *how* to process the video stream for a given subscription, including complex multi-stage AI pipelines with database persistence. -### 5.1. Automatic Resubscription - -**Connection Recovery Flow:** - -1. **Connection Detection**: Backend detects worker reconnection via WebSocket events -2. **State Restoration**: All subscription states are restored from backend memory and Redis -3. **Fresh Model URLs**: New model URLs are generated to handle S3 URL expiration -4. **Session Recovery**: Session IDs and progression stages are automatically restored -5. **Heartbeat Resumption**: Worker immediately begins sending state reports - -### 5.2. State Persistence Architecture - -**Backend State Storage:** - -- **Local State**: Each backend process maintains `DetectorWorkerState` with active subscriptions -- **Redis Coordination**: Assignment mappings stored in `worker:assignments` Redis hash -- **Session Tracking**: Display session IDs tracked in display persistent data -- **Progression Stages**: Current stages maintained in display controllers - -**Recovery Guarantees:** - -- **Zero Configuration Loss**: All subscription parameters are preserved across disconnections -- **Session Continuity**: Active sessions remain linked after worker reconnection -- **Stage Synchronization**: Progression stages are immediately synchronized on reconnection -- **Model Availability**: Fresh model URLs ensure continuous access to detection models - -### 5.3. Heartbeat and Health Monitoring - -**Health Check Protocol:** - -- **Heartbeat Interval**: Workers send `stateReport` every 2 seconds -- **Timeout Detection**: Backend marks workers offline after 10-second timeout -- **Automatic Recovery**: Offline workers are automatically rescheduled when they reconnect -- **Performance Tracking**: CPU, memory, and GPU usage monitored for load balancing - -**Failure Scenarios:** - -- **Worker Crash**: Subscriptions are reassigned to other available workers -- **Network Interruption**: Automatic reconnection with full state restoration -- **Backend Restart**: Worker assignments are restored from Redis state -- **Redis Failure**: Local state provides temporary operation until Redis recovers - -### 5.4. Multi-Process Coordination - -**Ownership and Leasing:** - -- **Worker Ownership**: Each worker is owned by a single backend process via Redis lease -- **Lease Renewal**: 30-second TTL leases automatically renewed by owning process -- **Orphan Detection**: Expired leases allow worker reassignment to active processes -- **Graceful Handover**: Clean ownership transfer during process shutdown - -## 6. Messages from Worker to Backend +## 4. Messages from Worker to Backend These are the messages your worker is expected to send to the backend. -### 6.1. State Report (Heartbeat) +### 4.1. State Report (Heartbeat) This message is crucial for the backend to monitor your worker's health and status, including GPU usage. @@ -205,12 +99,21 @@ This message is crucial for the backend to monitor your worker's health and stat > > - `cropX1`, `cropY1`, `cropX2`, `cropY2` (optional, integer) should be included in each camera connection to indicate the crop coordinates for that subscription. -### 6.2. Image Detection +### 4.2. Image Detection Sent when the worker detects a relevant object. The `detection` object should be flat and contain key-value pairs corresponding to the detected attributes. - **Type:** `imageDetection` +**Enhanced Detection Capabilities:** + +The current implementation supports multi-class detection with parallel classification processing. When a vehicle is detected, the system: + +1. **Multi-Class Detection**: Simultaneously detects "Car" and "Frontal" classes +2. **Parallel Processing**: Runs brand and body type classification concurrently +3. **Database Integration**: Automatically creates and updates PostgreSQL records +4. **Redis Storage**: Saves cropped frontal images with expiration + **Payload Example:** ```json @@ -220,20 +123,39 @@ Sent when the worker detects a relevant object. The `detection` object should be "timestamp": "2025-07-14T12:34:56.789Z", "data": { "detection": { - "carModel": "Civic", + "class": "Car", + "confidence": 0.92, "carBrand": "Honda", - "carYear": 2023, + "carModel": "Civic", "bodyType": "Sedan", - "licensePlateText": "ABCD1234", - "licensePlateConfidence": 0.95 + "branch_results": { + "car_brand_cls_v1": { + "class": "Honda", + "confidence": 0.89, + "brand": "Honda" + }, + "car_bodytype_cls_v1": { + "class": "Sedan", + "confidence": 0.85, + "body_type": "Sedan" + } + } }, "modelId": 101, - "modelName": "US-LPR-and-Vehicle-ID" + "modelName": "Car Frontal Detection V1" } } ``` -### 6.3. Patch Session +**Database Integration:** + +Each detection automatically: +- Creates a record in `gas_station_1.car_frontal_info` table +- Generates a unique `session_id` for tracking +- Updates the record with classification results after parallel processing completes +- Stores cropped frontal images in Redis with the session_id as key + +### 4.3. Patch Session > **Note:** Patch messages are only used when the worker can't keep up and needs to retroactively send detections. Normally, detections should be sent in real-time using `imageDetection` messages. Use `patchSession` only to update session data after the fact. @@ -249,9 +171,9 @@ Allows the worker to request a modification to an active session's data. The `da "sessionId": 12345, "data": { "currentCar": { - "carModel": "Civic", - "carBrand": "Honda", - "licensePlateText": "ABCD1234" + "carModel": "Civic", + "carBrand": "Honda", + "licensePlateText": "ABCD1234" } } } @@ -265,33 +187,24 @@ The `data` object in the `patchSession` message is merged with the existing `Dis ```typescript interface DisplayPersistentData { - progressionStage: - | 'welcome' - | 'car_fueling' - | 'car_waitpayment' - | 'car_postpayment' - | null; - qrCode: string | null; - adsPlayback: { - playlistSlotOrder: number; // The 'order' of the current slot - adsId: number | null; - adsUrl: string | null; - } | null; - currentCar: { - carModel?: string; - carBrand?: string; - carYear?: number; - bodyType?: string; - licensePlateText?: string; - licensePlateType?: string; - } | null; - fuelPump: { - /* FuelPumpData structure */ - } | null; - weatherData: { - /* WeatherResponse structure */ - } | null; - sessionId: number | null; + progressionStage: "welcome" | "car_fueling" | "car_waitpayment" | "car_postpayment" | null; + qrCode: string | null; + adsPlayback: { + playlistSlotOrder: number; // The 'order' of the current slot + adsId: number | null; + adsUrl: string | null; + } | null; + currentCar: { + carModel?: string; + carBrand?: string; + carYear?: number; + bodyType?: string; + licensePlateText?: string; + licensePlateType?: string; + } | null; + fuelPump: { /* FuelPumpData structure */ } | null; + weatherData: { /* WeatherResponse structure */ } | null; + sessionId: number | null; } ``` @@ -302,91 +215,68 @@ interface DisplayPersistentData { - **`null`** values will set the corresponding field to `null`. - Nested objects are merged recursively. -## 7. Commands from Backend to Worker +## 5. Commands from Backend to Worker -These are the commands your worker will receive from the backend. The subscription system uses a **fully declarative approach** with `setSubscriptionList` - the backend sends the complete desired subscription list, and workers handle reconciliation internally. +These are the commands your worker will receive from the backend. -### 7.1. Set Subscription List (Declarative Subscriptions) +### 5.1. Subscribe to Camera -**The primary subscription command that replaces individual subscribe/unsubscribe operations.** +Instructs the worker to process a camera's RTSP stream using the configuration from the specified `.mpta` file. -Instructs the worker to process the complete list of camera streams. The worker must reconcile this list with its current subscriptions, adding new ones, removing obsolete ones, and updating existing ones as needed. - -- **Type:** `setSubscriptionList` +- **Type:** `subscribe` **Payload:** ```json { - "type": "setSubscriptionList", - "subscriptions": [ - { - "subscriptionIdentifier": "display-001;cam-001", - "rtspUrl": "rtsp://user:pass@host:port/stream1", - "snapshotUrl": "http://go2rtc/snapshot/1", - "snapshotInterval": 5000, - "modelUrl": "http://storage/models/us-lpr.mpta?token=fresh-token", - "modelName": "US-LPR-and-Vehicle-ID", - "modelId": 102, - "cropX1": 100, - "cropY1": 200, - "cropX2": 300, - "cropY2": 400 - }, - { - "subscriptionIdentifier": "display-002;cam-001", - "rtspUrl": "rtsp://user:pass@host:port/stream1", - "snapshotUrl": "http://go2rtc/snapshot/1", - "snapshotInterval": 5000, - "modelUrl": "http://storage/models/vehicle-detect.mpta?token=fresh-token", - "modelName": "Vehicle Detection", - "modelId": 201, - "cropX1": 0, - "cropY1": 0, - "cropX2": 1920, - "cropY2": 1080 - } - ] + "type": "subscribe", + "payload": { + "subscriptionIdentifier": "display-001;cam-002", + "rtspUrl": "rtsp://user:pass@host:port/stream", + "snapshotUrl": "http://go2rtc/snapshot/1", + "snapshotInterval": 5000, + "modelUrl": "http://storage/models/us-lpr.mpta", + "modelName": "US-LPR-and-Vehicle-ID", + "modelId": 102, + "cropX1": 100, + "cropY1": 200, + "cropX2": 300, + "cropY2": 400 + } } ``` -**Declarative Subscription Behavior:** - -- **Complete State Definition**: The backend sends the complete desired subscription list for this worker -- **Worker-Side Reconciliation**: Workers compare the new list with current subscriptions and handle differences -- **Fresh Model URLs**: Each command includes fresh pre-signed S3 URLs (1-hour TTL) for ML models -- **Load Balancing**: The backend intelligently distributes subscriptions across available workers -- **State Recovery**: Complete subscription list is sent on worker reconnection - -**Worker Reconciliation Responsibility:** - -When receiving a `setSubscriptionList` command, your worker must: - -1. **Compare with Current State**: Identify new subscriptions, removed subscriptions, and updated subscriptions -2. **Add New Subscriptions**: Start processing new camera streams with the provided configuration -3. **Remove Obsolete Subscriptions**: Stop processing camera streams not in the new list -4. **Update Existing Subscriptions**: Handle configuration changes (model updates, crop coordinates, etc.) -5. **Maintain Single Streams**: Ensure only one RTSP stream per camera, even with multiple display bindings -6. **Report Final State**: Send updated `stateReport` confirming the actual subscription state - > **Note:** > -> - `cropX1`, `cropY1`, `cropX2`, `cropY2` (optional, integer) specify the crop coordinates for the camera stream -> - `snapshotUrl` and `snapshotInterval` (optional) enable periodic snapshot capture -> - Multiple subscriptions may share the same `rtspUrl` but have different `subscriptionIdentifier` values +> - `cropX1`, `cropY1`, `cropX2`, `cropY2` (optional, integer) specify the crop coordinates for the camera stream. These values are configured per display and passed in the subscription payload. If not provided, the worker should process the full frame. > -> **Camera Stream Optimization:** -> When multiple subscriptions share the same camera (same `rtspUrl`), your worker must: +> **Important:** +> If multiple displays are bound to the same camera, your worker must ensure that only **one stream** is opened per camera. When you receive multiple subscriptions for the same camera (with different `subscriptionIdentifier` values), you should: > -> - Open the RTSP stream **once** for that camera -> - Capture each frame/snapshot **once** per cycle -> - Process the shared stream for each subscription's requirements (crop coordinates, model) -> - Route detection results separately for each `subscriptionIdentifier` -> - Apply display-specific crop coordinates during processing -> -> This optimization reduces bandwidth usage and ensures consistent detection timing across displays. +> - Open the RTSP stream **once** for that camera if using RTSP. +> - Capture each snapshot only once per cycle, and reuse it for all display subscriptions sharing that camera. +> - Capture each frame/image only once per cycle. +> - Reuse the same captured image and snapshot for all display subscriptions that share the camera, processing and routing detection results separately for each display as needed. +> This avoids unnecessary load and bandwidth usage, and ensures consistent detection results and snapshots across all displays sharing the same camera. -### 7.2. Request State +### 5.2. Unsubscribe from Camera + +Instructs the worker to stop processing a camera's stream. + +- **Type:** `unsubscribe` + +**Payload:** + +```json +{ + "type": "unsubscribe", + "payload": { + "subscriptionIdentifier": "display-001;cam-002" + } +} +``` + +### 5.3. Request State Direct request for the worker's current state. Respond with a `stateReport` message. @@ -400,7 +290,7 @@ Direct request for the worker's current state. Respond with a `stateReport` mess } ``` -### 7.3. Patch Session Result +### 5.4. Patch Session Result Backend's response to a `patchSession` message. @@ -419,11 +309,9 @@ Backend's response to a `patchSession` message. } ``` -### 7.4. Set Session ID +### 5.5. Set Session ID -**Real-time session association for linking detection events to user sessions.** - -Allows the backend to instruct the worker to associate a session ID with a display. This enables linking detection events to specific user sessions. The system automatically propagates session changes across all worker processes via Redis pub/sub. +Allows the backend to instruct the worker to associate a session ID with a subscription. This is useful for linking detection events to a specific session. The session ID can be `null` to indicate no active session. - **Type:** `setSessionId` @@ -451,94 +339,11 @@ Or to clear the session: } ``` -**Session Management Flow:** +> **Note:** +> +> - The worker should store the session ID for the given subscription and use it in subsequent detection or patch messages as appropriate. If `sessionId` is `null`, the worker should treat the subscription as having no active session. -1. **Session Creation**: When a new session is created (user interaction), the backend immediately sends `setSessionId` to all relevant workers -2. **Cross-Process Distribution**: The command is distributed across multiple backend processes via Redis `worker:commands` channel -3. **Worker State Synchronization**: Workers maintain session IDs for each display and apply them to all matching subscriptions -4. **Automatic Recovery**: Session IDs are restored when workers reconnect, ensuring no session context is lost -5. **Multi-Subscription Support**: A single session ID applies to all camera subscriptions for the given display - -**Worker Responsibility:** - -- Store the session ID for the given `displayIdentifier` -- Apply the session ID to **all active subscriptions** that start with `displayIdentifier;` (e.g., `display-001;cam-001`, `display-001;cam-002`) -- Include the session ID in subsequent `imageDetection` and `patchSession` messages -- Handle session clearing when `sessionId` is `null` -- Restore session IDs from backend state after reconnection - -**Multi-Process Coordination:** - -The session ID command uses the distributed worker communication system: - -- Commands are routed via Redis pub/sub to the process managing the target worker -- Automatic failover ensures session updates reach workers even during process changes -- Lease-based worker ownership prevents duplicate session notifications - -### 7.5. Set Progression Stage - -**Real-time progression stage synchronization for dynamic content adaptation.** - -Notifies workers about the current progression stage of a display, enabling context-aware content selection and detection behavior. The system automatically tracks stage changes and avoids redundant updates. - -- **Type:** `setProgressionStage` - -**Payload:** - -```json -{ - "type": "setProgressionStage", - "payload": { - "displayIdentifier": "display-001", - "progressionStage": "car_fueling" - } -} -``` - -Or to clear the progression stage: - -```json -{ - "type": "setProgressionStage", - "payload": { - "displayIdentifier": "display-001", - "progressionStage": null - } -} -``` - -**Available Progression Stages:** - -- `"welcome"` - Initial state, awaiting user interaction -- `"car_fueling"` - Vehicle is actively fueling -- `"car_waitpayment"` - Fueling complete, awaiting payment -- `"car_postpayment"` - Payment completed, transaction finishing -- `null` - No active progression stage - -**Progression Stage Flow:** - -1. **Automatic Detection**: Display controllers automatically detect progression stage changes based on display persistent data -2. **Change Filtering**: The system compares current stage with last sent stage to avoid redundant updates -3. **Instant Propagation**: Stage changes are immediately sent to all workers associated with the display -4. **Cross-Process Distribution**: Commands are distributed via Redis `worker:commands` channel to all backend processes -5. **State Recovery**: Progression stages are restored when workers reconnect - -**Worker Responsibility:** - -- Store the progression stage for the given `displayIdentifier` -- Apply the stage to **all active subscriptions** for that display -- Use progression stage for context-aware detection and content adaptation -- Handle stage clearing when `progressionStage` is `null` -- Restore progression stages from backend state after reconnection - -**Use Cases:** - -- **Fuel Station Displays**: Adapt content based on fueling progress (welcome ads vs. payment prompts) -- **Dynamic Detection**: Adjust detection sensitivity based on interaction stage -- **Content Personalization**: Select appropriate advertisements for current user journey stage -- **Analytics**: Track user progression through interaction stages - -## 8. Subscription Identifier Format +## Subscription Identifier Format The `subscriptionIdentifier` used in all messages is constructed as: @@ -557,14 +362,14 @@ When the backend sends a `setSessionId` command, it will only provide the `displ - The worker must match the `displayIdentifier` to all active subscriptions for that display (i.e., all `subscriptionIdentifier` values that start with `displayIdentifier;`). - The worker should set or clear the session ID for all matching subscriptions. -## 9. Example Communication Log +## 6. Example Communication Log -This section shows a typical sequence of messages between the backend and the worker, including the new declarative subscription model, session ID management, and progression stage synchronization. +This section shows a typical sequence of messages between the backend and the worker. Patch messages are not included, as they are only used when the worker cannot keep up. -> **Note:** Unsubscribe is triggered during load rebalancing or when displays/cameras are removed from the system. The system automatically handles worker reconnection with full state recovery. +> **Note:** Unsubscribe is triggered when a user removes a camera or when the node is too heavily loaded and needs rebalancing. 1. **Connection Established** & **Heartbeat** - - **Worker -> Backend** + * **Worker -> Backend** ```json { "type": "stateReport", @@ -575,25 +380,22 @@ This section shows a typical sequence of messages between the backend and the wo "cameraConnections": [] } ``` -2. **Backend Sets Subscription List** - - **Backend -> Worker** +2. **Backend Subscribes Camera** + * **Backend -> Worker** ```json { - "type": "setSubscriptionList", - "subscriptions": [ - { - "subscriptionIdentifier": "display-001;entry-cam-01", - "rtspUrl": "rtsp://192.168.1.100/stream1", - "modelUrl": "http://storage/models/vehicle-id.mpta?token=fresh-token", - "modelName": "Vehicle Identification", - "modelId": 201, - "snapshotInterval": 5000 - } - ] + "type": "subscribe", + "payload": { + "subscriptionIdentifier": "display-001;entry-cam-01", + "rtspUrl": "rtsp://192.168.1.100/stream1", + "modelUrl": "http://storage/models/vehicle-id.mpta", + "modelName": "Vehicle Identification", + "modelId": 201 + } } ``` -3. **Worker Acknowledges with Reconciled State** - - **Worker -> Backend** +3. **Worker Acknowledges in Heartbeat** + * **Worker -> Backend** ```json { "type": "stateReport", @@ -611,44 +413,13 @@ This section shows a typical sequence of messages between the backend and the wo ] } ``` -4. **Backend Sets Session ID** - - - **Backend -> Worker** - - ```json - { - "type": "setSessionId", - "payload": { - "displayIdentifier": "display-001", - "sessionId": 12345 - } - } - ``` - -5. **Backend Sets Progression Stage** - - - **Backend -> Worker** - - ```json - { - "type": "setProgressionStage", - "payload": { - "displayIdentifier": "display-001", - "progressionStage": "welcome" - } - } - ``` - -6. **Worker Detects a Car with Session Context** - - - **Worker -> Backend** - +4. **Worker Detects a Car** + * **Worker -> Backend** ```json { "type": "imageDetection", "subscriptionIdentifier": "display-001;entry-cam-01", "timestamp": "2025-07-15T10:00:00.000Z", - "sessionId": 12345, "data": { "detection": { "carBrand": "Honda", @@ -662,89 +433,56 @@ This section shows a typical sequence of messages between the backend and the wo } } ``` - -7. **Progression Stage Change** - - - **Backend -> Worker** - + * **Worker -> Backend** ```json { - "type": "setProgressionStage", - "payload": { - "displayIdentifier": "display-001", - "progressionStage": "car_fueling" + "type": "imageDetection", + "subscriptionIdentifier": "display-001;entry-cam-01", + "timestamp": "2025-07-15T10:00:01.000Z", + "data": { + "detection": { + "carBrand": "Toyota", + "carModel": "Corolla", + "bodyType": "Sedan", + "licensePlateText": "CMS-1234", + "licensePlateConfidence": 0.97 + }, + "modelId": 201, + "modelName": "Vehicle Identification" } } ``` - -8. **Worker Reconnection with State Recovery** - - - **Worker Disconnects and Reconnects** - - **Worker -> Backend** (Immediate heartbeat after reconnection) - + * **Worker -> Backend** ```json { - "type": "stateReport", - "cpuUsage": 70.0, - "memoryUsage": 38.0, - "gpuUsage": 55.0, - "gpuMemoryUsage": 20.0, - "cameraConnections": [] - } - ``` - - - **Backend -> Worker** (Automatic subscription list restoration with fresh model URLs) - - ```json - { - "type": "setSubscriptionList", - "subscriptions": [ - { - "subscriptionIdentifier": "display-001;entry-cam-01", - "rtspUrl": "rtsp://192.168.1.100/stream1", - "modelUrl": "http://storage/models/vehicle-id.mpta?token=fresh-reconnect-token", - "modelName": "Vehicle Identification", - "modelId": 201, - "snapshotInterval": 5000 - } - ] - } - ``` - - - **Backend -> Worker** (Session ID recovery) - - ```json - { - "type": "setSessionId", - "payload": { - "displayIdentifier": "display-001", - "sessionId": 12345 + "type": "imageDetection", + "subscriptionIdentifier": "display-001;entry-cam-01", + "timestamp": "2025-07-15T10:00:02.000Z", + "data": { + "detection": { + "carBrand": "Ford", + "carModel": "Focus", + "bodyType": "Hatchback", + "licensePlateText": "CMS-5678", + "licensePlateConfidence": 0.96 + }, + "modelId": 201, + "modelName": "Vehicle Identification" } } ``` - - - **Backend -> Worker** (Progression stage recovery) - +5. **Backend Unsubscribes Camera** + * **Backend -> Worker** ```json { - "type": "setProgressionStage", + "type": "unsubscribe", "payload": { - "displayIdentifier": "display-001", - "progressionStage": "car_fueling" + "subscriptionIdentifier": "display-001;entry-cam-01" } } ``` - -9. **Backend Updates Subscription List** (Load balancing or system cleanup) - - **Backend -> Worker** (Empty list removes all subscriptions) - ```json - { - "type": "setSubscriptionList", - "subscriptions": [] - } - ``` -10. **Worker Acknowledges Subscription Removal** - - **Worker -> Backend** (Updated heartbeat showing no active connections after reconciliation) +6. **Worker Acknowledges Unsubscription** + * **Worker -> Backend** ```json { "type": "stateReport", @@ -755,18 +493,7 @@ This section shows a typical sequence of messages between the backend and the wo "cameraConnections": [] } ``` - -**Key Improvements in Communication Flow:** - -1. **Fully Declarative Subscriptions**: Complete subscription list sent in single command, worker handles reconciliation -2. **Worker-Side Reconciliation**: Workers compare desired vs. current state and make necessary changes internally -3. **Session Context**: All detection events include session IDs for proper user linking -4. **Progression Stages**: Real-time stage updates enable context-aware content selection -5. **State Recovery**: Complete automatic recovery of subscription lists, session IDs, and progression stages -6. **Fresh Model URLs**: S3 URL expiration is handled transparently with 1-hour TTL tokens -7. **Load Balancing**: Backend intelligently distributes complete subscription lists across available workers - -## 10. HTTP API: Image Retrieval +## 7. HTTP API: Image Retrieval In addition to the WebSocket protocol, the worker exposes an HTTP endpoint for retrieving the latest image frame from a camera. @@ -781,13 +508,11 @@ GET /camera/{camera_id}/image ### Response - **Success (200):** Returns the latest JPEG image from the camera stream. - - - `Content-Type: image/jpeg` - - Binary JPEG data. + - `Content-Type: image/jpeg` + - Binary JPEG data. - **Error (404):** If the camera is not found or no frame is available. - - - JSON error response. + - JSON error response. - **Error (500):** Internal server error. @@ -800,9 +525,9 @@ GET /camera/display-001;cam-001/image ### Example Response - **Headers:** - ``` - Content-Type: image/jpeg - ``` + ``` + Content-Type: image/jpeg + ``` - **Body:** Binary JPEG image. ### Notes