microsoft / Swin-Transformer

This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows".
https://arxiv.org/abs/2103.14030
MIT License
13.55k stars 2.03k forks source link

IncompatibleKeys when finetuned SimMIM on SwinV2 #271

Open lifuguan opened 1 year ago

lifuguan commented 1 year ago

I tried to finetune SimMIM pre-trained Swin-V2 model following the get_started.md:

python -m torch.distributed.launch --nproc_per_node 16 main_simmim_ft.py \ 
--cfg configs/simmim/simmim_finetune__swin_base__img224_window7__800ep.yaml --batch-size 128 --data-path <imagenet-path> --pretrained <pretrained-ckpt> [--output <output-directory> --tag <job-tag>]

The link of pre-trained weight is shown below: https://msravcghub.blob.core.windows.net/simmim-release/swinv2/pretrain/swinv2_base_22k_125k.pth

Problem report

Unfortunately, the logger shows that there are many Incompatible Keys when loading the pre-trained weight file:


[2022-09-26 07:52:25 simmim_finetune](utils_simmim.py 112): INFO _IncompatibleKeys(missing_keys=['layers.0.blocks.0.attn.relative_coords_table', 'layers.0.blocks.0.attn.relative_position_index', 'layers.0.blocks.0.attn.cpb_mlp.0.weight', 'layers.0.blocks.0.attn.cpb_mlp.0.bias', 'layers.0.blocks.0.attn.cpb_mlp.2.weight', 'layers.0.blocks.1.attn_mask', 'layers.0.blocks.1.attn.relative_coords_table', 'layers.0.blocks.1.attn.relative_position_index', 'layers.0.blocks.1.attn.cpb_mlp.0.weight', 'layers.0.blocks.1.attn.cpb_mlp.0.bias', 'layers.0.blocks.1.attn.cpb_mlp.2.weight', 'layers.1.blocks.0.attn.relative_coords_table', 'layers.1.blocks.0.attn.relative_position_index', 'layers.1.blocks.0.attn.cpb_mlp.0.weight', 'layers.1.blocks.0.attn.cpb_mlp.0.bias', 'layers.1.blocks.0.attn.cpb_mlp.2.weight', 'layers.1.blocks.1.attn_mask', 'layers.1.blocks.1.attn.relative_coords_table', 'layers.1.blocks.1.attn.relative_position_index', 'layers.1.blocks.1.attn.cpb_mlp.0.weight', 'layers.1.blocks.1.attn.cpb_mlp.0.bias', 'layers.1.blocks.1.attn.cpb_mlp.2.weight', 'layers.2.blocks.0.attn.relative_coords_table', 'layers.2.blocks.0.attn.relative_position_index', 'layers.2.blocks.0.attn.cpb_mlp.0.weight', 'layers.2.blocks.0.attn.cpb_mlp.0.bias', 'layers.2.blocks.0.attn.cpb_mlp.2.weight', 'layers.2.blocks.1.attn.relative_coords_table', 'layers.2.blocks.1.attn.relative_position_index', 'layers.2.blocks.1.attn.cpb_mlp.0.weight', 'layers.2.blocks.1.attn.cpb_mlp.0.bias', 'layers.2.blocks.1.attn.cpb_mlp.2.weight', 'layers.2.blocks.2.attn.relative_coords_table', 'layers.2.blocks.2.attn.relative_position_index', 'layers.2.blocks.2.attn.cpb_mlp.0.weight', 'layers.2.blocks.2.attn.cpb_mlp.0.bias', 'layers.2.blocks.2.attn.cpb_mlp.2.weight', 'layers.2.blocks.3.attn.relative_coords_table', 'layers.2.blocks.3.attn.relative_position_index', 'layers.2.blocks.3.attn.cpb_mlp.0.weight', 'layers.2.blocks.3.attn.cpb_mlp.0.bias', 'layers.2.blocks.3.attn.cpb_mlp.2.weight', 'layers.2.blocks.4.attn.relative_coords_table', 'layers.2.blocks.4.attn.relative_position_index', 'layers.2.blocks.4.attn.cpb_mlp.0.weight', 'layers.2.blocks.4.attn.cpb_mlp.0.bias', 'layers.2.blocks.4.attn.cpb_mlp.2.weight', 'layers.2.blocks.5.attn.relative_coords_table', 'layers.2.blocks.5.attn.relative_position_index', 'layers.2.blocks.5.attn.cpb_mlp.0.weight', 'layers.2.blocks.5.attn.cpb_mlp.0.bias', 'layers.2.blocks.5.attn.cpb_mlp.2.weight', 'layers.2.blocks.6.attn.relative_coords_table', 'layers.2.blocks.6.attn.relative_position_index', 'layers.2.blocks.6.attn.cpb_mlp.0.weight', 'layers.2.blocks.6.attn.cpb_mlp.0.bias', 'layers.2.blocks.6.attn.cpb_mlp.2.weight', 'layers.2.blocks.7.attn.relative_coords_table', 'layers.2.blocks.7.attn.relative_position_index', 'layers.2.blocks.7.attn.cpb_mlp.0.weight', 'layers.2.blocks.7.attn.cpb_mlp.0.bias', 'layers.2.blocks.7.attn.cpb_mlp.2.weight', 'layers.2.blocks.8.attn.relative_coords_table', 'layers.2.blocks.8.attn.relative_position_index', 'layers.2.blocks.8.attn.cpb_mlp.0.weight', 'layers.2.blocks.8.attn.cpb_mlp.0.bias', 'layers.2.blocks.8.attn.cpb_mlp.2.weight', 'layers.2.blocks.9.attn.relative_coords_table', 'layers.2.blocks.9.attn.relative_position_index', 'layers.2.blocks.9.attn.cpb_mlp.0.weight', 'layers.2.blocks.9.attn.cpb_mlp.0.bias', 'layers.2.blocks.9.attn.cpb_mlp.2.weight', 'layers.2.blocks.10.attn.relative_coords_table', 'layers.2.blocks.10.attn.relative_position_index', 'layers.2.blocks.10.attn.cpb_mlp.0.weight', 'layers.2.blocks.10.attn.cpb_mlp.0.bias', 'layers.2.blocks.10.attn.cpb_mlp.2.weight', 'layers.2.blocks.11.attn.relative_coords_table', 'layers.2.blocks.11.attn.relative_position_index', 'layers.2.blocks.11.attn.cpb_mlp.0.weight', 'layers.2.blocks.11.attn.cpb_mlp.0.bias', 'layers.2.blocks.11.attn.cpb_mlp.2.weight', 'layers.2.blocks.12.attn.relative_coords_table', 'layers.2.blocks.12.attn.relative_position_index', 'layers.2.blocks.12.attn.cpb_mlp.0.weight', 'layers.2.blocks.12.attn.cpb_mlp.0.bias', 'layers.2.blocks.12.attn.cpb_mlp.2.weight', 'layers.2.blocks.13.attn.relative_coords_table', 'layers.2.blocks.13.attn.relative_position_index', 'layers.2.blocks.13.attn.cpb_mlp.0.weight', 'layers.2.blocks.13.attn.cpb_mlp.0.bias', 'layers.2.blocks.13.attn.cpb_mlp.2.weight', 'layers.2.blocks.14.attn.relative_coords_table', 'layers.2.blocks.14.attn.relative_position_index', 'layers.2.blocks.14.attn.cpb_mlp.0.weight', 'layers.2.blocks.14.attn.cpb_mlp.0.bias', 'layers.2.blocks.14.attn.cpb_mlp.2.weight', 'layers.2.blocks.15.attn.relative_coords_table', 'layers.2.blocks.15.attn.relative_position_index', 'layers.2.blocks.15.attn.cpb_mlp.0.weight', 'layers.2.blocks.15.attn.cpb_mlp.0.bias', 'layers.2.blocks.15.attn.cpb_mlp.2.weight', 'layers.2.blocks.16.attn.relative_coords_table', 'layers.2.blocks.16.attn.relative_position_index', 'layers.2.blocks.16.attn.cpb_mlp.0.weight', 'layers.2.blocks.16.attn.cpb_mlp.0.bias', 'layers.2.blocks.16.attn.cpb_mlp.2.weight', 'layers.2.blocks.17.attn.relative_coords_table', 'layers.2.blocks.17.attn.relative_position_index', 'layers.2.blocks.17.attn.cpb_mlp.0.weight', 'layers.2.blocks.17.attn.cpb_mlp.0.bias', 'layers.2.blocks.17.attn.cpb_mlp.2.weight', 'layers.3.blocks.0.attn.relative_coords_table', 'layers.3.blocks.0.attn.relative_position_index', 'layers.3.blocks.0.attn.cpb_mlp.0.weight', 'layers.3.blocks.0.attn.cpb_mlp.0.bias', 'layers.3.blocks.0.attn.cpb_mlp.2.weight', 'layers.3.blocks.1.attn.relative_coords_table', 'layers.3.blocks.1.attn.relative_position_index', 'layers.3.blocks.1.attn.cpb_mlp.0.weight', 'layers.3.blocks.1.attn.cpb_mlp.0.bias', 'layers.3.blocks.1.attn.cpb_mlp.2.weight', 'head.weight', 'head.bias'], unexpected_keys=['mask_token', 'layers.0.blocks.0.attn.rpe_mlp.0.weight', 'layers.0.blocks.0.attn.rpe_mlp.0.bias', 'layers.0.blocks.0.attn.rpe_mlp.2.weight', 'layers.0.blocks.1.attn.rpe_mlp.0.weight', 'layers.0.blocks.1.attn.rpe_mlp.0.bias', 'layers.0.blocks.1.attn.rpe_mlp.2.weight', 'layers.1.blocks.0.attn.rpe_mlp.0.weight', 'layers.1.blocks.0.attn.rpe_mlp.0.bias', 'layers.1.blocks.0.attn.rpe_mlp.2.weight', 'layers.1.blocks.1.attn.rpe_mlp.0.weight', 'layers.1.blocks.1.attn.rpe_mlp.0.bias', 'layers.1.blocks.1.attn.rpe_mlp.2.weight', 'layers.2.blocks.0.attn.rpe_mlp.0.weight', 'layers.2.blocks.0.attn.rpe_mlp.0.bias', 'layers.2.blocks.0.attn.rpe_mlp.2.weight', 'layers.2.blocks.1.attn.rpe_mlp.0.weight', 'layers.2.blocks.1.attn.rpe_mlp.0.bias', 'layers.2.blocks.1.attn.rpe_mlp.2.weight', 'layers.2.blocks.2.attn.rpe_mlp.0.weight', 'layers.2.blocks.2.attn.rpe_mlp.0.bias', 'layers.2.blocks.2.attn.rpe_mlp.2.weight', 'layers.2.blocks.3.attn.rpe_mlp.0.weight', 'layers.2.blocks.3.attn.rpe_mlp.0.bias', 'layers.2.blocks.3.attn.rpe_mlp.2.weight', 'layers.2.blocks.4.attn.rpe_mlp.0.weight', 'layers.2.blocks.4.attn.rpe_mlp.0.bias', 'layers.2.blocks.4.attn.rpe_mlp.2.weight', 'layers.2.blocks.5.attn.rpe_mlp.0.weight', 'layers.2.blocks.5.attn.rpe_mlp.0.bias', 'layers.2.blocks.5.attn.rpe_mlp.2.weight', 'layers.2.blocks.6.attn.rpe_mlp.0.weight', 'layers.2.blocks.6.attn.rpe_mlp.0.bias', 'layers.2.blocks.6.attn.rpe_mlp.2.weight', 'layers.2.blocks.7.attn.rpe_mlp.0.weight', 'layers.2.blocks.7.attn.rpe_mlp.0.bias', 'layers.2.blocks.7.attn.rpe_mlp.2.weight', 'layers.2.blocks.8.attn.rpe_mlp.0.weight', 'layers.2.blocks.8.attn.rpe_mlp.0.bias', 'layers.2.blocks.8.attn.rpe_mlp.2.weight', 'layers.2.blocks.9.attn.rpe_mlp.0.weight', 'layers.2.blocks.9.attn.rpe_mlp.0.bias', 'layers.2.blocks.9.attn.rpe_mlp.2.weight', 'layers.2.blocks.10.attn.rpe_mlp.0.weight', 'layers.2.blocks.10.attn.rpe_mlp.0.bias', 'layers.2.blocks.10.attn.rpe_mlp.2.weight', 'layers.2.blocks.11.attn.rpe_mlp.0.weight', 'layers.2.blocks.11.attn.rpe_mlp.0.bias', 'layers.2.blocks.11.attn.rpe_mlp.2.weight', 'layers.2.blocks.12.attn.rpe_mlp.0.weight', 'layers.2.blocks.12.attn.rpe_mlp.0.bias', 'layers.2.blocks.12.attn.rpe_mlp.2.weight', 'layers.2.blocks.13.attn.rpe_mlp.0.weight', 'layers.2.blocks.13.attn.rpe_mlp.0.bias', 'layers.2.blocks.13.attn.rpe_mlp.2.weight', 'layers.2.blocks.14.attn.rpe_mlp.0.weight', 'layers.2.blocks.14.attn.rpe_mlp.0.bias', 'layers.2.blocks.14.attn.rpe_mlp.2.weight', 'layers.2.blocks.15.attn.rpe_mlp.0.weight', 'layers.2.blocks.15.attn.rpe_mlp.0.bias', 'layers.2.blocks.15.attn.rpe_mlp.2.weight', 'layers.2.blocks.16.attn.rpe_mlp.0.weight', 'layers.2.blocks.16.attn.rpe_mlp.0.bias', 'layers.2.blocks.16.attn.rpe_mlp.2.weight', 'layers.2.blocks.17.attn.rpe_mlp.0.weight', 'layers.2.blocks.17.attn.rpe_mlp.0.bias', 'layers.2.blocks.17.attn.rpe_mlp.2.weight', 'layers.3.blocks.0.attn.rpe_mlp.0.weight', 'layers.3.blocks.0.attn.rpe_mlp.0.bias', 'layers.3.blocks.0.attn.rpe_mlp.2.weight', 'layers.3.blocks.1.attn.rpe_mlp.0.weight', 'layers.3.blocks.1.attn.rpe_mlp.0.bias', 'layers.3.blocks.1.attn.rpe_mlp.2.weight'])
[2022-09-26 07:52:25 simmim_finetune](utils_simmim.py 116): INFO >>>>>>>>>> loaded successfully 'model_zoo/swinv2_base_22k_simmim_125k.pth'

