open-mmlab / mmsegmentation

OpenMMLab Semantic Segmentation Toolbox and Benchmark.
https://mmsegmentation.readthedocs.io/en/main/
Apache License 2.0
8.36k stars 2.63k forks source link

Question of the swin backbones in mmseg and mmdet #1877

Closed Li-Qingyun closed 2 years ago

Li-Qingyun commented 2 years ago

I conducted a simple exploratory experiment, to explore whether the backbones of mmseg and mmdet have the same effect under the same backbone configuration, for better understanding the various tasks of the OpenMMLab series.

I found that backbones, such as swin and resnet, have different degrees of similarity in mmseg and mmdet. There are only a few lines of codes different between swin files in mmdet and mmseg. (different on convert_weights arg). It seems that the backbones can take the same effect.

I tried compare r50 and swin-7-p4-w7, under the same state_dict, in the same env. The r50 is aligned and the swin is misaligned.

I submit this issue to ask whether my usage and understanding were wrong, and discuss the difference among swin of mmdet and mmseg and even more downstream tasks of OpenMMLab.

Hope that openmmlab can have a more unified interface for the backbone shared by multiple tasks in the future, which will be easier to understand and reduce code redundancy.

Code

import torch

from mmdet.apis import set_random_seed as mmdet_set_random_seed
from mmseg.apis import set_random_seed as mmseg_set_random_seed
from mmdet.models.builder import build_backbone as mmdet_build_backbone
from mmseg.models.builder import build_backbone as mmseg_build_backbone

USE_GPU = False

