ytpeng-aimlab / Multi-Stage-Partitioned-Transformer-for-Efficient-Image-Deraining

14 stars 0 forks source link

Multi-stage Partitioned Transformer (MPT)

Pytorch Implementation of "Multi-Stage Partitioned Transformer for Efficient Image Deraining"

Introduction

Images shot outdoors may capture rain, which can be troublesome to view the clean scene and significantly degrade their visual quality. Since rain scenes vary due to rain's density and wind directions, removing rain streaks from a rainy image is difficult. Thanks to the recent success of transformers in vision tasks, we propose a novel Multi-stage Partitioned Transformer (MPT) specifically for image deraining. MPT separates the attention module and multi-layer perceptron (MLP) to decompose the rain layer and the clean background from a rainy image. It utilizes the proposed global and local rain-aware attention mechanism to estimate the rain layer. In addition, we add atrous convolutions to MLP to aggregate contextualized background features to produce a clean background at multiple stages. MPT is a parameter-economical and computationally efficient deraining model that can effectively remove rain streaks from the input rainy image. Experimental results demonstrate that the proposed MPT performs favorably against state-of-the-art models in image deraining on benchmark datasets.

Network Architecture of MPT

An architecture overview of the Multi-stage Partitioned Transformer (MPT).

Architecture of Partitioned Transformer Block (PTB).

Performance comparison on the five test dataset in terms of deraining quality and model size (number of parameters in million). (click to expand)

Dataset Descriptions

Synthetic datasets

Real-world dataset

./data/Rain100L
+--- train
|   +--- norain
|   +--- rain
|
+--- test
|   +--- norain
|   +--- rain

Quality Metrics

All PSNR and SSIM results are computed based on Y channel of YCbCr space.

Installation

The implementation is modified from "RCDNet_simple"

git clone https://github.com/WENYICAT/MPT.git
cd MPT
conda create -n Stripformer python=3.8
source activate MPT
conda install pytorch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 cudatoolkit=11.4 -c pytorch -c conda-forge
pip install opencv-python tqdm ptflops glog scikit-image tensorboardX torchsummary

Training

*taking training on Rain100L (200 training pairs) as an example, then unzip to ./data. the unzipped file is like:

 data_path = r"./data/Rain100L/train/rain/rain-\*.png"
 gt_path =  r"./data/Rain100L/train/norain/norain-\*.png"

Note that if using other datasets, please change the file organization as this.

Training

$ python -m torch.distributed.launch --nproc_per_node=2 --master_port=25911 train_main_syn_parallel.py --use_gpu="0,1" --batchSize=12 --resume=-1 --model_dir="./checkpoints/Rain100L/"

Testing

$ python -m torch.distributed.launch --nproc_per_node=1 --master_port=25911 test.py --use_gpu="0" --model_dir="./checkpoints/Rain100L/" --save_path="./results/Rain100L/"

Pretrained Model

The pre-trained are place it in ./weights/, and modified the content is just like train_main_syn_parallel.py --resume=1

Deraining Results

Deraining quantitative comparison on the synthetic dataset (click to expand)
Deraining quantitative comparison on the real-world dataset (click to expand)
Image Deraining on Rain100L (click to expand)
Image Deraining on Rain100H (click to expand)
Image Deraining on Rain800 (click to expand)
Image Deraining on Rain1400 (click to expand)
Image Deraining on Rain1200 (click to expand)
Image Deraining on Real world SPA-Data (click to expand)

Citation