open-mmlab / mmsegmentation

OpenMMLab Semantic Segmentation Toolbox and Benchmark.
https://mmsegmentation.readthedocs.io/en/main/
Apache License 2.0
8.25k stars 2.61k forks source link

the default pretrained vit ckpt loading fails due to umatched key names #1288

Closed Tsingularity closed 2 years ago

Tsingularity commented 2 years ago

Describe the bug When running the default training script for SETR, it would initialize the backbone vit model with the default downloaded vit-large ckpt from mmcls (mmcls://vit_large_p16). However, it would report a lot of missing keys during the model initialization: image

In order to investigate what's going on, I create a notebook and check the downloaded keys and the model's state_dict keys.

Here's the list of the downloaded vit ckpt keys: image

Here's the key list of model's state_dict(): image

And clearly you could see an unmatch issue for all the attention weights.

Looks like this is due to the unmatch implementation of multiheadattention in mmcls and mmseg:

For mmcls, this is their codes for multiheadattention: https://github.com/open-mmlab/mmclassification/blob/a7f8e96b31c10ab3e9c133293ca406e6e548475b/mmcls/models/utils/attention.py#L294

For mmseg, this is their codes for multiheadattention: https://github.com/open-mmlab/mmcv/blob/5de2b130d37301432ecbe0a51c31ef979b3d7a26/mmcv/cnn/bricks/transformer.py#L407

So could you please take a look at this and fix it? Feel free to let me know if there's anything I can help with on my end.

Thanks!

Reproduction

  1. What command or script did you run?

    ./tools/dist_train.sh configs/setr/setr_pup_512x512_160k_b16_ade20k.py 1
  2. Did you make any modifications on the code or config? Did you understand what you have modified? No.

  3. What dataset did you use? ADE20K.

Environment

sys.platform: linux Python: 3.7.11 (default, Jul 27 2021, 14:32:16) [GCC 7.5.0] CUDA available: True GPU 0: NVIDIA RTX A6000 CUDA_HOME: /usr/local/cuda NVCC: Build cuda_11.2.r11.2/compiler.29558016_0 GCC: gcc (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0 PyTorch: 1.10.2 PyTorch compiling details: PyTorch built with:

TorchVision: 0.11.3 OpenCV: 4.5.5 MMCV: 1.4.4 MMCV Compiler: GCC 7.3 MMCV CUDA Compiler: 11.3 MMSegmentation: 0.21.1+b7a3f72

Error traceback

No error actually (due to non-restrick load_state_dict). but clearly the vit model is not loading correctly.

Bug fix

see above.

MengzhangLI commented 2 years ago

Sorry for inconvenient usage and thanks for your error report. I think it is codebase itself bug.

In this case, could you try to use init_cfg=dict(type='Pretrained', checkpoint='mmcls://vit_large_p16', prefix='backbone.')?

The error is caused by prefix backbone., you can check MMCV here.

Look forward to your reply.

Best,

Tsingularity commented 2 years ago

Hi Menzhang,

Thanks for ur quick response!

But I think it's not the problem of 'backbone' prefix here. Sorry I didn't make this clear enough. (let me try again below lol)

If you take another look at the warning message in figure 1, u'll notice the params names for multiheadattention are different: the pretrained ckpt uses qkv.weight, qkv.bias, proj.weight, proj.bias, and the defined vit model uses attn.in_proj_weight, attn.in_proj_bias, attn.out_proj.weight, attn.out_proj.bias.

i think the current code already handles the 'backbone' prefix issue pretty well because as u can see in the warning message, there's no 'backbone.' in it. sorry i didn't get rid of it in my own demonstration jupyter code just for simplicity.

Therefore, the current codebase is basically training SETR from scratch because only loading a small amount of pre-trained vit weights.

So just wondering could u take a look at this issue and investigate when and how this unmatch happened? And how could we fix it?

Thanks!

MengzhangLI commented 2 years ago

Yes, because before this month we seldom use mmcls pretrained model link (actually we usually use downloaded .pth file and load it directly).

There must be some potential bugs lying in currently usage method. I would try to check out this issue as soon as possible.

Tsingularity commented 2 years ago

cool thanks!

just wondering, for now, could u please provide a correct version of imagenet21k pre-trained vit-large ckpt such that we can use for setr's training? Thanks!

MengzhangLI commented 2 years ago

Hi, @Tsingularity We have checked the problem, it is caused by our codebase: the pretrained model of ViT is different (i.e., the link of ViT pretrained model from MMClassification and the model our MMSegmentation wanting to use is DIFFERENT.)

For quick use of SETR, you can use our old version: follow this pr to find download links of original pretrained model and convert its keys. Then change SETR config to use it.

For example, current init cfg is:

init_cfg=dict(type='Pretrained', checkpoint='mmcls://vit_large_p16')

you can find old model link in history version:

    pretrained=\
    'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',  # noqa

And convert this key by using:

python ./tools/model_converters/vit2mmseg.py ${SRC} ${DST} 

NOTE: This ViT pretrained model is older version from Google, Google has released a new version of pretrained model which have better performance in downstream tasks.

If you want to use new version pretrained model from Google, you can follow the doc from Segmenter. Theoretically, its transferred model like vit_tiny_p16_384.pth could also be used for SETR. You can have a try.

To sum up: you can check out history version of setr-related file to solve problems.

By the way, for long-term use of OpenMMLab, we would update our codebase to fix these potential bugs and errors. Downstream codebase such as MMDetection, MMSegmentation and others would support MMClassification and use its abundant backbone model.

Best.

MengzhangLI commented 2 years ago

Let me try to use vit-large pretrained model from Google:

(1) Old Jax model:

Download: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth

Transfer key:

python tools/model_converters/vit2mmseg.py \
jx_vit_large_p16_384-b3be5167.pth pretrain/vit_large_p16_384.pth

Change config file: https://github.com/open-mmlab/mmsegmentation/blob/master/configs/setr/setr_vit-large_pup_8x1_768x768_80k_cityscapes.py#L9-L12. Just make sure it use pretrained='pretrain/vit_large_p16_384.pth', correctly.

(2) New JAX model

Download:

https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz

Transfer key:

python tools/model_converters/vitjax2mmseg.py \
Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz \
pretrain/vit_large_p16_384.pth

Change config file: https://github.com/open-mmlab/mmsegmentation/blob/master/configs/setr/setr_vit-large_pup_8x1_768x768_80k_cityscapes.py#L9-L12. Just make sure it use pretrained='pretrain/vit_large_p16_384.pth', correctly.

RockeyCoss commented 2 years ago

We will fix this problem in this pr,

Tsingularity commented 2 years ago

sorry for the late response. and huge thanks for the quick response! I'll take a look and see how it works now