pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.62k stars 179 forks source link

[RFC] LowBit Fused Attention #1071

Open drisspg opened 1 month ago

drisspg commented 1 month ago

Current State of OSS FP8 Operators

So far, all examples of fp8 ops (compute in fp8) are scaled matmuls that accumulate in a higher precision type. In fact, there are really only 2 classes of instructions that are supported in PTX:

The complexity of FP8 training (which is somewhat easier for inference) is that we need to efficiently calculate scales that align the current distribution of values in a high precision tensor to what is representable in fp8.

This is easier for inference because the weight is frozen and we can pre-calculate the scale.

Inference

Before we can walk, we must crawl. Let's look at what's available for inference, which is a strictly easier problem.

All of these are using TensorWise scaling.

Kernels

1. FAv3

2. FlashInfer

Prefill
BatchedPrefill with KVCache
Decode

TLDR: Uses a neat strategy for fusing scaling into existing kernels.

3. VLLM

4. FlexAttention

This is idealized too since not accounting for casting overhead or epilogue kernel

5. Transformer Engine

6. TensorRt

Some Code Runs

Flex Experiments

from functools import partial
from typing import Optional
import torch
import torch.nn.functional as F
import math
from tabulate import tabulate
from torch.nn.attention.flex_attention import flex_attention
from triton.testing import do_bench

torch.set_default_device("cuda")
torch.manual_seed(0)
torch._dynamo.config.cache_size_limit = 1000
torch._inductor.config.triton.unique_kernel_names = True

# For better performance, you can use:
data_type = torch.float16

def main(do_fp16=False, max_autotune: bool = False):
    try:
        from torchao.float8.float8_tensor import Float8Tensor, hp_tensor_and_scale_to_float8
        from torchao.float8.float8_utils import tensor_to_scale
    except ImportError:
        raise ImportError("Fp8 example needs torchao to run!")

    data_type = torch.float16
    make_tensor = partial(torch.rand, device="cuda", dtype=data_type)
    input_size = (4, 16, 8192, 128)
    q, k, v = make_tensor(input_size), make_tensor(input_size), make_tensor(input_size)

    if max_autotune:
        flex = torch.compile(
            flex_attention, dynamic=True, mode="max-autotune-no-cudagraphs"
        )
    else:
        flex = torch.compile(flex_attention, dynamic=False)

    if do_fp16:
        float16_time = do_bench(lambda: flex(q, k, v))
        print(f"{float16_time=}")

    # Maximal perf time
    q_fp8 = hp_tensor_and_scale_to_float8(q, tensor_to_scale(q, torch.float8_e4m3fn))
    k_fp8 = hp_tensor_and_scale_to_float8(k, tensor_to_scale(k, torch.float8_e4m3fn))
    v_fp8 = hp_tensor_and_scale_to_float8(v, tensor_to_scale(v, torch.float8_e4m3fn))
    sm_scale = 1.0 / math.sqrt(64)
    sm_scale *= q_fp8._scale.reciprocal() * k_fp8._scale.reciprocal()
    # Work around for now
    sm_scale = sm_scale.item()
    q_fp8_data = q_fp8._data
    k_fp8_data = k_fp8._data
    v_fp8_data = v_fp8._data

    flex(q_fp8_data, k_fp8_data, v_fp8_data, scale=sm_scale)
    fp8_time = do_bench(lambda: flex(q_fp8_data, k_fp8_data, v_fp8_data, scale=sm_scale))
    print(f"{fp8_time=}")

if __name__ == "__main__":
    try:
        from jsonargparse import CLI
    except ImportError:
        raise ImportError("Be sure to run: pip install -e .'[viz]'")
    CLI(main)
gau-nernst commented 1 month ago

Would you be interested to consider INT8 attention too? #952 (https://github.com/INT-FlashAttention2024/INT-FlashAttention)

There are also other triton/cuda kernels for int8 attention floating around but I haven't looked into them closely.

drisspg commented 1 month ago

@gau-nernst Still working through this RFC not nearly complete yet but yeah going to add a section on int8 attention