open-mmlab / mmdetection

OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io
Apache License 2.0
29.74k stars 9.48k forks source link

SWIN Backbone: KeyError: 'stages.2.blocks.6.attn.w_msa.relative_position_bias_table' #6667

Closed pfuerste closed 3 years ago

pfuerste commented 3 years ago

I am trying to use Swin as a backbone for Deformable-DETR. I used /swin/mask_rcnn_swin-t-p4-w7_fpn_1x_coco.py and my earlier DDETR-configs to to create a config, only changing the backbone-part, but it throws the following error after I call train.py:

2021-12-02 10:42:01,371 - mmdet - INFO - Use load_from_http loader
Traceback (most recent call last):
  File "/home/fuerste/mmdetection/tools/train.py", line 207, in <module>
    main()
  File "/home/fuerste/mmdetection/tools/train.py", line 181, in main
    model.init_weights()
  File "/home/fuerste/miniconda3/envs/openmmlab/lib/python3.9/site-packages/mmcv/runner/base_module.py", line 117, in init_weights
    m.init_weights()
  File "/home/fuerste/miniconda3/envs/openmmlab/lib/python3.9/site-packages/mmdet/models/backbones/swin.py", line 728, in init_weights
    table_current = self.state_dict()[table_key]
KeyError: 'stages.2.blocks.6.attn.w_msa.relative_position_bias_table'

I have no problem when using DDETR with this config without changing the backbone (default ResNet50). I think something in my config is not right, could you tell me what I am missing? Here it is:

epochs = 20
lr = 2e-5
samples_per_gpu = 4
workers_per_gpu = 1

_base_ = '/home/fuerste/mmdetection/configs/deformable_detr/deformable_detr_r50_16x2_50e_coco.py'

model = dict(
    type='DeformableDETR',
    backbone=dict(
        _delete_=True,
        type='SwinTransformer',
        embed_dims=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        window_size=7,
        mlp_ratio=4,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.,
        attn_drop_rate=0.,
        drop_path_rate=0.2,
        patch_norm=True,
        out_indices=(0, 1, 2, 3),
        with_cp=False,
        convert_weights=True,
        init_cfg=dict(type='Pretrained', checkpoint='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth')),
    bbox_head=dict(
        type='DeformableDETRHead',
        num_classes=1))

dataset_type = 'COCODataset'
classes = (
    'animal',
)

# https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py
# Default Decrease all lr to 0.1
lr_config = dict(policy='step',
                 step=[int(0.75 * epochs)],
                 warmup='linear',
                 warmup_iters=500,
                 warmup_ratio=0.001)
workflow = [('train', 1), ('val', 1)]
runner = dict(type='EpochBasedRunner', max_epochs=epochs)
optimizer = dict(
    type='AdamW',
    lr=lr,
    weight_decay=0.0001,
    paramwise_cfg=dict(
        custom_keys={
            'backbone': dict(lr_mult=0.1),
            'sampling_offsets': dict(lr_mult=0.1),
            'reference_points': dict(lr_mult=0.1),
            #'absolute_pos_embed': dict(decay_mult=0.),
            #'relative_position_bias_table': dict(decay_mult=0.),
            #'norm': dict(decay_mult=0.)
        }))

img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

data = dict(
    samples_per_gpu=samples_per_gpu,
    workers_per_gpu=workers_per_gpu,
    train=dict(
        img_prefix='/home/datasets/camera_traps/Caltech-Camera-Traps/CCT20-benchmark/eccv_18_all_images_sm',
        classes=classes,
        filter_empty_gt=False,
        ann_file='/home/fuerste/thesis_root/data/cct20/annotations/one_cat/train_annotations.json'),
    val=dict(
        img_prefix='/home/datasets/camera_traps/Caltech-Camera-Traps/CCT20-benchmark/eccv_18_all_images_sm',
        classes=classes,
        filter_empty_gt=False,
        separate_eval=True,
        ann_file=['/home/fuerste/thesis_root/data/cct20/annotations/one_cat/cis_val_annotations.json',
                  '/home/fuerste/thesis_root/data/cct20/annotations/one_cat/trans_val_annotations.json']),
    test=dict(
        img_prefix='/home/datasets/camera_traps/Caltech-Camera-Traps/CCT20-benchmark/eccv_18_all_images_sm',
        classes=classes,
        filter_empty_gt=False,
        separate_eval=True,
        ann_file=['/home/fuerste/thesis_root/data/cct20/annotations/one_cat/cis_test_annotations.json',
                  '/home/fuerste/thesis_root/data/cct20/annotations/one_cat/trans_test_annotations.json']))

