bigscience-workshop / Megatron-DeepSpeed

Ongoing research training transformer language models at scale, including: BERT & GPT-2
Other
1.32k stars 213 forks source link

Fix tflops glu computation #283

Closed Muennighoff closed 2 years ago

Muennighoff commented 2 years ago

From Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM:

A 𝐴𝑚×𝑘 ×𝑋𝑘×𝑛 matrix multiplication requires 2𝑚 ×𝑘 ×𝑛 FLOPs

The feed-forward network increases the hidden size to 4ℎ and then reduces it back to ℎ; this requires 16𝐵𝑠ℎ^2 FLOP

We know that

Normal
(b, s, h) * (h, h*4)
SwiGLU
(b, s, h) * (h, h*8)

Hence

B*(2 * s * h * (h*4)) = b*(8sh^2) = 8bsh^2 -> * 2 for downscaling again -> 16bsh^2 (This turns then into 24𝐵𝑠ℎ^2 + 4𝐵𝑠2ℎ + ... with some additional operations, see the paper)

B*(2 * s * h * (h*8)) = b*(16sh^2) = 16bsh^2 (upscaling) + SwiGLU + 8bsh^2 (downscaling)-> 24bsh^2 (-> We need to add 8, thus 32𝐵𝑠ℎ^2 + 4𝐵𝑠2ℎ + ...)

I.e. we need to increase the coefficient by 8.

The SwiGLU operation also adds another 0.5*bsh*8 I think, but ignoring them here for simplicity. Kudos to @DanielHesslow!

thomasw21 commented 2 years ago

This seems indeed correct, though I would suggest we have a more transparent computation of flops (ie decompose 24 and 32 into the individual components: query key value, output attention, MLP ...). WDYT @stas00 ?

(also think we should remove the swiglu operator as the previous part did not take in account relu activation which seems reasonnable)

stas00 commented 2 years ago

Thank you, @Muennighoff - could you please adjust the math in the OP - I think there are a few h2 and 2h when h^2 was probably meant?

I would suggest we have a more transparent computation of flops (ie decompose 24 and 32 into the individual components: query key value, output attention, MLP ...).

sure, if it's not too much trouble. and adding comments to where the numbers come from? similar to what OP did?

Muennighoff commented 2 years ago

Thank you, @Muennighoff - could you please adjust the math in the OP - I think there are a few h2 and 2h when h^2 was probably meant?

I would suggest we have a more transparent computation of flops (ie decompose 24 and 32 into the individual components: query key value, output attention, MLP ...).

sure, if it's not too much trouble. and adding comments to where the numbers come from? similar to what OP did?

Standardized to h^2 everywhere 👍 I added a link to this PR & a brief explanation in the code. I think given that it also links to the paper it is fine to leave it as is? I think I can't explain it better than the paper, so people are best of reading it there.