young-geng / EasyLM

Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
Apache License 2.0
2.33k stars 247 forks source link

What is the full batch size if mesh_dim is set to 1,1,-1, on TPUv3-8? #94

Open TPFRL opened 10 months ago

TPFRL commented 10 months ago

Hi, thanks for this amazing repo. I was wondering how should I set batch size to make a desirable full batch size.

For example, if I set train_dataset.huggingface_dataset.batch_size to 1 on TPUv3-8, what is the full batch size given mesh_dim 1,1,-1 / 1,-1,1 / -1,1,1 ? Are all of them 8? or 1?

Thanks!

young-geng commented 10 months ago

Different mesh dims correspond to different sharding strategies. While they do not define a batch size, they do incur certain constraints on the possible batch size.

jcole75 commented 9 months ago

Hi, thanks for this amazing repo. I was wondering how should I set batch size to make a desirable full batch size.

For example, if I set train_dataset.huggingface_dataset.batch_size to 1 on TPUv3-8, what is the full batch size given mesh_dim 1,1,-1 / 1,-1,1 / -1,1,1 ? Are all of them 8? or 1?

Thanks!

Did you get this to run with a v3? I seem to always get HLO out of memory errors.

young-geng commented 9 months ago

A single v3-8 only has 128GB of memory in total, which might not be sufficient for training a 7B model.