Pytorch Implementation of "Multi-Stage Partitioned Transformer for Efficient Image Deraining"
Images shot outdoors may capture rain, which can be troublesome to view the clean scene and significantly degrade their visual quality. Since rain scenes vary due to rain's density and wind directions, removing rain streaks from a rainy image is difficult. Thanks to the recent success of transformers in vision tasks, we propose a novel Multi-stage Partitioned Transformer (MPT) specifically for image deraining. MPT separates the attention module and multi-layer perceptron (MLP) to decompose the rain layer and the clean background from a rainy image. It utilizes the proposed global and local rain-aware attention mechanism to estimate the rain layer. In addition, we add atrous convolutions to MLP to aggregate contextualized background features to produce a clean background at multiple stages. MPT is a parameter-economical and computationally efficient deraining model that can effectively remove rain streaks from the input rainy image. Experimental results demonstrate that the proposed MPT performs favorably against state-of-the-art models in image deraining on benchmark datasets.
SPA-Data: 28500 training pairs, 1000 testing pairs [paper][dataset] (2019 CVPR)
For example on Rain100L: './data/Rain100L'
./data/Rain100L
+--- train
| +--- norain
| +--- rain
|
+--- test
| +--- norain
| +--- rain
The implementation is modified from "RCDNet_simple"
git clone https://github.com/WENYICAT/MPT.git
cd MPT
conda create -n Stripformer python=3.8
source activate MPT
conda install pytorch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 cudatoolkit=11.4 -c pytorch -c conda-forge
pip install opencv-python tqdm ptflops glog scikit-image tensorboardX torchsummary
*taking training on Rain100L (200 training pairs) as an example, then unzip to ./data. the unzipped file is like:
data_path = r"./data/Rain100L/train/rain/rain-\*.png"
gt_path = r"./data/Rain100L/train/norain/norain-\*.png"
Note that if using other datasets, please change the file organization as this.
$ python -m torch.distributed.launch --nproc_per_node=2 --master_port=25911 train_main_syn_parallel.py --use_gpu="0,1" --batchSize=12 --resume=-1 --model_dir="./checkpoints/Rain100L/"
$ python -m torch.distributed.launch --nproc_per_node=1 --master_port=25911 test.py --use_gpu="0" --model_dir="./checkpoints/Rain100L/" --save_path="./results/Rain100L/"
The pre-trained are place it in ./weights/
, and modified the content is just like train_main_syn_parallel.py --resume=1