SwinTransformer / Swin-Transformer-Semantic-Segmentation

This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Semantic Segmentation.
https://arxiv.org/abs/2103.14030
Apache License 2.0
1.15k stars 224 forks source link

Inplace Operation: Gradient Fails #60

Open azhangmn opened 2 years ago

azhangmn commented 2 years ago

Hi there, I am trying to train swin_upernet on a custom dataset. My config file is as follows:

_base_ = [
    'configs/_base_/models/upernet_swin.py', 'configs/_base_/datasets/xxx.py',
    'configs/_base_/default_runtime.py', 'configs/_base_/schedules/schedule_160k.py'
]
model = dict(
    backbone=dict(
        embed_dim=128,
        depths=[2, 2, 18, 2],
        num_heads=[4, 8, 16, 32],
        window_size=8,
        ape=False,
        drop_path_rate=0.3,
        patch_norm=True,
        use_checkpoint=False
    ),
    decode_head=dict(
        in_channels=[128, 256, 512, 1024],
        num_classes=150
    ),
    auxiliary_head=dict(
        in_channels=512,
        num_classes=150
    ))

# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
                                                 'relative_position_bias_table': dict(decay_mult=0.),
                                                 'norm': dict(decay_mult=0.)}))

lr_config = dict(_delete_=True, policy='poly',
                 warmup='linear',
                 warmup_iters=1500,
                 warmup_ratio=1e-6,
                 power=1.0, min_lr=0.0, by_epoch=False)

# By default, models are trained on 8 GPUs with 2 images per GPU
data=dict(samples_per_gpu=2)

However, when I try to run training, I get the error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2, 512, 27, 27]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

This is while running torch anomaly to try to detect the source of the error.

Has anyone dealt with this before? How to fix the issue?

impiga commented 2 years ago

Hi, @azhangmn. It seems that an inplace ReLU operation results to this error.

Could you check where this operation is used? In addition, the input of this ReLU layer has a shape of [2, 512, 27, 27].

azhangmn commented 2 years ago

Hi @impiga, it seems that every part of the model except the swin transformer uses inplace ReLU. However, these parameters are imported from default settings in mmcv, if I have interpreted the code correctly. Do you have any suggestions to fix this problem? Thanks!

azhangmn commented 2 years ago

I found the solution, which was merged into the master branch of mmseg: https://github.com/open-mmlab/mmsegmentation/pull/1103. Turns out the uper_head.py code in mmseg/models/decode_heads will throw an error on line 104 due to the += operation. Similarly, the code in fpn.py under mmseg/models/necks will throw the error due to the += operation on lines 178 and 181. Please update this code.