Closed Ph0rk0z closed 2 weeks ago
xformers also has a implemention called xops.memory_efficient_attention. Related code was commented in exl2. Not sure why, but it seems not too hard to get it work
B,M,H,K = q_states.shape q= q_states.reshape([B, M, num_key_value_heads, num_key_value_groups, K]) k =k_states.reshape([B, k_states.shape[1], num_key_value_heads, 1, K]).expand([B, k_states.shape[1], num_key_value_heads, num_key_value_groups, K]) v= v_states.reshape([B, v_states.shape[1], num_key_value_heads, 1, K]).expand([B, k_states.shape[1], num_key_value_heads, num_key_value_groups, K])
k_states = None
q_states = None
v_states = None
attn_mask = attn_params.get_attn_mask(hidden_states.device)
print(attn_mask)
if attn_mask is None or attn_mask[0,0,0,1]==1: attn_output = xops.memory_efficient_attention(q, k, v) else: attn_output = xops.memory_efficient_attention(q, k, v, attn_bias = xops.LowerTriangularMask()) attn_output = attn_output.reshape((batch_size, q_len, hidden_size))
I got xformers to work but it would generate gibberish. Was faster though. I don't think I was reshaping it right.
The main reason I moved away from xformers was that the causal attention mask was strictly lop-left aligned. This was addressed in flash-attn, allowing attention when 1 < q_len < k_len, which is essential for a number of features like cache reuse and speculative decoding.
It might be time to give xformers another look though, see if they've caught up.
@turboderp do you mean this one?
xops.fmha.LowerTriangularFromBottomRightMask()
I guess got this worked by using xops.fmha.LowerTriangularFromBottomRightMask(), with a quick rough test with very long context(100k), it seems 25~30% faster than Torch matmul attention in my 4*2080ti server didn't see mem usage reducing...not sure why? Isn't it suppose to be more mem efficient with "memory_efficient_attention" ?
Shall I make a pr on this?
I know there are some traps with SDPA at least, where certain types of masking causes a fallback to matmul attention, and it's possible that's what you're seeing with xformers? Even if the implementation is faster because some operations end up getting fused, that doesn't mean it's actually memory efficient.
Another possibility is that chunking by default already prevents matmul attention from scaling too much. I.e. with a chunk size of 2048 and a max attn size of 2048^2, matmul attention will never use more memory than it uses for a 2048-token context. (It will become progressively slower instead, as chunks get shorter and shorter to keep q_len * k_len < 2048^2.)
I made this little test script to measure the actual memory usage of flash-attn and confirm that it scales with q and not k (usage in bytes is 4q^2 + 8192q + C). I guess you could plug in xformers to compare.
import torch
from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
from flash_attn import flash_attn_func
nvmlInit()
nvml_handle = nvmlDeviceGetHandleByIndex(0)
bsz = 1
num_attention_heads = 64
num_key_value_heads = 8
head_dim = 128
max_ctx = 65536
q_states_f = torch.randn((bsz, max_ctx, num_attention_heads, head_dim), dtype = torch.half, device = "cuda:0")
k_states_f = torch.randn((bsz, max_ctx, num_key_value_heads, head_dim), dtype = torch.half, device = "cuda:0")
v_states_f = torch.randn((bsz, max_ctx, num_key_value_heads, head_dim), dtype = torch.half, device = "cuda:0")
info = nvmlDeviceGetMemoryInfo(nvml_handle)
max_mem = info.used
print(f"max_mem: {max_mem:,} bytes")
for past in range(0, 32768+1, 2048):
for q in range(2048, 32768+1, 2048):
k = q + past
q_states = q_states_f[:, :q, :, :]
k_states = k_states_f[:, :k, :, :]
v_states = v_states_f[:, :k, :, :]
attn_output = flash_attn_func(q_states, k_states, v_states, causal = True)
torch.cuda.synchronize()
info = nvmlDeviceGetMemoryInfo(nvml_handle)
max_mem = info.used
print(f"q x k = {q:6,} x {k:6,} max_mem: {max_mem:13,} bytes")
A PR would be fine, though I'm not sure exactly when I'll have time to go over it. It would probably need some more integration.
I'll give it a try.
I have made a PR: https://github.com/turboderp/exllamav2/pull/452.
For SM >= 80, using your script, the memory cost and speed are quite balanced. Unfortunately, memory_efficient_attention does not support broadcasting, so we need to manually expand the tensor, which significantly slows down the speed. I wonder if there is a more efficient way to handle the matrix operations.
In the PR:
For SM >= 80, we continue using flash_attn. If flash_attn is not available, the xformer implementation is 30-50% slower than flash_attn but 30-50% faster than Torch's matmul attention with very long context(~100K). For SM < 80, the xformer implementation reduces memory cost similarly to SM >= 80 (it actually uses even less memory than SM >= 80, but I'm not sure by how much, so I didn't modify the code related to model loading) and is 30-50% faster than Torch's matmul attention.
import torch
from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
from flash_attn import flash_attn_func
import xformers.ops as xops
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"#"0,1,2,3"
nvmlInit()
nvml_handle = nvmlDeviceGetHandleByIndex(5)
info = nvmlDeviceGetMemoryInfo(nvml_handle)
max_mem = info.used
print(f"max_mem: {max_mem/ (1024**2)} m")
bsz = 1
num_attention_heads = 64
num_key_value_heads = 8
head_dim = 128
max_ctx = 65536
q_states_f = torch.randn((bsz, max_ctx, num_attention_heads, head_dim), dtype = torch.half, device = "cuda:0")
k_states_f = torch.randn((bsz, max_ctx, num_key_value_heads, head_dim), dtype = torch.half, device = "cuda:0")
v_states_f = torch.randn((bsz, max_ctx, num_key_value_heads, head_dim), dtype = torch.half, device = "cuda:0")
q_states_x = q_states_f.clone()
k_states_x=k_states_f.reshape(bsz, max_ctx, num_key_value_heads, 1, head_dim).expand(bsz, max_ctx, num_key_value_heads, num_attention_heads//num_key_value_heads,head_dim).reshape(bsz, max_ctx, num_attention_heads, head_dim)
v_states_x=v_states_f.reshape(bsz, max_ctx, num_key_value_heads, 1, head_dim).expand(bsz, max_ctx, num_key_value_heads, num_attention_heads//num_key_value_heads,head_dim).reshape(bsz, max_ctx, num_attention_heads, head_dim)
info = nvmlDeviceGetMemoryInfo(nvml_handle)
max_mem = info.used
print(f"max_mem: {max_mem/ (1024**2)} m")
for past in range(0, 32768+1, 2048):
for q in range(2048, 32768+1, 2048):
k = q + past
q_states = q_states_x[:, :q, :, :]
k_states = k_states_x[:, :k, :, :]
v_states = v_states_x[:, :k, :, :]
attn_output = xops.memory_efficient_attention(q_states, k_states, v_states, attn_bias= xops.fmha.LowerTriangularFromBottomRightMask()).reshape(bsz,q , num_attention_heads* head_dim)
# q_states = q_states_f[:, :q, :, :]
# k_states = k_states_f[:, :k, :, :]
# v_states = v_states_f[:, :k, :, :]
# attn_output = flash_attn_func(q_states, k_states, v_states, causal = True).reshape(bsz,q , num_attention_heads* head_dim)
# print(torch.all(attn_output==attn_output1))
torch.cuda.synchronize()
info = nvmlDeviceGetMemoryInfo(nvml_handle)
# if attn_output[0,0,0]==100:
# print(1)
max_mem = info.used
print(f"q x k = {q:6,} x {k:6,} max_mem: {max_mem/ (1024**2)} m")
FWIW, vllm's Triton attention also works on AMD Navi if you edit the autotune configs.
Something to consider @turboderp while AMD fumbles the CK based flash attention on Navi cards.
I'm slightly allergic to Triton, so I'm not sure. Also this implementation doesn't seem to support paged attention which is kind of a shame.
xformers kicks ass but isn't as fast on longer contexts. It even works on my P100. Allowed me to run a bigger quant of wizard. This was a HUGE drawback to non ampere cards. Don't know if anyone got it to work on AMD with hip yet because it could be an option.
So vllm has a triton based flash attention: https://github.com/vllm-project/vllm/blob/main/vllm/attention/ops/triton_flash_attention.py
Supposedly it supports volta+. Not sure if it's slower or faster, it requires a lot more arguments or to wrap around torch attention. Has anyone gotten it working?