google-research / big_vision

Official codebase used to develop Vision Transformer, SigLIP, MLP-Mixer, LiT and more.
Apache License 2.0
2.16k stars 147 forks source link

bfloat16 Training #41

Open LeoXinhaoLee opened 12 months ago

LeoXinhaoLee commented 12 months ago

Thank you for releasing code for these inspiring works!

I tried to use bfloat16 for model parameters, and manually converted images and labels from float32 to bfloat16 before feeding them for training, but noticed that training slowed down by about 3 times. Also, the performance becomes noticeably worse. I'm wondering if it is wrong to use bfloat16 in this way?

Thank you very much for your help.

andsteing commented 11 months ago

You mean that you have 3x slower step time? Or is it 3x slower to target accuracy? The first would be unexpected, but I wouldn't know why that is the case without examining the training with a profiler.

In general, you could not expect to have the same performance when going from float32 to bfloat16. With ViTs we found that the first Adam momentum can safely be kept in bfloat16 (example config), but the second moment and the model weights need to be kept in float32.