open-mmlab / mmpretrain

OpenMMLab Pre-training Toolbox and Benchmark
https://mmpretrain.readthedocs.io/en/latest/
Apache License 2.0
3.41k stars 1.06k forks source link

[Bug] Error in loading Swin pretrained weights, KeyError: 'backbone.stages.0.blocks.0.attn.w_msa.relative_position_bias_table' #794

Closed vothaianh1997 closed 2 years ago

vothaianh1997 commented 2 years ago

!python tools/test.py ./configs/swin_transformer/swin-tiny_cats-dogs.py work_dirs/swin-tiny_cats-dogs/latest.pth --metrics=accuracy --metric-options=topk=1

/usr/local/lib/python3.7/dist-packages/mmcv/cnn/bricks/transformer.py:33: UserWarning: Fail to import MultiScaleDeformableAttention from mmcv.ops.multi_scale_deform_attn, You should install mmcv-full if you need this module. warnings.warn('Fail to import MultiScaleDeformableAttention from ' /content/mmclassification/mmcls/utils/setup_env.py:33: UserWarning: Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. f'Setting OMP_NUM_THREADS environment variable for each process ' /content/mmclassification/mmcls/utils/setup_env.py:43: UserWarning: Setting MKL_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. f'Setting MKL_NUM_THREADS environment variable for each process ' load checkpoint from local path: work_dirs/swin-tiny_cats-dogs/latest.pth Traceback (most recent call last): File "tools/test.py", line 243, in main() File "tools/test.py", line 170, in main checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') File "/usr/local/lib/python3.7/dist-packages/mmcv/runner/checkpoint.py", line 563, in load_checkpoint load_state_dict(model, state_dict, strict, logger) File "/usr/local/lib/python3.7/dist-packages/mmcv/runner/checkpoint.py", line 80, in load_state_dict load(module) File "/usr/local/lib/python3.7/dist-packages/mmcv/runner/checkpoint.py", line 78, in load load(child, prefix + name + '.') File "/usr/local/lib/python3.7/dist-packages/mmcv/runner/checkpoint.py", line 75, in load err_msg) File "/content/mmclassification/mmcls/models/backbones/swin_transformer.py", line 457, in _load_from_state_dict *args, **kwargs) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1372, in _load_from_state_dict hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) File "/content/mmclassification/mmcls/models/backbones/swin_transformer.py", line 514, in _prepare_relative_position_bias_table relative_position_bias_table_current = state_dict_model[key] KeyError: 'backbone.stages.0.blocks.0.attn.w_msa.relative_position_bias_table'

vothaianh1997 commented 2 years ago

this link https://drive.google.com/file/d/12Aa88pVrerVGUMOADfW_3cNGymFSXqd_/view?usp=sharing

tcexeexe commented 2 years ago

hello. l have meet the same problem when use the swin transformer, the config file is ./configs/swin_transformer/swin-base_16xb64_in1k.py. The traning is ok but when l do the inference, the error occurred python tools/test.py ./configs/swin_transformer/swin-base_16xb64_in1k.py ./work_dirs/swin-base_16xb64_in1k/epoch_1.pth --out "result_220422.json" And l haven't use the pretrained model

/home/heji/code/CVPR2022/mmclassification/mmcls/utils/setup_env.py:33: UserWarning: Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
  f'Setting OMP_NUM_THREADS environment variable for each process '
/home/heji/code/CVPR2022/mmclassification/mmcls/utils/setup_env.py:43: UserWarning: Setting MKL_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
  f'Setting MKL_NUM_THREADS environment variable for each process '
load checkpoint from local path: ./work_dirs/swin-base_16xb64_in1k/epoch_1.pth
Traceback (most recent call last):
  File "tools/test.py", line 243, in <module>
    main()
  File "tools/test.py", line 170, in main
    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
  File "/home/heji/anaconda3/envs/openmmlab/lib/python3.7/site-packages/mmcv/runner/checkpoint.py", line 563, in load_checkpoint
    load_state_dict(model, state_dict, strict, logger)
  File "/home/heji/anaconda3/envs/openmmlab/lib/python3.7/site-packages/mmcv/runner/checkpoint.py", line 80, in load_state_dict
    load(module)
  File "/home/heji/anaconda3/envs/openmmlab/lib/python3.7/site-packages/mmcv/runner/checkpoint.py", line 78, in load
    load(child, prefix + name + '.')
  File "/home/heji/anaconda3/envs/openmmlab/lib/python3.7/site-packages/mmcv/runner/checkpoint.py", line 75, in load
    err_msg)
  File "/home/heji/code/CVPR2022/mmclassification/mmcls/models/backbones/swin_transformer.py", line 457, in _load_from_state_dict
    *args, **kwargs)
  File "/home/heji/anaconda3/envs/openmmlab/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1127, in _load_from_state_dict
    hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
  File "/home/heji/code/CVPR2022/mmclassification/mmcls/models/backbones/swin_transformer.py", line 514, in _prepare_relative_position_bias_table
    relative_position_bias_table_current = state_dict_model[key]
