Closed jeromeku closed 1 year ago
Model parallelism in this library refers to the tensor parallelism variant of model parallelism (the other variant is pipeline parallelism, which is used less frequently), and FSDP in this library refers to the stage 3 of ZeRO. Consider the setting where you have a linear layer where you multiply the input with a big matrix. FSDP would shard the matrix along the contracting axis (input axis), so during the forward pass, the whole matrix need to be all-gathered on all accelerators to perform the matrix multiplication. The amount of computation is the same as in normal data parallelism, but instead of replicating the model weights across all accelerators, we store it in a sharded way and only gather the full weights for the layer we are computing.
On the other hand, model parallelism (tensor parallelism) shard the matrix along the non-contracting axis (output axis), so each accelerator would only need to multiply a shard of the full matrix. If the input activation is replicated across all accelerators, the output activation is then sharded, which might need to be gather for the next layer.
@young-geng thanks for the response!
So if I'm understanding you correctly:
M x N
. M
is the output or non-contracting axis and therefore N
is the input or contracting axis.FSDP
will shard along N
and Model Parallelism
(in the context of this library) would shard along M
.FSDP
would shard the linear weights into 2 M x (N / 2)
sized matrices while tensor parallelism would shard into 2 (M / 2) x N
sized matrices.FSDP
, an all-gather would take place so that each device now has the full M x N
matrix and compute the output on its respective data batch (per data + model parallelism)(M / 2) x N
x (N x 1)
(assuming batch size of 1 for simplicity) such that each device now has output activation of size (M / 2) x 1
. A gather would then need to take place to assemble the full output of size M x 1
for the next layer.Is this along the right lines?
Yeah, that seems correct.
@young-geng thanks for the responses.
I see that dp
, fsdp
, and mp
are used to denote sharding constraints in llama_model.py
.
parameters
and activations
for a transformer block / layer are sharded for a simple config of say 2 hosts with 8 accelerators each? Specifically, how the assignment of fsdp
vs mp
results in FSDP vs tensor parallelism?parameters
affect the sharding
of its input / output activations
? For those details I strongly recommend the distributed arrays and automatic parallelism documentation of JAX.
Thanks for the reference -- I've already gone through the Jax and Flax guides on distributed / parallelism.
Some clarifying questions:
FSDP
as it's used in EasyLM
: how is it sharding parameters / activations? I guess I'm still not clear how this differs from model parallelism as it's used in this project, since both are forms of tensor parallelism since it's splitting at the tensor level (as opposed to the layer level, i.e., pipeline parallelism).Appreciate the patience!
The major difference between FSDP and tensor parallelism is that FSDP does not shard activations along the hidden dimension. The activations are sharded only along the batch dimension, which is identical to normal DP. You can think of FSDP as a more efficient DP where the params are stored in a distributed way to save memory rather then replicated.
As for how JAX infers sharding of activations, I'd recommend reading the GSPMD paper for the details.
@young-geng
Great project and documentation.
Can you further elucidate the difference between FSDP and Model Parallelism? Isn't FSDP already a form of model parallelism? Trying to understand the nuanced differences between 3-stage DeepSpeed ZeRO, FSDP, and "model parallelism".
Thanks!