Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.65k stars 1.25k forks source link

INT8 versions of FMHA and Flash-Attention (Forward) #122

Open jundaf2 opened 1 year ago

jundaf2 commented 1 year ago

Hi @tridao, we recently implemented INT8 forward FMHA (8-bit Flash-Attention) with both static and dynamic quantization for Softmax on our GPGPU card, and achieved good results and relatively okay accuracy. For confidential reasons, we cannot open-source our CPP/CUDA code (we use our own hardware and compilers that is compatible with CUDA) but the basic method is shown in the equations and python simulations in this GitHub repo: https://github.com/jundaf2/INT8-Flash-Attention-FMHA-Quantization One can easily write the CUDA code by following these equations if he or she is a beginner in 8-bit quantization but familiar with GPU programming. The drawback is that it requires several additional parameters for quantization and de-quantization, but the use of additional parameters seems to be common in the INT8 society.

Update on June 20th, 2023: we recently rewrote the 8-bit flash attention using Nvidia CUDA based on WMMA-API and updated the original repo. Further optimizations of the performance will be delivered in the next phase. Sorry for the delay. Thanks.

tridao commented 1 year ago

This is such great work! Thanks @jundaf2 for writing it up and running the simulation!

lingffff commented 1 year ago

Any plan to open-source INT8 forward FMHA (8-bit Flash-Attention)?

jundaf2 commented 1 year ago

We implemented using our own instruction set and built-in functions, you can easily write one using Nvidia GPU if you know how to quantize and dequantize the data. Thanks!

jundaf2 commented 1 year ago

Hi, we recently rewrote the 8-bit flash attention using Nvidia CUDA based on WMMA-API and updated the original repo. Further optimizations of the performance will be delivered in the next phase.

goodluckcwl commented 1 year ago

Have you try 4-bit flash attention using Nvidia CUDA based on WMMA-API?

jundaf2 commented 1 year ago

Hi @goodluckcwl. Using 4-bit precision is a good idea but I haven't tried it. I didn't have the hardware that supports 4-bit tensor core operations at the time I worked on the inference of flash-attention~

manupak commented 8 months ago

Hi @jundaf2,

Its cool to see int8 version is(/was?) being worked on. Im curious to understand how would you obtain the scales for "P" tensor if that something you can share ? (as it represent a distribution within softmax)