KeyError: 'backbone.stages.0.blocks.0.attn.w_msa.relative_position_bias_table'

my configfile is

_base_ = ['./pipelines/rand_aug.py']

# dataset settings
dataset_type = 'Filelist'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='RandomResizedCrop',
        size=512,
        backend='pillow',
        interpolation='bicubic'),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(
        type='RandAugment',
        policies={{_base_.rand_increasing_policies}},
        num_policies=2,
        total_level=10,
        magnitude_level=9,
        magnitude_std=0.5,
        hparams=dict(
            pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
            interpolation='bicubic')),
    dict(
        type='RandomErasing',
        erase_prob=0.25,
        mode='rand',
        min_area_ratio=0.02,
        max_area_ratio=1 / 3,
        fill_color=img_norm_cfg['mean'][::-1],
        fill_std=img_norm_cfg['std'][::-1]),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]

test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='Resize',
        size=(512, -1),
        backend='pillow',
        interpolation='bicubic'),
    # dict(type='CenterCrop', crop_size=224),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
]
# heji
data = dict(
    samples_per_gpu=8,
    workers_per_gpu=4,
    train=dict(
        type=dataset_type,
        ann_file='../data/phase1/trainset_label.txt',
        data_prefix='../data/phase1/trainset/',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        ann_file='../data/phase1/valset_label.txt',
        data_prefix='../data/phase1/valset/',
        pipeline=test_pipeline),
    test=dict(
        # replace `data/val` with `data/test` for standard test
        type=dataset_type,
        ann_file='../data/phase1/valset_label.txt',
        data_prefix='../data/phase1/valset/',
        pipeline=test_pipeline))

# fp16 = dict(loss_scale=768.)
evaluation = dict(interval=10, metric='accuracy')
# load_from = './weights/swin_base_224_b16x64_300e_imagenet_20210616_190742-93230b0d.pth'

Does anybody know what the problem is

Ezra-Yu commented 2 years ago

Can you give me the version information? Try to run pip list | grep "mmcls\|torch\|mmcv".

tcexeexe commented 2 years ago

Can you give me the version information? Try to run pip list | grep "mmcls\|torch\|mmcv". @Ezra-Yu mmcls 0.22.1
mmcv-full 1.4.8 torch 1.8.2 torchaudio 0.8.2 torchvision 0.9.2

vothaianh1997 commented 2 years ago

Can you give me the version information? Try to run pip list | grep "mmcls|torch|mmcv".

@Ezra-Yu mmcls 0.22.1 mmcv-full 1.4.8 torch 1.8.2 torchaudio 0.8.2 torchvision 0.9.2

vothaianh1997 commented 2 years ago

pls help me if U can @Ezra-Yu

vothaianh1997 commented 2 years ago

can you give me link colab MMClassification swin transformer no error. I need this code for graduation essay. help me please @Ezra-Yu

Ezra-Yu commented 2 years ago

@vothaianh1997 I will check if there is a bug in the code base.

For a quick fix, you can use a lower version of mmcls. just run

pip uninstall mmcls               # unistall current mmcls
pip install mmcls==0.22.0    # install a lower version of mmcls, 

In this way, you can not modify the code in '$MMClassification/mmcls', but the modification of the config files are OK

Ezra-Yu commented 2 years ago

here is an example to use swin

Ezra-Yu commented 2 years ago

We have fixed this bug in the latest version. please update mmcls >=v0.23.0.

I will close this PR.