turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.19k stars 234 forks source link

Triton based flash attention 2 that supports volta and up. #411

Closed Ph0rk0z closed 2 weeks ago

Ph0rk0z commented 2 months ago

So vllm has a triton based flash attention: https://github.com/vllm-project/vllm/blob/main/vllm/attention/ops/triton_flash_attention.py

Supposedly it supports volta+. Not sure if it's slower or faster, it requires a lot more arguments or to wrap around torch attention. Has anyone gotten it working?

laoda513 commented 1 month ago

xformers also has a implemention called xops.memory_efficient_attention. Related code was commented in exl2. Not sure why, but it seems not too hard to get it work image

B,M,H,K = q_states.shape q= q_states.reshape([B, M, num_key_value_heads, num_key_value_groups, K]) k =k_states.reshape([B, k_states.shape[1], num_key_value_heads, 1, K]).expand([B, k_states.shape[1], num_key_value_heads, num_key_value_groups, K]) v= v_states.reshape([B, v_states.shape[1], num_key_value_heads, 1, K]).expand([B, k_states.shape[1], num_key_value_heads, num_key_value_groups, K])

k_states = None

q_states = None

v_states = None

attn_mask = attn_params.get_attn_mask(hidden_states.device)

print(attn_mask)

if attn_mask is None or attn_mask[0,0,0,1]==1: attn_output = xops.memory_efficient_attention(q, k, v) else: attn_output = xops.memory_efficient_attention(q, k, v, attn_bias = xops.LowerTriangularMask()) attn_output = attn_output.reshape((batch_size, q_len, hidden_size))

Ph0rk0z commented 1 month ago

I got xformers to work but it would generate gibberish. Was faster though. I don't think I was reshaping it right.

turboderp commented 1 month ago

The main reason I moved away from xformers was that the causal attention mask was strictly lop-left aligned. This was addressed in flash-attn, allowing attention when 1 < q_len < k_len, which is essential for a number of features like cache reuse and speculative decoding.

It might be time to give xformers another look though, see if they've caught up.

laoda513 commented 1 month ago

@turboderp do you mean this one?

xops.fmha.LowerTriangularFromBottomRightMask()

laoda513 commented 1 month ago

I guess got this worked by using xops.fmha.LowerTriangularFromBottomRightMask(), with a quick rough test with very long context(100k), it seems 25~30% faster than Torch matmul attention in my 4*2080ti server didn't see mem usage reducing...not sure why? Isn't it suppose to be more mem efficient with "memory_efficient_attention" ?

Shall I make a pr on this?

turboderp commented 1 month ago

I know there are some traps with SDPA at least, where certain types of masking causes a fallback to matmul attention, and it's possible that's what you're seeing with xformers? Even if the implementation is faster because some operations end up getting fused, that doesn't mean it's actually memory efficient.

Another possibility is that chunking by default already prevents matmul attention from scaling too much. I.e. with a chunk size of 2048 and a max attn size of 2048^2, matmul attention will never use more memory than it uses for a 2048-token context. (It will become progressively slower instead, as chunks get shorter and shorter to keep q_len * k_len < 2048^2.)

I made this little test script to measure the actual memory usage of flash-attn and confirm that it scales with q and not k (usage in bytes is 4q^2 + 8192q + C). I guess you could plug in xformers to compare.

import torch
from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
from flash_attn import flash_attn_func

nvmlInit()
nvml_handle = nvmlDeviceGetHandleByIndex(0)

bsz = 1
num_attention_heads = 64
num_key_value_heads = 8
head_dim = 128
max_ctx = 65536

q_states_f = torch.randn((bsz, max_ctx, num_attention_heads, head_dim), dtype = torch.half, device = "cuda:0")
k_states_f = torch.randn((bsz, max_ctx, num_key_value_heads, head_dim), dtype = torch.half, device = "cuda:0")
v_states_f = torch.randn((bsz, max_ctx, num_key_value_heads, head_dim), dtype = torch.half, device = "cuda:0")

info = nvmlDeviceGetMemoryInfo(nvml_handle)
max_mem = info.used
print(f"max_mem: {max_mem:,} bytes")

for past in range(0, 32768+1, 2048):
    for q in range(2048, 32768+1, 2048):
        k = q + past
        q_states = q_states_f[:, :q, :, :]
        k_states = k_states_f[:, :k, :, :]
        v_states = v_states_f[:, :k, :, :]
        attn_output = flash_attn_func(q_states, k_states, v_states, causal = True)
        torch.cuda.synchronize()
        info = nvmlDeviceGetMemoryInfo(nvml_handle)
        max_mem = info.used
        print(f"q x k = {q:6,} x {k:6,}      max_mem: {max_mem:13,} bytes")

