Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.35k stars 1.34k forks source link

Reduce training time? #642

Open zhoumengbo opened 1 year ago

zhoumengbo commented 1 year ago

Hello author, I noticed in the paper that when training a GPT-style model, the speed of the second generation improved by 1.3 times. I'm curious, does this imply that if I pre-train a model from scratch, with all other conditions remaining the same, using FlashAttention2 can reduce the overall training time by 1.3 times compared to using FlashAttention?

tridao commented 1 year ago

Yes, given the same setting as in the paper (e.g. context length 8k, 3B model, 8xA100). Ofc for other settings (model size, context length, hardware) the speedup will be different.

zhoumengbo commented 1 year ago

Yes, given the same setting as in the paper (e.g. context length 8k, 3B model, 8xA100). Ofc for other settings (model size, context length, hardware) the speedup will be different.

I would like to ask for your opinion on what proportion of the total pre-training time you believe is taken up by the Attention mechanism. Assuming Attention accounts for 50% and the remaining factors also contribute 50%. For FlashAttention1: 1 50% + 1 50% = 1; if we assume that the time consumption of FlashAttention2 is x compared to FlashAttention1, then 50% x + 50% 1 = 1/1.3, from which we can deduce that x = 53.85%. An approximate 53.85% corresponds to a doubling in speed, which aligns roughly with the twofold speed increase mentioned in the paper. Therefore, my hypothesis stands. Do you think this reasoning is correct? (If it is, I am curious to know if you have attempted to verify this with tools like profilie or if you have studied the proportional contribution of different modules during the pre-training process.)