Open TPFRL opened 10 months ago
Different mesh dims correspond to different sharding strategies. While they do not define a batch size, they do incur certain constraints on the possible batch size.
1,1,-1
corresponds to tensor parallelism only, and you can use any batch size you want1,-1,1
corresponds to full FSDP, this means that your batch size needs to a multiple of number of devices (8 here)-1,1,1
corresponds to full DP, this also means that your batch size needs to a multiple of number of devices (8 here)Hi, thanks for this amazing repo. I was wondering how should I set batch size to make a desirable full batch size.
For example, if I set train_dataset.huggingface_dataset.batch_size to 1 on TPUv3-8, what is the full batch size given mesh_dim 1,1,-1 / 1,-1,1 / -1,1,1 ? Are all of them 8? or 1?
Thanks!
Did you get this to run with a v3? I seem to always get HLO out of memory errors.
A single v3-8 only has 128GB of memory in total, which might not be sufficient for training a 7B model.
Hi, thanks for this amazing repo. I was wondering how should I set batch size to make a desirable full batch size.
For example, if I set train_dataset.huggingface_dataset.batch_size to 1 on TPUv3-8, what is the full batch size given mesh_dim 1,1,-1 / 1,-1,1 / -1,1,1 ? Are all of them 8? or 1?
Thanks!