Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.84k stars 1.28k forks source link

[Questions] Triton blocksparse flashattention & quantization #82

Open fxmarty opened 1 year ago

fxmarty commented 1 year ago

Hello @tridao ,

Congratz on the work on FlashAttention. It seems to already have some huge impact, being integrated in kernl, pytorch's nightlies BetterTransformer and I'm sure others.

I had two questions:

Pinging @pommedeterresautee if you work on these topics or have some insighs!

Thanks a lot!

tridao commented 1 year ago

Congratz on the work on FlashAttention. It seems to already have some huge impact, being integrated in kernl, pytorch's nightlies BetterTransformer and I'm sure others.

Thanks for the kind words, we've been very happy to see FlashAttention being used in many places.

I had two questions:

  • Are you aware of an implementation of blocksparse flashattention in OpenAI Triton? Are there any benchmark available of possible speedups / loss in prediction quality depending on the sparsity ratio of the mask matrices? Currently as pointed out in an other issue it seems like DeepSpeed's FixedSparsityConfig is used.

I'm not aware if blocksparse FlashAttention is implemented in Triton, but that seems like a good idea! Triton has blocksparse matrix multiply implemented.

In terms of speedup, we observed speedup proportional to the density (e.g. if 20% of the blocks are nonzero, then attention goes 5x faster).

In terms of quality, it's hard to say. For simpler task (e.g. Long-range arena) block sparse attention seems to do about as well as dense attention. For language modeling, the GPT3 paper says they alternate dense and block sparse attention, and I think GPT-J alternates dense and local (a form of sparsity) attention. However, in general sparse attention hasn't been as widely used as dense attention. Maybe for really long sequences we could have a good use case?

  • Is there any effort to implement flashattention / blocksparse flashattention with integer arithmetic (e.g. int8 GEMM)? Do you think it could be worthwhile throughput-wise?

This is a good idea! I think it would make attention go twice as fast, since you would need to load 2x fewer bytes (for both global memory loading and shared memory loading).

kurumuz commented 1 year ago

For language modeling, the GPT3 paper says they alternate dense and block sparse attention, and I think GPT-J alternates dense and local (a form of sparsity) attention

GPT-J 6B is dense attention only. GPT-Neo used alternating local attention every other layer, but perhaps we can just get rid of attention every other layer or do very tiny attention and get similar results.

tridao commented 1 year ago

Ha, interesting idea! The overhead of creating the sparse mask should be pretty small, so we can probably change the mask between iterations.

pzheng16 commented 1 year ago

Is there any effort to implement flashattention / blocksparse flashattention with integer arithmetic (e.g. int8 GEMM)? Do you think it could be worthwhile throughput-wise? This is a good idea! I think it would make attention go twice as fast, since you would need to load 2x fewer bytes (for both global memory loading and shared memory loading).

How should we maintain the numerical stability while quantizing and dequantizing P in each step of flash attention? Normally, softmax results will be quantized for the 2nd GEMM. Using fp16 for the 2nd GEMM might be a choice, but I am unsure if there is a way to keep int8 GEMMs for both QK and PV.

tridao commented 1 year ago

How should we maintain the numerical stability while quantizing and dequantizing P in each step of flash attention? Normally, softmax results will be quantized for the 2nd GEMM. Using fp16 for the 2nd GEMM might be a choice, but I am unsure if there is a way to keep int8 GEMMs for both QK and PV.

I haven't thought too much about this but this is a general problem even if you implement it in Pytorch (you would have to decide how to quantize P). For training I think fp16 could possibly work (idk if Nvidia's Transformer Engine converts P to fp8 for H100). For inference Int8 could work, as done in FasterTransformer I think. In the end one would probably have to try it out so see.

jundaf2 commented 1 year ago

Hi @tridao, @pzheng16, and @fxmarty, we recently implemented INT8 forward FMHA (8-bit Flash-Attention) with both static and dynamic quantization for Softmax on our GPGPU card, and achieved good results and relatively okay accuracy. For confidential reasons, we cannot open-source our CPP/CUDA code but the basic method is shown in the equations and python simulations in this GitHub repo: https://github.com/jundaf2/INT8-Flash-Attention-FMHA-Quantization

fxmarty commented 1 year ago

Super cool, thank you for sharing!