huggingface / nanotron

Minimalistic large language model 3D-parallelism training
Apache License 2.0
1.23k stars 122 forks source link

[Question] Async Tensor Parallel #48

Open woshiyyya opened 9 months ago

woshiyyya commented 9 months ago

I am a bit curious about the below example in 3d_parallelism.md

Q: Can you give a concrete example illustrating how asynchronous tensor parallelism works? (6 steps)

A:

Step 1: Let's look at an example with 4 GPU ranks:
    Input X sharded across ranks as [X0, X1, X2, X3]
    Weight matrix W sharded as [W0, W1, W2, W3]
Step 2: Rank 2 kicks off async all-gather to get [X0, X1, X2, X3]
Step 3: While gathering, rank 2 computes: local_output = X2 * W2
Step 4: All-gather completes, rank 2 has [X0, X1, X2, X3]
Step 5: Rank 2 computes: before_local_output = X0 * W0 + X1 * W1, after_local_output = X3 * W3
Step 6: Rank 2's output = before_local_output + local_output + after_local_output

So each rank computes the full output using the locally gathered X and its shard of W.

Since W is also sharded across different ranks (rank i possesses the Wi), in which step do we gather the sharded weights?

Just want to confirm if my understanding is correct. Suppose TP=4, you have a input of shape [B x M], and a weight of shape [M x N]. Then the shape of sharded input Xi be [B, M/4] and Wi be [M/4, N]? Is it row-based tensor parallel?

xrsrke commented 9 months ago

@woshiyyya Hi. Sorry for the deplayed response :)

Is it row-based tensor parallel

Yes. Currently we only use async communication in row-based [link]

Then the shape of sharded input Xi be [B, M/4] and Wi be [M/4, N]?

Nope. The sharded input Xi is [B/4, M], and W_i is [M, N/4].

woshiyyya commented 9 months ago

Thanks for the reply~ I am still confused, since from the code I saw the row-based linear layer should have weights in shape [M/4, N].

class TensorParallelRowLinear(nn.Linear):
        ...
        self.in_features = in_features // self.world_size
        self.out_features = out_features

And if rank 2's output is X0 * W0 + X1 * W1 + X2 * W2 + X3 * W3, that will be of shape [B/4, N/4] which is only 1/16 of the output?