Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.01k stars 1.31k forks source link

FP8 for flash attention 3 and possible concerns #1169

Open TheTinyTeddy opened 2 months ago

TheTinyTeddy commented 2 months ago

Thank you for this amazing work!

I was wondering if the fp8 implementation of flash attention 3 will be able for public to use? My main concern will be accuracy (block quant may have alleviated this issue) and the performance impact from the quantization scaling as in: input fp16->MMA fp8->softmax fp32->MMA fp8->output fp16, and also the wait time between OPs.

I am eager to hear your proposed solution.

Many thanks

tridao commented 2 months ago

FA3 FP8 code is already public in this repo. Accuracy is a open problem, I don't think the community has a consensus on what's the best way to quantize. One can also study that independent of FA3: just do standard attention with quantization and measure accuracy.

What's the "wait time between OPs"?

TheTinyTeddy commented 2 months ago

FA3 FP8 code is already public in this repo. Accuracy is a open problem, I don't think the community has a consensus on what's the best way to quantize. One can also study that independent of FA3: just do standard attention with quantization and measure accuracy.

What's the "wait time between OPs"?

Thank you for the reply. I can see the fp8 implementation now. I have a few questions if you don't mind answering?

1) I can see that if the input is fp8 then the flash_attn_func will produce the result (albeit with lower accuracy), but if the input is fp16, then is it possible to do online scalings to fp8 when needed? Did you use any quantisation scaling calls if the input is fp16?

2) The "wait time between OPs" as I understand is that if tensor core perform fp8 MMA, then it is done much faster than those vector OPs (mainly exp performed in fp32), so did you use a larger block size or is the multi-stage pipeline (such as 3-stage method) can alleviate this fp8 pipeline latency issue (if the issue exists at all, don't know if there is any profiling experiments)?

Many thanks

tridao commented 2 months ago

You can try out the only scalings you suggest (input fp16 but casted to fp8 for matmul) and measure accuracy. This can be done independent of FA3. I don't think we've tried that.

TheTinyTeddy commented 2 months ago

You can try out the only scalings you suggest (input fp16 but casted to fp8 for matmul) and measure accuracy. This can be done independent of FA3. I don't think we've tried that.

I see. So you didn't include any quantization scaling within the fp8 flash attention 3 kernel?

yatorho commented 3 weeks ago

Yes, I have the same question. I browsed the code under the hopper folder in the repository, but did not find any content related to block quantization and incoherent processing skills. Can the author give more information? Thank you very much.

Nalilik commented 1 week ago

Same question too.....looking forward to you replay. Thanks.