berniwal / swin-transformer-pytorch

Implementation of the Swin Transformer in PyTorch.
https://arxiv.org/pdf/2103.14030.pdf
MIT License
782 stars 125 forks source link

Training advise with swin_transformer - initialization with GELU, etc. #3

Open lessw2020 opened 3 years ago

lessw2020 commented 3 years ago

Hi, I've been setting up with swin_transformer but having a hard time getting it to actually train.
I figured one immediate issue is the lack of init, so I'm using the truncated init setup from rwightman/pytorch he used in ViT impl since that also uses GELU. But regardless, I'm not able to get it to learn atm even after testing out a range of lr.

Thus wondering if anyone has found some starting hyperparams and/or init method, to get it up and training?

lessw2020 commented 3 years ago

After some fiddling, I've finally got one up and training! (lr = 3e-3, ViT style weight init, AdamW optimizer). Whether this is optimal is hard to say but compared to the earlier flatlines, this is a big improvement. swin_training

habibian commented 3 years ago

Cool! Would be great to keep us posted on whether you could replicate the results!

lessw2020 commented 3 years ago

I overlooked this (that they add a global average pooling layer) so that's a key difference and that they had published their params (AdamW, 3e-1 learning rate, etc). Will add and test: ic_swin_transformer

lessw2020 commented 3 years ago

nvm, global avg pooling is already in the impl now with this line: swin_gap

lessw2020 commented 3 years ago

ok adding in their warmup I'm seeing pretty good results (slow but steady, but that's typical of transformers). I'm adding in gradient clipping now as a final test. Here's latest curves on a small subset of dog breeds from ImageNet: swin_training_100e

lessw2020 commented 3 years ago

I tested with both gradient clipping as in the paper and with adaptive gradient clipping.
Results were nearly identical in terms of validation loss (technically hard clipping at 1.0 as in the paper was the winner, but since I only ran 1x it seems within random range of each other). As a final test, I'm running with the new FB AI MADGRAD optimizer which seems to outperform on transformers.

YuhsiHu commented 3 years ago

Hi, I want to use swin transformer to replace feature pyramid net, what should I do to modify the code?

lessw2020 commented 3 years ago

so madgrad blew away all my previous results, nearly an 18% improvement for the same limited run time (22 epochs). my friend also tested on tabular data and had similar results, with madgrad blowing away his previous benchmarks... their weight decay is not implemented adamw style though, so I've adjusted and testing that now to see if that helps at all.

habibian commented 3 years ago

Exciting! Thanks for keeping us posted ...

lessw2020 commented 3 years ago

Running madgrad with AdamW style decay and using similar decay as for AdamW so far has been the best results (slightly better accuracy and loss vs no weight decay or minor weight decay).

yueming-zhang commented 3 years ago

Thanks for sharing the experiments. I also tried a simple classification but couldn't work. I would appreciate if you can have some advices:

I noticed the Avg Pooling comment, it seems "x = x.mean(dim=[2, 3])" in the forward path already doing that. so all I did was a simple training:

model_ft = swin_l(hidden_dim=192, layers=(2, 2, 18, 2), heads=(6, 12, 24, 48), channels=3, num_classes=2, head_dim=32, window_size=7, downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True)

nn.CrossEntropyLoss() optim.AdamW(model_ft.parameters(), lr=3e-3)

image

lessw2020 commented 3 years ago

Hi @yueming-zhang, Did you use any type of lr warmup and schedule? I did use a linear one as did the swin authors per their paper and then a cosine type decay. That might dampen the wild swings you are showing in the first 15 epochs there. Also, if you go more than say 10 epochs with a flat val loss, you should stop training and save your gpu time :) It's unlikely to get better so faster to stop and reset params. Secondly, I would recommend trying madgrad instead of AdamW - I just posted madgrad with adamW style decay here: https://github.com/lessw2020/Best-Deep-Learning-Optimizers/tree/master/madgrad but you could test with weight_decay=0 to start to keep it simple. Just see if that helps. Anyway, that would be my first two recommendations: 1 - do a warmup and then either a cosine or step decay for the lr. Just running flat rate lr probably won't go well. 2 - maybe try madgrad, mostly b/c my val accuracy jumped 18%+ just from that change alone.

yueming-zhang commented 3 years ago

appreciate your response @lessw2020 , I tested AdamW with my custom scheduler and yield following result. (note the very small LR). Seems there is a negative correlation between LR and Accu. I briefly tested the madgrad, but the result is not ideal.

image image

jm-R152 commented 3 years ago

@lessw2020 Could you share how do you implement the above discussion in the code?

jm-R152 commented 3 years ago

@yueming-zhang I met the same situation as above, could you tell me how do you solve that and do you use any pre-trained weight?

yueming-zhang commented 3 years ago

Hi jm-R512, All I did was use cosine LR schedule and started with a very low rate. I didn't use the pre-trained weight.

tanveer-hussain commented 3 years ago

Thank you @lessw2020 AdamW optimizer has very good affect on training, do you've any further suggestions for improvements.