r50_config = dict(
    type='ResNet',
    depth=50,
    num_stages=4,
    out_indices=(0, 1, 2, 3),
    frozen_stages=1,
    norm_cfg=dict(type='BN', requires_grad=False),
    norm_eval=True,
    style='pytorch',
    init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'))

swin_t_config = dict(
    type='SwinTransformer',
    embed_dims=96,
    depths=[2, 2, 6, 2],
    num_heads=[3, 6, 12, 24],
    window_size=7,
    mlp_ratio=4,
    qkv_bias=True,
    qk_scale=None,
    drop_rate=0.,
    attn_drop_rate=0.,
    drop_path_rate=0.2,
    patch_norm=True,
    out_indices=(0, 1, 2, 3),
    with_cp=False,
    convert_weights=True,
    init_cfg=dict(type='Pretrained', checkpoint=
    'https://github.com/SwinTransformer/storage/releases/download/v1.0.0'
    '/swin_tiny_patch4_window7_224.pth'))

def compare(a, b, message: str):
    a = a.values() if hasattr(a, 'values') else a
    b = b.values() if hasattr(b, 'values') else b
    for i, (a_, b_) in enumerate(zip(a, b)):
        if torch.equal(a_, b_):
            comparison_result = 'equal'
        else:
            comparison_result = torch.sum(a_-b_).item()
        print(message.format(layer=i, result=comparison_result))

if __name__ == '__main__':
    backbone = torch.nn.ModuleDict()

    mmdet_set_random_seed(2022, True)
    backbone['mmdet_r50'] = mmdet_build_backbone(r50_config)

    mmseg_set_random_seed(2022, True)
    backbone['mmseg_r50'] = mmseg_build_backbone(r50_config)

    mmdet_set_random_seed(2022, True)
    backbone['mmdet_swin_t'] = mmdet_build_backbone(swin_t_config)

    swin_t_config.pop('convert_weights')
    mmseg_set_random_seed(2022, True)
    backbone['mmseg_swin_t'] = mmseg_build_backbone(swin_t_config)

    sample = torch.rand(1, 3, 800, 800)
    if USE_GPU:
        sample, backbone = sample.cuda(), backbone.cuda()

    mmdet_r50_output = backbone['mmdet_r50'](sample)
    mmseg_r50_output = backbone['mmseg_r50'](sample)
    compare(mmdet_r50_output, mmseg_r50_output,
            'ResNet backbone output[{layer}]: {result}')

    mmdet_swin_output = backbone['mmdet_swin_t'](sample)
    mmseg_swin_output = backbone['mmseg_swin_t'](sample)
    compare(mmdet_swin_output, mmseg_swin_output,
            'Swin-T backbone output[{layer}]: {result}')

    backbone_sd = {k: m.state_dict() for k, m in backbone.items()}
    compare(*[sd for k, sd in backbone_sd.items() if 'r50' in k],
            'ResNet backbone state_dict: {result}')
    compare(*[sd for k, sd in backbone_sd.items() if 'swin_t' in k],
            'Swin-T backbone state_dict: {result}')

env

Python 3.8.13 | torch 1.11.0+cu113 | torchvision 0.12.0 mmcv-full 1.5.3 | mmdet 2.25.0 | mmsegmentation 0.27.0 GPU: RTX3070

Results

USE_GPU = False

ResNet backbone output[0]: equal
ResNet backbone output[1]: equal
ResNet backbone output[2]: equal
ResNet backbone output[3]: equal
Swin-T backbone output[0]: equal
Swin-T backbone output[1]: equal
Swin-T backbone output[2]: -0.00146484375
Swin-T backbone output[3]: 0.000244140625
ResNet backbone state_dict: equal
......
Swin-T backbone state_dict: equal
......

图片

USE_GPU = True

ResNet backbone output[0]: equal
ResNet backbone output[1]: equal
ResNet backbone output[2]: equal
ResNet backbone output[3]: equal
Swin-T backbone output[0]: equal
Swin-T backbone output[1]: 0.0009765625
Swin-T backbone output[2]: 0.0009765625
Swin-T backbone output[3]: 0.0
ResNet backbone state_dict: equal
......
Swin-T backbone state_dict: equal
......

图片

MengzhangLI commented 2 years ago

I am not sure whether mmdet and mmseg swin backbone are strictly identical, due to lack of human resources on developing, those refactory works are not implemented last year, but I am confident at least the models should be aligned with original paper, so does other models. As for swin backbone of mmdet and mmseg, they should be aligned strictly with official swin object detection repo and official swin semantic segmentation. A potential reason is certain parameters in config file of mmdet and mmseg are different but I did not check it.

However, supporting backbone in an unified way has been considered from now on that is why in 2022, ConvNeXt is implemented in MMClassification backbone and used in mmdet and mmseg. So does other potential models such as PoolFormer(CVPR'2022). But those older models were not been refactored because lack of human resources.

Besides making an issue about requirement, PR about code refactory of older models into MMClassification is more appreciated because OpenMMLab is supported and shared by the entire community and currently maintainers are already very busy on regular work. And related PRs from community are really needed.

Li-Qingyun commented 2 years ago

@MengzhangLI Hi~ Thanks for your quick response.

I have solved it myself and found that the randomness caused by the dropout operations of comparison modules was different. When the model is in eval mode, or the seed is reset before forward, the outputs can be aligned. It's my fault for submitting issue without further verification.

Besides making an issue about requirement, PR about code refactory of older models into MMClassification is more appreciated because OpenMMLab is supported and shared by the entire community and currently maintainers are already very busy on regular work. And related PRs from community are really needed.

I'm currently working on my first PR and I'm looking forward to being a long-term contributor to the openmmlab series of open source frameworks. After finishing the urgent things (including completing the current PR and completing the academic requirements of the campus), I will try myself and encourage people around me to make contributions.

Thank you very much for your efficient and easy-to-use framework of mmsegmentation.

The final exploratory code

import torch

from mmdet.apis import set_random_seed as mmdet_set_random_seed
from mmseg.apis import set_random_seed as mmseg_set_random_seed
from mmdet.models.builder import build_backbone as mmdet_build_backbone
from mmseg.models.builder import build_backbone as mmseg_build_backbone

IMG_SIZE = [(1, 3, 224, 224), (1, 3, 800, 800)][0]
USE_GPU = True
CONVERT_WEIGHT = False
SEED = 2022
DETERMINISTIC = False
EVAL_MODE = False
INIT_SET_SEED = True
FORWARD_SET_SEED = True

r50_config = dict(
    type='ResNet',
    depth=50,
    num_stages=4,
    out_indices=(0, 1, 2, 3),
    frozen_stages=1,
    norm_cfg=dict(type='BN', requires_grad=False),
    norm_eval=True,
    style='pytorch',
    init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'))

swin_t_config = dict(
    mmdet=dict(
        type='SwinTransformer',
        pretrain_img_size=224,
        embed_dims=96,
        patch_size=4,
        window_size=7,
        mlp_ratio=4,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        strides=(4, 2, 2, 2),
        out_indices=(0, 1, 2, 3),
        drop_path_rate=0.2,  # difference
        convert_weights=CONVERT_WEIGHT,  # difference
        init_cfg=dict(type='Pretrained', checkpoint=
        'https://github.com/SwinTransformer/storage/releases/download/v1.0.0'
        '/swin_tiny_patch4_window7_224.pth')),
    mmseg=dict(
        type='SwinTransformer',
        pretrain_img_size=224,
        embed_dims=96,
        patch_size=4,
        window_size=7,
        mlp_ratio=4,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        strides=(4, 2, 2, 2),
        out_indices=(0, 1, 2, 3),
        drop_path_rate=0.3,
        init_cfg=dict(type='Pretrained', checkpoint=
        'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/'
        'swin_tiny_patch4_window7_224_20220317-1cdeb081.pth')))

def compare(a, b, message: str):
    a = a.values() if hasattr(a, 'values') else a
    b = b.values() if hasattr(b, 'values') else b
    for i, (a_, b_) in enumerate(zip(a, b)):
        if torch.equal(a_, b_):
            comparison_result = 'equal'
        else:
            comparison_result = torch.sum(a_-b_).item()
        print(message.format(layer=i, result=comparison_result))

if __name__ == '__main__':
    backbone = torch.nn.ModuleDict()

    if INIT_SET_SEED:
        mmdet_set_random_seed(SEED, DETERMINISTIC)
    backbone['mmdet_r50'] = mmdet_build_backbone(r50_config)

    if INIT_SET_SEED:
        mmseg_set_random_seed(SEED, DETERMINISTIC)
    backbone['mmseg_r50'] = mmseg_build_backbone(r50_config)

    if INIT_SET_SEED:
        mmdet_set_random_seed(SEED, DETERMINISTIC)
    backbone['mmdet_swin_t'] = mmdet_build_backbone(swin_t_config['mmdet'])

    swin_t_config['mmdet'].pop('convert_weights')
    if INIT_SET_SEED:
        mmseg_set_random_seed(SEED, DETERMINISTIC)
    backbone['mmseg_swin_t'] = mmseg_build_backbone(swin_t_config['mmdet'])

    if EVAL_MODE:
        backbone.eval()

    sample = torch.rand(*IMG_SIZE)
    if USE_GPU:
        sample, backbone = sample.cuda(), backbone.cuda()

    if FORWARD_SET_SEED:
        mmdet_set_random_seed(SEED, DETERMINISTIC)
    mmdet_r50_output = backbone['mmdet_r50'](sample)
    if FORWARD_SET_SEED:
        mmseg_set_random_seed(SEED, DETERMINISTIC)
    mmseg_r50_output = backbone['mmseg_r50'](sample)
    compare(mmdet_r50_output, mmseg_r50_output,
            'ResNet backbone output[{layer}]: {result}')

    if FORWARD_SET_SEED:
        mmdet_set_random_seed(SEED, DETERMINISTIC)
    mmdet_swin_output = backbone['mmdet_swin_t'](sample)
    if FORWARD_SET_SEED:
        mmseg_set_random_seed(SEED, DETERMINISTIC)
    mmseg_swin_output = backbone['mmseg_swin_t'](sample)
    compare(mmdet_swin_output, mmseg_swin_output,
            'Swin-T backbone output[{layer}]: {result}')

    backbone_sd = {k: m.state_dict() for k, m in backbone.items()}
    compare(*[sd for k, sd in backbone_sd.items() if 'r50' in k],
            'ResNet backbone state_dict: {result}')
    compare(*[sd for k, sd in backbone_sd.items() if 'swin_t' in k],
            'Swin-T backbone state_dict: {result}')