Open woshiyyya opened 9 months ago
@woshiyyya Hi. Sorry for the deplayed response :)
Is it row-based tensor parallel
Yes. Currently we only use async communication in row-based [link]
Then the shape of sharded input Xi be [B, M/4] and Wi be [M/4, N]?
Nope. The sharded input Xi is [B/4, M], and W_i is [M, N/4].
Thanks for the reply~ I am still confused, since from the code I saw the row-based linear layer should have weights in shape [M/4, N]
.
class TensorParallelRowLinear(nn.Linear):
...
self.in_features = in_features // self.world_size
self.out_features = out_features
And if rank 2's output is X0 * W0 + X1 * W1 + X2 * W2 + X3 * W3
, that will be of shape [B/4, N/4]
which is only 1/16 of the output?
I am a bit curious about the below example in
3d_parallelism.md
Since W is also sharded across different ranks (rank i possesses the Wi), in which step do we gather the sharded weights?
Just want to confirm if my understanding is correct. Suppose TP=4, you have a input of shape
[B x M]
, and a weight of shape[M x N]
. Then the shape of sharded input Xi be[B, M/4]
and Wi be[M/4, N]
? Is it row-based tensor parallel?