Closed Tsingularity closed 2 years ago
Sorry for inconvenient usage and thanks for your error report. I think it is codebase itself bug.
In this case, could you try to use init_cfg=dict(type='Pretrained', checkpoint='mmcls://vit_large_p16', prefix='backbone.')
?
The error is caused by prefix backbone.
, you can check MMCV here.
Look forward to your reply.
Best,
Hi Menzhang,
Thanks for ur quick response!
But I think it's not the problem of 'backbone' prefix here. Sorry I didn't make this clear enough. (let me try again below lol)
If you take another look at the warning message in figure 1, u'll notice the params names for multiheadattention are different: the pretrained ckpt uses qkv.weight, qkv.bias, proj.weight, proj.bias, and the defined vit model uses attn.in_proj_weight, attn.in_proj_bias, attn.out_proj.weight, attn.out_proj.bias.
i think the current code already handles the 'backbone' prefix issue pretty well because as u can see in the warning message, there's no 'backbone.' in it. sorry i didn't get rid of it in my own demonstration jupyter code just for simplicity.
Therefore, the current codebase is basically training SETR from scratch because only loading a small amount of pre-trained vit weights.
So just wondering could u take a look at this issue and investigate when and how this unmatch happened? And how could we fix it?
Thanks!
Yes, because before this month we seldom use mmcls pretrained model link (actually we usually use downloaded .pth file and load it directly).
There must be some potential bugs lying in currently usage method. I would try to check out this issue as soon as possible.
cool thanks!
just wondering, for now, could u please provide a correct version of imagenet21k pre-trained vit-large ckpt such that we can use for setr's training? Thanks!
Hi, @Tsingularity We have checked the problem, it is caused by our codebase: the pretrained model of ViT is different (i.e., the link of ViT pretrained model from MMClassification and the model our MMSegmentation wanting to use is DIFFERENT.)
For quick use of SETR, you can use our old version: follow this pr to find download links of original pretrained model and convert its keys. Then change SETR config to use it.
For example, current init cfg is:
init_cfg=dict(type='Pretrained', checkpoint='mmcls://vit_large_p16')
you can find old model link in history version:
pretrained=\
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa
And convert this key by using:
python ./tools/model_converters/vit2mmseg.py ${SRC} ${DST}
NOTE: This ViT pretrained model is older version from Google, Google has released a new version of pretrained model which have better performance in downstream tasks.
If you want to use new version pretrained model from Google, you can follow the doc from Segmenter. Theoretically, its transferred model like vit_tiny_p16_384.pth
could also be used for SETR. You can have a try.
To sum up: you can check out history version of setr-related file to solve problems.
By the way, for long-term use of OpenMMLab, we would update our codebase to fix these potential bugs and errors. Downstream codebase such as MMDetection, MMSegmentation and others would support MMClassification and use its abundant backbone model.
Best.
Let me try to use vit-large pretrained model from Google:
(1) Old Jax model:
Transfer key:
python tools/model_converters/vit2mmseg.py \
jx_vit_large_p16_384-b3be5167.pth pretrain/vit_large_p16_384.pth
Change config file: https://github.com/open-mmlab/mmsegmentation/blob/master/configs/setr/setr_vit-large_pup_8x1_768x768_80k_cityscapes.py#L9-L12. Just make sure it use pretrained='pretrain/vit_large_p16_384.pth',
correctly.
(2) New JAX model
Download:
Transfer key:
python tools/model_converters/vitjax2mmseg.py \
Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz \
pretrain/vit_large_p16_384.pth
Change config file: https://github.com/open-mmlab/mmsegmentation/blob/master/configs/setr/setr_vit-large_pup_8x1_768x768_80k_cityscapes.py#L9-L12. Just make sure it use pretrained='pretrain/vit_large_p16_384.pth',
correctly.
We will fix this problem in this pr,
sorry for the late response. and huge thanks for the quick response! I'll take a look and see how it works now
Describe the bug When running the default training script for SETR, it would initialize the backbone vit model with the default downloaded vit-large ckpt from mmcls (mmcls://vit_large_p16). However, it would report a lot of missing keys during the model initialization:
In order to investigate what's going on, I create a notebook and check the downloaded keys and the model's state_dict keys.
Here's the list of the downloaded vit ckpt keys:
Here's the key list of model's state_dict():
And clearly you could see an unmatch issue for all the attention weights.
Looks like this is due to the unmatch implementation of multiheadattention in mmcls and mmseg:
For mmcls, this is their codes for multiheadattention: https://github.com/open-mmlab/mmclassification/blob/a7f8e96b31c10ab3e9c133293ca406e6e548475b/mmcls/models/utils/attention.py#L294
For mmseg, this is their codes for multiheadattention: https://github.com/open-mmlab/mmcv/blob/5de2b130d37301432ecbe0a51c31ef979b3d7a26/mmcv/cnn/bricks/transformer.py#L407
So could you please take a look at this and fix it? Feel free to let me know if there's anything I can help with on my end.
Thanks!
Reproduction
What command or script did you run?
Did you make any modifications on the code or config? Did you understand what you have modified? No.
What dataset did you use? ADE20K.
Environment
sys.platform: linux Python: 3.7.11 (default, Jul 27 2021, 14:32:16) [GCC 7.5.0] CUDA available: True GPU 0: NVIDIA RTX A6000 CUDA_HOME: /usr/local/cuda NVCC: Build cuda_11.2.r11.2/compiler.29558016_0 GCC: gcc (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0 PyTorch: 1.10.2 PyTorch compiling details: PyTorch built with:
TorchVision: 0.11.3 OpenCV: 4.5.5 MMCV: 1.4.4 MMCV Compiler: GCC 7.3 MMCV CUDA Compiler: 11.3 MMSegmentation: 0.21.1+b7a3f72
Error traceback
No error actually (due to non-restrick load_state_dict). but clearly the vit model is not loading correctly.
Bug fix
see above.