young-geng / EasyLM

Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
Apache License 2.0
2.33k stars 247 forks source link

What is the logic behind the partitions? #65

Closed gianlucadetommaso closed 1 year ago

gianlucadetommaso commented 1 year ago

I have been looking into the the partition rules you use for llama, gptj and roberta.

Would you mind explaining the idea behind these specific partitioning schemes? Is there any particular heuristic that you are following? For example, I have notice that you often invert the "mp" and "fsdp" axis for consecutive layers. Is this to minimize communication costs, or is there any other reason?

Thanks a lot!

young-geng commented 1 year ago

For FSDP, it really doesn't matter which axis we put it in as the operation is always all-gather. Regarding tensor parallelism, I strongly recommend this post to get a sense of how tensor parallelism works for Transformers.

gianlucadetommaso commented 1 year ago

Thanks for the link! A few follow-up questions.

  1. I guess this means the it also doesn't really matter on which axis we put "mp"?

  2. Do you have any general recommendation on how to set the mesh axis dimensions for "dp", "fsdp" and "mp" given a certain number of devices?

  3. In the whole code, I couldn't find a different treatment for the different mesh axis. The only difference is that "mp" is used only to partition model parameters, "dp" is used only to partition batches, while "fsdp" is used for both. Is this correct?

young-geng commented 1 year ago
  1. For MP it matters a lot. FSDP performs all-gather on the layer weights before computing the output, so it doesn't matter how the weights are partitioned. On the other hand, tensor parallelism does not gather the weights and one device performs only part of the computation of one layer. Therefore setting MP axis incorrectly would result in extra computation and communication.

  2. For smaller models (<20B), it is recommended to use full FSDP, so you can set 1,-1,1 for the mesh dim. For larger models, you might want to use a mixture of FSDP and MP. Generally you'll need to try out some configurations on your setup to determine the best one.

  3. This is mostly correct, with the exception that mp needs to be set correctly to optimal partition of the ops. JAX treats all type of parallelisms as tensor sharding so we don't need to handle them separately.

gianlucadetommaso commented 1 year ago

Thanks, this helps!