A PR would be fine, though I'm not sure exactly when I'll have time to go over it. It would probably need some more integration.

laoda513 commented 1 month ago

I'll give it a try.

laoda513 commented 1 month ago

I have made a PR: https://github.com/turboderp/exllamav2/pull/452.

For SM >= 80, using your script, the memory cost and speed are quite balanced. Unfortunately, memory_efficient_attention does not support broadcasting, so we need to manually expand the tensor, which significantly slows down the speed. I wonder if there is a more efficient way to handle the matrix operations.

In the PR:

For SM >= 80, we continue using flash_attn. If flash_attn is not available, the xformer implementation is 30-50% slower than flash_attn but 30-50% faster than Torch's matmul attention with very long context(~100K). For SM < 80, the xformer implementation reduces memory cost similarly to SM >= 80 (it actually uses even less memory than SM >= 80, but I'm not sure by how much, so I didn't modify the code related to model loading) and is 30-50% faster than Torch's matmul attention.

import torch
from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
from flash_attn import flash_attn_func
import xformers.ops as xops
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"#"0,1,2,3"
nvmlInit()
nvml_handle = nvmlDeviceGetHandleByIndex(5)
info = nvmlDeviceGetMemoryInfo(nvml_handle)
max_mem = info.used
print(f"max_mem: {max_mem/ (1024**2)} m")
bsz = 1
num_attention_heads = 64
num_key_value_heads = 8
head_dim = 128
max_ctx = 65536

q_states_f = torch.randn((bsz, max_ctx, num_attention_heads, head_dim), dtype = torch.half, device = "cuda:0")
k_states_f = torch.randn((bsz, max_ctx, num_key_value_heads, head_dim), dtype = torch.half, device = "cuda:0")
v_states_f = torch.randn((bsz, max_ctx, num_key_value_heads, head_dim), dtype = torch.half, device = "cuda:0")

q_states_x = q_states_f.clone()
k_states_x=k_states_f.reshape(bsz, max_ctx, num_key_value_heads, 1, head_dim).expand(bsz, max_ctx, num_key_value_heads, num_attention_heads//num_key_value_heads,head_dim).reshape(bsz, max_ctx, num_attention_heads, head_dim)
v_states_x=v_states_f.reshape(bsz, max_ctx, num_key_value_heads, 1, head_dim).expand(bsz, max_ctx, num_key_value_heads, num_attention_heads//num_key_value_heads,head_dim).reshape(bsz, max_ctx, num_attention_heads, head_dim)

info = nvmlDeviceGetMemoryInfo(nvml_handle)
max_mem = info.used
print(f"max_mem: {max_mem/ (1024**2)} m")

for past in range(0, 32768+1, 2048):
    for q in range(2048, 32768+1, 2048):
        k = q + past
        q_states = q_states_x[:, :q, :, :]
        k_states = k_states_x[:, :k, :, :]
        v_states = v_states_x[:, :k, :, :]
        attn_output = xops.memory_efficient_attention(q_states, k_states, v_states, attn_bias= xops.fmha.LowerTriangularFromBottomRightMask()).reshape(bsz,q , num_attention_heads* head_dim)

        # q_states = q_states_f[:, :q, :, :]
        # k_states = k_states_f[:, :k, :, :]
        # v_states = v_states_f[:, :k, :, :]
        # attn_output = flash_attn_func(q_states, k_states, v_states, causal = True).reshape(bsz,q , num_attention_heads* head_dim)

        # print(torch.all(attn_output==attn_output1))
        torch.cuda.synchronize()
        info = nvmlDeviceGetMemoryInfo(nvml_handle)
        # if attn_output[0,0,0]==100:
        #     print(1)
        max_mem = info.used
        print(f"q x k = {q:6,} x {k:6,}      max_mem: {max_mem/ (1024**2)} m")
Beinsezii commented 1 month ago

FWIW, vllm's Triton attention also works on AMD Navi if you edit the autotune configs.

Something to consider @turboderp while AMD fumbles the CK based flash attention on Navi cards.

turboderp commented 1 month ago

I'm slightly allergic to Triton, so I'm not sure. Also this implementation doesn't seem to support paged attention which is kind of a shame.

Ph0rk0z commented 1 month ago

xformers kicks ass but isn't as fast on longer contexts. It even works on my P100. Allowed me to run a bigger quant of wizard. This was a HUGE drawback to non ampere cards. Don't know if anyone got it to work on AMD with hip yet because it could be an option.