microsoft / mup

maximal update parametrization (µP)
https://arxiv.org/abs/2203.03466
MIT License
1.37k stars 93 forks source link

Does mup support Swin Transformer v2 model? #21

Open shiyf129 opened 2 years ago

shiyf129 commented 2 years ago

Hi, we are trying to use mup tool to tune Swin Transformer v2 model. I modified the code of Swin Transformer v2 to adapt mup and executed the "save base shape" and "coordinate check". The results of "coordinate check" shows that it can not meet the requirements of mup.

Does mup support the Swin Transformer v2 model?

For the code of "swin_transformer_v2.py", I modified the following code (Because Swin Transformer v2 doesn't use "1/sqrt(d) attention scaling", I don't modify it):

  1. replaced the output layper nn.Linear with MuReadout
  2. replaced std normal init with mup normal init
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
# self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
### muP: replace nn.Linear with MuReadout
self.head = MuReadout(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

self.apply(self._init_weights)
for bly in self.layers:
    bly._init_respostnorm()
def _init_weights(self, m, readout_zero_init=False, query_zero_init=False):
    ### muP: swap constant std normal init with normal_ from `mup.init`.
    ### Because `_init_weights` is called in `__init__`, before `infshape` is set,
    ### we need to manually call `self.apply(self._init_weights)` after calling
    ### `set_base_shape(model, base)`
    if isinstance(m, nn.Linear):
        if isinstance(m, MuReadout) and readout_zero_init:
            m.weight.data.zero_()
        else:
            if hasattr(m.weight, 'infshape'):
                normal_(m.weight, mean=0.0, std=.02)
            else:
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    ### End muP
    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)

For the code of "main.py" of Swin Transformer, I added "save base shape" and "coordinate check" functions.

def main(config, args):
    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)

    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
    model = build_model(config)
    logger.info(str(model))

    ### muP
    if args.save_base_shapes:
        print(f'saving base shapes at {args.save_base_shapes}')
        base_shapes = get_shapes(model)
        delta_config = copy.deepcopy(config)
        delta_config.defrost()
        delta_config.MODEL.SWINV2.EMBED_DIM *= 2  # Modify SwinV2 embed dim
        delta_config.MODEL.SWIN.EMBED_DIM *= 2  # Modify Swin embed dim
        # delta_config.MODEL.SWIN_MOE.EMBED_DIM *= 2  # Modify Swin_moe embed dim
        delta_config.MODEL.SWIN_MLP.EMBED_DIM *= 2  # Modify Swin_mlp embed dim

        delta_shapes = get_shapes(
            # just need to change whatever dimension(s) we are scaling
            build_model(delta_config)
        )
        make_base_shapes(base_shapes, delta_shapes, savefile=args.save_base_shapes)
        print('done and exit')
        import sys;
        sys.exit()
    if args.load_base_shapes:
        print(f'loading base shapes from {args.load_base_shapes}')
        set_base_shapes(model, args.load_base_shapes)
        print('done')
    else:
        print(f'using own shapes')
        set_base_shapes(model, None)
        print('done')
### muP
def coord_check(mup, config, lr, optimizer, nsteps, nseeds, args, plotdir='', legend=False):
    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)

    def gen(w, standparam=False):
        def f():
            delta_config = copy.deepcopy(config)
            delta_config.defrost()
            delta_config.MODEL.SWINV2.EMBED_DIM = w  # Modify SwinV2 embed dim
            delta_config.MODEL.SWIN.EMBED_DIM = w  # Modify Swin embed dim
            # delta_config.MODEL.SWIN_MOE.EMBED_DIM = w  # Modify Swin_moe embed dim
            delta_config.MODEL.SWIN_MLP.EMBED_DIM = w  # Modify Swin_mlp embed dim
            model = build_model(delta_config)

            if standparam:
                set_base_shapes(model, None)
            else:
                assert args.load_base_shapes, 'load_base_shapes needs to be nonempty'
                set_base_shapes(model, args.load_base_shapes)
            return model
        return f

    optimizer = optimizer.replace('mu', '')
    widths = (12, 24, 48, 96, 192)
    models = {w: gen(w, standparam=not mup) for w in widths}

    # train_data = batchify(corpus.train, batch_size, device=args.device)
    df = get_coord_data(models, data_loader_train, mup=mup, lr=lr, optimizer=optimizer, flatten_output=True,
                        nseeds=nseeds, nsteps=nsteps, lossfn='xent')

    prm = 'muP' if mup else 'SP'
    return plot_coord_data(df, legend=legend,
                           save_to=os.path.join(plotdir, f'{prm.lower()}_trsfmr_{optimizer}_coord.png'),
                           suptitle=f'{prm} Transformer {optimizer} lr={lr} nseeds={nseeds}',
                           face_color='xkcd:light grey' if not mup else None)
if __name__ == '__main__':
    args, config = parse_option()

    ......

    ### muP
    if args.coord_check:
        print('testing parametrization')
        import os
        os.makedirs('coord_checks', exist_ok=True)
        plotdir = 'coord_checks'
        coord_check(mup=True, config=config, lr=0.0001, optimizer='adamw',
                    nsteps=args.coord_check_nsteps, nseeds=args.coord_check_nseeds, args=args,
                    plotdir=plotdir, legend=False)
        coord_check(mup=False, config=config, lr=0.0001, optimizer='adamw',
                    nsteps=args.coord_check_nsteps, nseeds=args.coord_check_nseeds, args=args,
                    plotdir=plotdir, legend=False)
        import sys
        sys.exit()

    main(config, args)

The results of "coordinate check" show that there is only a small difference between "mup" and "SP". sorry, I can't upload pictures. Could you please help us to check if mup can support Swin Transformer v2 model? or there are some other reasons? Thanks a lot.

edwardjhu commented 2 years ago

Hi!

The snippets you included seem reasonable, except that the widths tested seem small, if it's the d_model of a Transformer.

Can you try larger widths and attach the coord check plots?

QiyaoWei commented 1 year ago

sp_swin mup_swin

@shiyf129 I also think the snippets look reasonable. I have done coord checks on Swin as well, and I attach the plots here. Echoing Edward's suggestion, the widths tested is typically 256, 512, 1024, and 2048. Have you tried larger widths and attaching your coord check plots?