microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
3k stars 201 forks source link

Q) Tensor parallel for magneto #13

Closed taehwakkwon closed 1 year ago

taehwakkwon commented 1 year ago

When magneto is applied, it is hard to apply tensor parallel(TP). Gathering tensors in prev of subln and scattering after subln cause so much communication cost. Do you have any code or idea how to solve it?

shumingma commented 1 year ago

To speed up the model parallelism of subln, one idea is to compute (and only sync) the mean/variance of LNs before gathers/scatters. This should significantly reduce the communication cost.

taehwakkwon commented 1 year ago

Yeah. I have thought about that too. Although communicating mean of LNs will reduce cost, sharing variance is not that simple. If we parallelize tensor as N. We have to calculate covariance Comb(N, 2) times because parallelized tensors are not independant. ex) V(x+y+z) = V(x) + V(y) + V(z) + 2cov(x,y) + 2cov(y,z) + 2cov(z, x)

shumingma commented 1 year ago

Yeah. I have thought about that too. Although communicating mean of LNs will reduce cost, sharing variance is not that simple. If we parallelize tensor as N. We have to calculate covariance Comb(N, 2) times because parallelized tensors are not independant. ex) V(x+y+z) = V(x) + V(y) + V(z) + 2cov(x,y) + 2cov(y,z) + 2cov(z, x)

We can split the parameter matrix along its column rather than the row (see Sec. 3 in Megatron-LM paper). Then the parallelized tensors are concatenated rather than added, and they are independent.

taehwakkwon commented 1 year ago

We can split it with rows. But it is not identical without training tensor parallel. Dimension of layernorm shrinks to 'h/TP' which can cause different distributions of each tensors. Also this can cause problems while we inference model after training.

shumingma commented 1 year ago

We can split it with rows. But it is not identical without training tensor parallel. Dimension of layernorm shrinks to 'h/TP' which can cause different distributions of each tensors. Also this can cause problems while we inference model after training.

I may have missed something. Could you explain a little bit about ''Dimension of layernorm shrinks to 'h/TP' which can cause different distributions of each tensors.''?

taehwakkwon commented 1 year ago

If we parallelize whole model as TP(2, 4 or 8) sub_lm layernorm's normalized_shape should be half as down below. 'self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps)' ffn_dim -> ffn_dim/TP 'self.inner_attn_ln = ( MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps)) if subln and self.self_attention else None )' self.embed_dim -> self.embed_dim/TP Therefore, parallelized tensors cannot share their mean and variance. Does this cause any degradation?

shumingma commented 1 year ago

If we parallelize whole model as TP(2, 4 or 8) sub_lm layernorm's normalized_shape should be half as down below. 'self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps)' ffn_dim -> ffn_dim/TP 'self.inner_attn_ln = ( MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps)) if subln and self.self_attention else None )' self.embed_dim -> self.embed_dim/TP Therefore, parallelized tensors cannot share their mean and variance. Does this cause any degradation?

For $Y=LN(GeLU(AX))$, tensor parallelism makes it to $Y=LN(GeLU([A_1X_1, A_2X_2]))=LN([GeLU(A_1X_1), GeLU(A_2X_2)])$

Denoting $Z_1=GeLU(A_1X_1)$ and $Z_2=GeLU(A_2X_2)$, then

$Y=[\frac{Z_1-mean([Z_1,Z_2])}{\sqrt{Var([Z_1,Z_2])}}, \frac{Z_2-mean([Z_1,Z_2])}{\sqrt{Var([Z_1,Z_2])}}]$ (omitting $\gamma, \beta, \epsilon$ here)

Here, $Z_1$ and $Z_2$ are independent (because we split the column rathers than rows), so $mean([Z_1,Z_2])=mean(Z_1)+mean(Z_2)$ and $Var([Z_1,Z_2])=Var(Z_1)^2+Var(Z_2)^2$

Therefore, we can sync the mean and variance before computing LN, and that exactly matches when there's no tensor parallelism.

shumingma commented 1 year ago

@taehwakkwon Sequence parallel is another option to reduce the communication cost of LayerNorm. See https://arxiv.org/pdf/2205.05198.pdf

taehwakkwon commented 1 year ago

Thanks for sharing!