zhengchen1999 / DAT

PyTorch code for our ICCV 2023 paper "Dual Aggregation Transformer for Image Super-Resolution"
Apache License 2.0
350 stars 27 forks source link

Load model via dat_arch.py #4

Open Julesaiyy opened 10 months ago

Julesaiyy commented 10 months ago

Thank you so much for your wonderful work. I am having a problem, I want to use your pre-trained model to load via dat_arch.py and then input a real life image, but the loading of the model fails, do I need to install basicsr for such a process as well, and I would like to ask if there is a solution that does not require basicsr to be installed?

zhengchen1999 commented 10 months ago

Hi. Thanks for your interest in our work.

You can modify dat_arch.py: Comment out line 17 (from basicsr.utils.registry import ARCH_REGISTRY) and line 700 (@ARCH_REGISTRY.register()). Then you can perform dat_arch directly. The following code can load the pre-trained model (e.g., DAT_x2.pth).

if __name__ == '__main__':
    model = DAT(
        upscale=2,
        in_chans=3,
        img_size=64,
        img_range=1.,
        depth=[6,6,6,6,6,6],
        embed_dim=180,
        num_heads=[6,6,6,6,6,6],
        expansion_factor=4,
        resi_connection='1conv',
        split_size=[8,32],
                )

    state_dict = torch.load('experiments/pretrained_models/DAT/DAT_x2.pth')['params']
    model.load_state_dict(state_dict)

The loading of other pre-trained models is similar. But you need to process the input image, conforming to the input requirements (e.g., img_range=1).

If you have any other problem, please let us know. Thanks.

Julesaiyy commented 10 months ago

Thanks so much. I got it. For example, if I have a 256×256 image as input, how do I preprocess it?