It means the pre-trained file is not loaded properly.

Weight visualization

Furthermore, I double checked the weight keys of the pre-trained file using the following code:

dict = torch.load("model_zoo/swinv2_base_22k_simmim_125k.pth")
dict['model'].keys()

It shows the keys are Incompatible with the expected model: encoder. is unwanted.

odict_keys(['encoder.mask_token', 'encoder.patch_embed.proj.weight', 'encoder.patch_embed.proj.bias', 'encoder.patch_embed.norm.weight', 'encoder.patch_embed.norm.bias', 'encoder.layers.0.blocks.0.norm1.weight', 'encoder.layers.0.blocks.0.norm1.bias', 'encoder.layers.0.blocks.0.attn.logit_scale', 'encoder.layers.0.blocks.0.attn.q_bias', 'encoder.layers.0.blocks.0.attn.v_bias', 'encoder.layers.0.blocks.0.attn.relative_coords_table', 'encoder.layers.0.blocks.0.attn.relative_position_index', 'encoder.layers.0.blocks.0.attn.rpe_mlp.0.weight', 'encoder.layers.0.blocks.0.attn.rpe_mlp.0.bias', 'encoder.layers.0.blocks.0.attn.rpe_mlp.2.weight', 'encoder.layers.0.blocks.0.attn.qkv.weight', 'encoder.layers.0.blocks.0.attn.proj.weight', 'encoder.layers.0.blocks.0.attn.proj.bias', 'encoder.layers.0.blocks.0.norm2.weight', 'encoder.layers.0.blocks.0.norm2.bias', 'encoder.layers.0.blocks.0.mlp.fc1.weight', 'encoder.layers.0.blocks.0.mlp.fc1.bias', 'encoder.layers.0.blocks.0.mlp.fc2.weight', 'encoder.layers.0.blocks.0.mlp.fc2.bias', 
image

I will be appreciated if it can be fixed soon!

zdaxie commented 1 year ago

Hi @lifuguan,

Thanks a lot for your feedback!

We checked the code and found the problem, but some of the mismatched keys were intentional. The problem is about the name mismatch between cpb_mlp in the released model and rpe_mlp in the provided checkpoints. Luckily we have fixed this issue and you could pull the recent updates to check it.

However, the missing relative_coords_table and relative_position_index are intentional. Since the input resolution is set to 192 during pre-training and 224 during fine-tuning, we will just delete these two params in loading pre-trained models to avoid loading into wrong shaped params. And the model will generate the right params. You could refer to here for more details:

https://github.com/microsoft/Swin-Transformer/blob/main/utils_simmim.py#L189 https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer_v2.py#L97

So currently, after pulling the latest repo, the IncompatibleKeys prompt you get when fine-tuning the provided pre-trained models should look like this:

[2022-09-29 16:01:49 simmim_finetune](utils_simmim.py 119): INFO _IncompatibleKeys(missing_keys=['layers.0.blocks.0.attn.relative_coords_table', 'layers.0.blocks.0.attn.relative_position_index', 'layers.0.blocks.1.attn_mask', 'layers.0.blocks.1.attn.relative_coords_table', 'layers.0.blocks.1.attn.relative_position_index', 'layers.1.blocks.0.attn.relative_coords_table', 'layers.1.blocks.0.attn.relative_position_index', 'layers.1.blocks.1.attn_mask', 'layers.1.blocks.1.attn.relative_coords_table', 'layers.1.blocks.1.attn.relative_position_index', 'layers.2.blocks.0.attn.relative_coords_table', 'layers.2.blocks.0.attn.relative_position_index', 'layers.2.blocks.1.attn.relative_coords_table', 'layers.2.blocks.1.attn.relative_position_index', 'layers.2.blocks.2.attn.relative_coords_table', 'layers.2.blocks.2.attn.relative_position_index', 'layers.2.blocks.3.attn.relative_coords_table', 'layers.2.blocks.3.attn.relative_position_index', 'layers.2.blocks.4.attn.relative_coords_table', 'layers.2.blocks.4.attn.relative_position_index', 'layers.2.blocks.5.attn.relative_coords_table', 'layers.2.blocks.5.attn.relative_position_index', 'layers.2.blocks.6.attn.relative_coords_table', 'layers.2.blocks.6.attn.relative_position_index', 'layers.2.blocks.7.attn.relative_coords_table', 'layers.2.blocks.7.attn.relative_position_index', 'layers.2.blocks.8.attn.relative_coords_table', 'layers.2.blocks.8.attn.relative_position_index', 'layers.2.blocks.9.attn.relative_coords_table', 'layers.2.blocks.9.attn.relative_position_index', 'layers.2.blocks.10.attn.relative_coords_table', 'layers.2.blocks.10.attn.relative_position_index', 'layers.2.blocks.11.attn.relative_coords_table', 'layers.2.blocks.11.attn.relative_position_index', 'layers.2.blocks.12.attn.relative_coords_table', 'layers.2.blocks.12.attn.relative_position_index', 'layers.2.blocks.13.attn.relative_coords_table', 'layers.2.blocks.13.attn.relative_position_index', 'layers.2.blocks.14.attn.relative_coords_table', 'layers.2.blocks.14.attn.relative_position_index', 'layers.2.blocks.15.attn.relative_coords_table', 'layers.2.blocks.15.attn.relative_position_index', 'layers.2.blocks.16.attn.relative_coords_table', 'layers.2.blocks.16.attn.relative_position_index', 'layers.2.blocks.17.attn.relative_coords_table', 'layers.2.blocks.17.attn.relative_position_index', 'layers.3.blocks.0.attn.relative_coords_table', 'layers.3.blocks.0.attn.relative_position_index', 'layers.3.blocks.1.attn.relative_coords_table', 'layers.3.blocks.1.attn.relative_position_index', 'head.weight', 'head.bias'], unexpected_keys=['mask_token'])

Hope this will solve your problem.

Rowan-L commented 1 year ago

First of all, thanks a lot for solving the problem of loading pre-trained models, but a new problem appeared, swin_v2 doesn't seem to work when the tested images are of different resolutions.