Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.65k stars 1.25k forks source link

Question regarding overlapping #1160

Open wookjeHan opened 1 month ago

wookjeHan commented 1 month ago

First of all, Thank you for sharing your excellent work!

I have a question about overlapping (pingpong design). From my understanding:

1) With FP8 precision and a head dimension of 128, the exponential function seems to take the same amount of time as GEMM0 + GEMM1. This is because matmul is 512 times faster than the exponential function, and there are 512 times more FLOPs in matmul than the exponential operation.

2) The softmax function involves two MUFU: one for the exponential function and another for floating-point division. This indicates that softmax should be about twice as slow as GEMM0+GEMM1.

However, the figure provided shows that softmax takes only half as much time as GEMM0 + GEMM1. If softmax takes twice more time than GEMMs (GEMM0 + GEMM1), GEMM operations might be idle for half of the time.

image

So my questions are:

1) Is my understanding of the time consumption for GEMM and softmax correct? Specifically, is the time consumption really bound by MUFU for FP8 precision and a head dimension of 128? 2) If my understanding is correct, does this mean that GEMM operations are only executing for half the time while the other half of the time, they remain idle? Thank you!

tridao commented 1 month ago

The figure is not drawn to scale, it's just an illustration.

The way we do it, softmax only has 1 MUFU (exponential). There's no floating point division. Division is done at the very end.