Ascend-Research / CascadedGaze

The official PyTorch implementation for CascadedGaze: Efficiency in Global Context Extraction for Image Restoration, TMLR'24.
MIT License
14 stars 0 forks source link
deblurring denoising efficiency efficient-attention image-restoration transformer

# CascadedGaze: Efficiency in Global Context Extraction for Image Restoration

PWC

The official PyTorch implementation of the paper

CascadedGaze: Efficiency in Global Context Extraction for Image Restoration \ Amirhosein Ghasemabadi, Muhammad Kamran Janjua, Mohammad Salameh, Chunhua Zhou, Fengyu Sun, Di Niu\ Accepted at Transactions on Machine Learning Research (TMLR), 2024.

Installation

This implementation is based on BasicSR which is an open-source toolbox for image/video restoration tasks, NAFNet, Restormer and Multi Output Deblur

python 3.9.5
pytorch 1.11.0
cuda 11.3
pip install -r requirements.txt
python setup.py develop --no_cuda_ext

Quick Start

We have provided demo-denoising.ipynb to show how to load images from the validation dataset and use the model to restore images.

CascadedGaze implementation

The implementation of our proposed CascadedGaze Net, CascadedGaze block, and the Global Context Extractor module can be found in /CascadedGaze/basicsr/models/archs/CGNet_arch.py

The implementation of Multi-Head CascadedGaze Net can be found in /CascadedGaze/basicsr/models/archs/CGNetMultiHead_arch.py

Denoising on SIDD

1. Data Preparation

Download the train set(from the SIDD dataset website) and place it in ./datasets/SIDD/Data/,
Download the evaluation data in lmdb format (from the Gopro dataset website) and place it in ./datasets/SIDD/test/:

After downloading, it should be like this:

./datasets/
└── SIDD/
    ├── Data/
    │   ├── 0001
    │   │   ├── GT_SRGB.PNG
    │   │   ├── NOISY_SRGB.PNG
    │   │   ....
    │   └── 0200
    │       ├── GT_SRGB.PNG
    │       ├── NOISY_SRGB.PNG    
    ├── train/
    └── test/
        ├── input.imdb
        └── target.imdb

2. Training

python -m torch.distributed.launch --nproc_per_node=8 --master_port=8081 basicsr/train.py -opt options/train/SIDD/CascadedGaze-SIDD.yml --launcher pytorch

3. Evaluation

Note: Due to the file size limitation, we are not able to share the pre-trained models in this code submission. However, they will be provided with an open-source release of the code.

Testing the model

4. Model complexity and inference speed

Gaussian Image denoising

1. Data Preparation

Clone the Restormer's github project and follow their instructions the download the train and test datasets.

2. Training

To train the CascadedGaze model follow these steps:

3. Evaluation

Note: Pretrained models will be released soon.

Testing the model

4. Model complexity and inference speed

Deblurring on GoPro

1. Data Preparation

Download the train set(from the Gopro dataset website) and place it in ./datasets/GoPro/train,
Download the evaluation data in lmdb format (from the Gopro dataset website) and place it in ./datasets/GoPro/test/:

After downloading, it should be like this:

./datasets/
└── GoPro/
    ├── train/
    │   ├── input/
    │   └── target/
    └── test/
        ├── input.imdb
        └── target.imdb

2. Training

python -m torch.distributed.launch --nproc_per_node=8 --master_port=8081 basicsr/train.py -opt options/train/GoPro/CascadedGazeMH-GoPro.yml --launcher pytorch

3. Evaluation

Note: Pretrained models will be released soon.

Testing the model

4. Model complexity and inference speed

Visualizing the training logs

Citation

If you use CascadedGaze, or this codebase in your work, please consider citing this work:

@article{
ghasemabadi2024cascadedgaze,
title={CascadedGaze: Efficiency in Global Context Extraction for Image Restoration},
author={Amirhosein Ghasemabadi and Muhammad Kamran Janjua and Mohammad Salameh and CHUNHUA ZHOU and Fengyu Sun and Di Niu},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2024},
url={https://openreview.net/forum?id=C3FXHxMVuq},
note={}
}