young-geng / EasyLM

Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
Apache License 2.0
2.38k stars 254 forks source link

FSDP vs Model Parallelism #42

Closed jeromeku closed 1 year ago

jeromeku commented 1 year ago

@young-geng

Great project and documentation.

Can you further elucidate the difference between FSDP and Model Parallelism? Isn't FSDP already a form of model parallelism? Trying to understand the nuanced differences between 3-stage DeepSpeed ZeRO, FSDP, and "model parallelism".

Thanks!

young-geng commented 1 year ago

Model parallelism in this library refers to the tensor parallelism variant of model parallelism (the other variant is pipeline parallelism, which is used less frequently), and FSDP in this library refers to the stage 3 of ZeRO. Consider the setting where you have a linear layer where you multiply the input with a big matrix. FSDP would shard the matrix along the contracting axis (input axis), so during the forward pass, the whole matrix need to be all-gathered on all accelerators to perform the matrix multiplication. The amount of computation is the same as in normal data parallelism, but instead of replicating the model weights across all accelerators, we store it in a sharded way and only gather the full weights for the layer we are computing.

On the other hand, model parallelism (tensor parallelism) shard the matrix along the non-contracting axis (output axis), so each accelerator would only need to multiply a shard of the full matrix. If the input activation is replicated across all accelerators, the output activation is then sharded, which might need to be gather for the next layer.

jeromeku commented 1 year ago

@young-geng thanks for the response!

So if I'm understanding you correctly:

Is this along the right lines?

young-geng commented 1 year ago

Yeah, that seems correct.

jeromeku commented 1 year ago

@young-geng thanks for the responses.

I see that dp, fsdp, and mp are used to denote sharding constraints in llama_model.py.

young-geng commented 1 year ago

For those details I strongly recommend the distributed arrays and automatic parallelism documentation of JAX.

jeromeku commented 1 year ago

Thanks for the reference -- I've already gone through the Jax and Flax guides on distributed / parallelism.

Some clarifying questions:

Appreciate the patience!

young-geng commented 1 year ago

The major difference between FSDP and tensor parallelism is that FSDP does not shard activations along the hidden dimension. The activations are sharded only along the batch dimension, which is identical to normal DP. You can think of FSDP as a more efficient DP where the params are stored in a distributed way to save memory rather then replicated.

As for how JAX infers sharding of activations, I'd recommend reading the GSPMD paper for the details.