Open drisspg opened 1 month ago
Would you be interested to consider INT8 attention too? #952 (https://github.com/INT-FlashAttention2024/INT-FlashAttention)
There are also other triton/cuda kernels for int8 attention floating around but I haven't looked into them closely.
@gau-nernst Still working through this RFC not nearly complete yet but yeah going to add a section on int8 attention
Current State of OSS FP8 Operators
So far, all examples of fp8 ops (compute in fp8) are scaled matmuls that accumulate in a higher precision type. In fact, there are really only 2 classes of instructions that are supported in PTX:
The complexity of FP8 training (which is somewhat easier for inference) is that we need to efficiently calculate scales that align the current distribution of values in a high precision tensor to what is representable in fp8.
This is easier for inference because the weight is frozen and we can pre-calculate the scale.
Inference
Before we can walk, we must crawl. Let's look at what's available for inference, which is a strictly easier problem.
All of these are using TensorWise scaling.
Kernels
1. FAv3
2. FlashInfer
Prefill
BatchedPrefill with KVCache
Decode
TLDR: Uses a neat strategy for fusing scaling into existing kernels.
3. VLLM
4. FlexAttention
Note: This currently fails since we expect input to be on host, but we can fix, or use score_mod (fixing is better).
This is idealized too since not accounting for casting overhead or epilogue kernel
5. Transformer Engine
6. TensorRt
Some Code Runs
Flex Experiments