Zheng Chen, Yulun Zhang, Jinjin Gu, Yongbing Zhang, Linghe Kong, and Xin Yuan, "Cross Aggregation Transformer for Image Restoration", NeurIPS, 2022 (Spotlight)
[paper] [arXiv] [supplementary material] [visual results] [pretrained models]
Abstract: Recently, Transformer architecture has been introduced into image restoration to replace convolution neural network (CNN) with surprising results. Considering the high computational complexity of Transformer with global attention, some methods use the local square window to limit the scope of self-attention. However, these methods lack direct interaction among different windows, which limits the establishment of long-range dependencies. To address the above issue, we propose a new image restoration model, Cross Aggregation Transformer (CAT). The core of our CAT is the Rectangle-Window Self-Attention (Rwin-SA), which utilizes horizontal and vertical rectangle window attention in different heads parallelly to expand the attention area and aggregate the features cross different windows. We also introduce the Axial-Shift operation for different window interactions. Furthermore, we propose the Locality Complementary Module to complement the self-attention mechanism, which incorporates the inductive bias of CNN (e.g., translation invariance and locality) into Transformer, enabling global-local coupling. Extensive experiments demonstrate that our CAT outperforms recent state-of-the-art methods on several image restoration applications.
SR (x4) | HQ | LQ | SwinIR | CAT (ours) |
---|---|---|---|---|
# Clone the github repo and go to the default directory 'CAT'.
git clone https://github.com/zhengchen1999/CAT.git
conda create -n CAT python=3.8
conda activate CAT
pip install -r requirements.txt
python setup.py develop
Used training and testing sets can be downloaded as follows:
Task | Training Set | Testing Set | Visual Results |
---|---|---|---|
image SR | DIV2K (800 training images, 100 validation images) + Flickr2K (2650 images) [complete training dataset DF2K] | Set5 + Set14 + BSD100 + Urban100 + Manga109 [complete testing dataset download] | here |
grayscale JPEG compression artifact reduction | DIV2K (800 training images) + Flickr2K (2650 images) + WED(4744 images) + BSD500 (400 training&testing images) [complete training dataset DFWB] | Classic5 +LIVE + Urban100 [complete testing dataset download] | here |
real image denoising | SIDD (320 training images) [complete training dataset SIDD] | SIDD + DND [complete testing dataset download] | here |
Here the visual results are generated under SR (x4), JPEG compression artifact reduction (q10), and real image denoising.
Download training and testing datasets and put them into the corresponding folders of datasets/
and restormer/datasets
. See datasets for the detail of directory structure.
Task | Method | Params (M) | FLOPs (G) | Dataset | PSNR (dB) | SSIM | Model Zoo | Visual Results |
---|---|---|---|---|---|---|---|---|
SR | CAT-R | 16.60 | 292.7 | Urban100 | 27.45 | 0.8254 | Google Drive | Google Drive |
SR | CAT-A | 16.60 | 360.7 | Urban100 | 27.89 | 0.8339 | Google Drive | Google Drive |
SR | CAT-R-2 | 11.93 | 216.3 | Urban100 | 27.59 | 0.8285 | Google Drive | Google Drive |
SR | CAT-A-2 | 16.60 | 387.9 | Urban100 | 27.99 | 0.8357 | Google Drive | Google Drive |
CAR | CAT | 16.20 | 346.4 | LIVE1 | 29.89 | 0.8295 | Google Drive | Google Drive |
real-DN | CAT | 25.77 | 53.2 | SIDD | 40.01 | 0.9600 | Google Drive | Google Drive |
The performance is reported on Urban100 (x4, SR), LIVE1 (q=10, CAR), and SIDD (real-DN). The test input size of FLOPs is 128 x 128.
Cd to 'CAT' and run the setup script.
# If already in CAT and set up, please ignore
python setup.py develop
Download training (DF2K, already processed) and testing (Set5, Set14, BSD100, Urban100, Manga109, already processed) datasets, place them in datasets/
.
Run the following scripts. The training configuration is in options/train/
.
# CAT-R, SR, input=64x64, 4 GPUs
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_R_sr_x2.yml --launcher pytorch
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_R_sr_x3.yml --launcher pytorch
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_R_sr_x4.yml --launcher pytorch
# CAT-A, SR, input=64x64, 4 GPUs
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_A_sr_x2.yml --launcher pytorch
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_A_sr_x3.yml --launcher pytorch
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_A_sr_x4.yml --launcher pytorch
# CAT-R-2, SR, input=64x64, 4 GPUs
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_R_2_sr_x2.yml --launcher pytorch
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_R_2_sr_x3.yml --launcher pytorch
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_R_2_sr_x4.yml --launcher pytorch
# CAT-A-2, SR, input=64x64, 4 GPUs
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_A_2_sr_x2.yml --launcher pytorch
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_A_2_sr_x3.yml --launcher pytorch
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_A_2_sr_x4.yml --launcher pytorch
The training experiment is in experiments/
.
Cd to 'CAT' and run the setup script
# If already in CAT and set up, please ignore
python setup.py develop
Download training (DFWB, already processed) and testing (Classic5, LIVE1, Urban100, already processed) datasets, place them in datasets/
.
Run the following scripts. The training configuration is in options/train/
.
# CAT, CAR, input=128x128, 4 GPUs
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_car_q10.yml --launcher pytorch
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_car_q20.yml --launcher pytorch
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_car_q30.yml --launcher pytorch
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_car_q40.yml --launcher pytorch
The training experiment is in experiments/
.
Cd to 'CAT/restormer' and run the setup script
# If already in restormer and set up, please ignore
python setup.py develop --no_cuda_ext
Download training (SIDD-train, contains validation dataset, already processed) datasets, and place them in datasets/
(restormer/datasets/
).
Run the following scripts. The training configuration is in options/
(restormer/options/
).
# CAT, Real DN, Progressive Learning, 8 GPUs
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train_RealDenoising_CAT.yml --launcher pytorch
The training experiment is in experiments/
(restormer/experiments/
).
Cd to 'CAT' and run the setup script
# If already in CAT and set up, please ignore
python setup.py develop
Download the pre-trained models and place them in experiments/pretrained_models/
.
We provide pre-trained models for image SR: CAT-R, CAT-A, CAT-A, and CAT-R-2 (x2, x3, x4).
Download testing (Set5, Set14, BSD100, Urban100, Manga109, already processed) datasets, place them in datasets/
.
Run the following scripts. The testing configuration is in options/test/
(e.g., test_CAT_R_sr_x2.yml).
Note 1: You can set use_chop: True
(default: False) in YML to chop the image for testing.
# No self-ensemble
# CAT-R, SR, reproduces results in Table 2 of the main paper
python basicsr/test.py -opt options/test/test_CAT_R_sr_x2.yml
python basicsr/test.py -opt options/test/test_CAT_R_sr_x3.yml
python basicsr/test.py -opt options/test/test_CAT_R_sr_x4.yml
# CAT-A, SR, reproduces results in Table 2 of the main paper
python basicsr/test.py -opt options/test/test_CAT_A_sr_x2.yml
python basicsr/test.py -opt options/test/test_CAT_A_sr_x3.yml
python basicsr/test.py -opt options/test/test_CAT_A_sr_x4.yml
# CAT-R-2, SR, reproduces results in Table 1 of the supplementary material
python basicsr/test.py -opt options/test/test_CAT_R_2_sr_x2.yml
python basicsr/test.py -opt options/test/test_CAT_R_2_sr_x3.yml
python basicsr/test.py -opt options/test/test_CAT_R_2_sr_x4.yml
# CAT-A-2, SR, reproduces results in Table 1 of the supplementary material
python basicsr/test.py -opt options/test/test_CAT_A_2_sr_x2.yml
python basicsr/test.py -opt options/test/test_CAT_A_2_sr_x3.yml
python basicsr/test.py -opt options/test/test_CAT_A_2_sr_x4.yml
The output is in results/
.
Cd to 'CAT' and run the setup script
# If already in CAT and set up, please ignore
python setup.py develop
Download the pre-trained models and place them in experiments/pretrained_models/
.
We provide pre-trained models for JPEG compression artifact reduction: CAT (q10, q20, q30, q40).
Download testing (Classic5, LIVE, Urban100, already processed) datasets, place them in datasets/
.
Run the following scripts. The testing configuration is in options/test/
(e.g., test_CAT_car_q10.yml).
# No self-ensemble
# CAT-A, CAR, rereproduces results in Table 3 of the main paper
python basicsr/test.py -opt options/test/test_CAT_car_q10.yml
python basicsr/test.py -opt options/test/test_CAT_car_q20.yml
python basicsr/test.py -opt options/test/test_CAT_car_q30.yml
python basicsr/test.py -opt options/test/test_CAT_car_q40.yml
The output is in results/
.
Cd to 'CAT' and run the setup script
# If already in CAT and set up, please ignore
python setup.py develop
Download the pre-trained models and place them in experiments/pretrained_models/
.
Download testing (SIDD, DND) datasets, place them in datasets/
.
Run the following scripts. The testing configuration is in options/test/
.
# No self-ensemble
# CAT, real DN, reproduces results in Table 4 of the main paper
# testing on SIDD
python test_real_denoising_sidd.py --save_images
evaluate_sidd.m
# testing on DND
python test_real_denoising_dnd.py --save_images
The output is in results/
.
We achieved state-of-the-art performance on image SR, JPEG compression artifact reduction and real image denoising. Detailed results can be found in the paper. All visual results of CAT can be downloaded here.
- results in Table 1 of the supplementary material
- visual comparison (x4) in the main paper
- visual comparison (x4) in the supplementary material
- results in Table 3 of the supplementary material (test on **Urban100**)
- visual comparison (q=10) in the main paper
- visual comparison (q=10) in the supplementary material
*: We re-test the SIDD with all official pre-trained models.
If you find the code helpful in your resarch or work, please cite the following paper(s).
@inproceedings{chen2022cross,
title={Cross Aggregation Transformer for Image Restoration},
author={Chen, Zheng and Zhang, Yulun and Gu, Jinjin and Zhang, Yongbing and Kong, Linghe and Yuan, Xin},
booktitle={NeurIPS},
year={2022}
}