log_config = dict(
    interval=10,
    hooks=[
        dict(type='TextLoggerHook'),
        dict(
            type='WandbLoggerHook',
            init_kwargs=dict(
                project='ddetr_test_lr',
                config={},
                tags=["gaia5", "ddetr", "single_class", "default_scale", f"batch_size {samples_per_gpu}", f"epochs {epochs}", f"lr {lr}"]
            ))])

load_from = '/home/fuerste/mmdetection/checkpoints/deformable_detr_r50_16x2_50e_coco_20210419_220030-a12b9512.pth'
RangiLyu commented 3 years ago

It seems that you were trying to load swin_small pretrained weight to swin_tiny.

pfuerste commented 3 years ago

True, I got that mixed up. Thanks!

pfuerste commented 3 years ago

Okay, I changed the loaded weights to tiny instead of weights, but now I get another error. It seems like SWIN's output has a different dimensionality than that of ResNet, but I do not know how to find out how big either is or should be. I guess I have to change something like num_heads or out_indices in the backbone or something else in the channel_mapper?

Here is the error:

Traceback (most recent call last):
  File "/home/fuerste/mmdetection/tools/train.py", line 207, in <module>
    main()
  File "/home/fuerste/mmdetection/tools/train.py", line 196, in main
    train_detector(
  File "/home/fuerste/miniconda3/envs/openmmlab/lib/python3.9/site-packages/mmdet/apis/train.py", line 174, in train_detector
    runner.run(data_loaders, cfg.workflow)
  File "/home/fuerste/miniconda3/envs/openmmlab/lib/python3.9/site-packages/mmcv/runner/epoch_based_runner.py", line 127, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/home/fuerste/miniconda3/envs/openmmlab/lib/python3.9/site-packages/mmcv/runner/epoch_based_runner.py", line 50, in train
    self.run_iter(data_batch, train_mode=True, **kwargs)
  File "/home/fuerste/miniconda3/envs/openmmlab/lib/python3.9/site-packages/mmcv/runner/epoch_based_runner.py", line 29, in run_iter
    outputs = self.model.train_step(data_batch, self.optimizer,
  File "/home/fuerste/miniconda3/envs/openmmlab/lib/python3.9/site-packages/mmcv/parallel/data_parallel.py", line 67, in train_step
    return self.module.train_step(*inputs[0], **kwargs[0])
  File "/home/fuerste/miniconda3/envs/openmmlab/lib/python3.9/site-packages/mmdet/models/detectors/base.py", line 238, in train_step
    losses = self(**data)
  File "/home/fuerste/miniconda3/envs/openmmlab/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/fuerste/miniconda3/envs/openmmlab/lib/python3.9/site-packages/mmcv/runner/fp16_utils.py", line 98, in new_func
    return old_func(*args, **kwargs)
  File "/home/fuerste/miniconda3/envs/openmmlab/lib/python3.9/site-packages/mmdet/models/detectors/base.py", line 172, in forward
    return self.forward_train(img, img_metas, **kwargs)
  File "/home/fuerste/miniconda3/envs/openmmlab/lib/python3.9/site-packages/mmdet/models/detectors/single_stage.py", line 82, in forward_train
    x = self.extract_feat(img)
  File "/home/fuerste/miniconda3/envs/openmmlab/lib/python3.9/site-packages/mmdet/models/detectors/single_stage.py", line 45, in extract_feat
    x = self.neck(x)
  File "/home/fuerste/miniconda3/envs/openmmlab/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/fuerste/miniconda3/envs/openmmlab/lib/python3.9/site-packages/mmdet/models/necks/channel_mapper.py", line 92, in forward
    assert len(inputs) == len(self.convs)
AssertionError
pfuerste commented 3 years ago

Nevermind, searching the configs helped, set neck=dict(in_channels=[96, 192, 384, 768]),