Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.25k stars 1.33k forks source link

`flash_attn_with_kvcache` API bug with kv_cache seq_len=4096 #772

Open maaquib opened 10 months ago

maaquib commented 10 months ago

The following case simulates 2 back-2-back prefills. 1st with batch_size=1, seq_len=2 and 2nd with batch_size=4, seq_len=3. Facing this issue on A10G and A100.

*Notekv_cache.max_seq_len=4096 is where the batch strides become 2**30*

Case 1 - kv_cache with max_seq_len=1024:

Everything works as expected i.e. during the first prefill the kv_cache.batch_idx=0 is updated with the correct kvs and during the 2nd prefill the kv_cache.batch_idx=[1, 2, 3, 4] are update with the correct kvs. At the end of the 2 iterations, kv_cache is in an expected state with correct values.

Case 2 - kv_cache with max_seq_len=4096:

Everything works as expected for the first prefill i.e. kv_cache.batch_idx=0 is updated with the correct kvs. But during the 2nd prefill the kv_cache.batch_idx=[0] is incorrectly updated. The expectation for prefill 2 is that only the kv_cache.batch_idx=[1, 2, 3, 4] should have been updated.

It is also reproduce-able with both num_splits = 0 and num_splits = 1

import torch
from flash_attn.flash_attn_interface import flash_attn_with_kvcache

layers, num_heads, head_size = 32, 32, 128
d_type = torch.float16

def run_test(k_cache, v_cache):
    for i, (prompt_batch_size, batch_idx, seq_len) in enumerate(zip([1, 4], [[0], [1, 2, 3, 4]], [2, 3])):
        q = torch.zeros((prompt_batch_size, seq_len, num_heads, head_size), device='cuda', dtype=d_type)
        k = torch.full((prompt_batch_size, seq_len, num_heads, head_size), 2.3 * (i + 1), device='cuda', dtype=d_type)
        v = torch.full((prompt_batch_size, seq_len, num_heads, head_size), 3.4 * (i + 1), device='cuda', dtype=d_type)
        cache_seqlens = torch.zeros(prompt_batch_size, device='cuda', dtype=torch.int32)
        cache_batch_idx = torch.Tensor(batch_idx).to(device='cuda', dtype=torch.int32)
        flash_attn_with_kvcache(
            q=q,
            k_cache=k_cache,
            v_cache=v_cache,
            k=k,
            v=v,
            rotary_cos=None,
            rotary_sin=None,
            cache_seqlens=cache_seqlens,
            cache_batch_idx=cache_batch_idx,
            # softmax_scale=None,
            # causal=True,
            # window_size=(-1, -1),
            # rotary_interleaved=True,
            # alibi_slopes=None,
            # num_splits=0
            )
        # print(f"### PREFILL {i + 1}:\n{k_cache[0:6, 0:3, :, :]}\n\n")
        print(f"### PREFILL {i + 1}:\n{k_cache[0:1, 0:3, :, :]}\n\n")

if __name__ == "__main__":
    assert torch.cuda.is_available()
    layer_id = 0
    for idx, max_seq_len in enumerate([1024, 4096]):
        print(f"## CASE {idx + 1}: max_seq_len={max_seq_len}")
        kv_cache = torch.zeros(8, max_seq_len, layers, 2, num_heads, head_size, device='cuda', dtype=d_type)
        # k_cache.shape: (batch_size_cache, seqlen_cache, nheads_k, headdim)
        key_cache = kv_cache[:, :, layer_id, 0, :, :]
        val_cache = kv_cache[:, :, layer_id, 1, :, :]
        run_test(key_cache, val_cache)
$ pip show flash_attn
Name: flash-attn
Version: 2.4.2
Summary: Flash Attention: Fast and Memory-Efficient Exact Attention
Home-page: https://github.com/Dao-AILab/flash-attention
Author: Tri Dao
Author-email: trid@cs.stanford.edu
License:
Location: /home/ubuntu/miniconda3/envs/vllm/lib/python3.10/site-packages
Requires: einops, ninja, packaging, torch
Required-by:

Fix

Setting using index_t = uint64_t instead of uint32_t fixes the issue. (Thanks to @davidthomas426 for investigating this and proposing this fix).

Logs

$ cat -n out1 | grep -e CASE -e bidb -e PREFILL -e tensor -e device -e '\[\[\[' -e '---'

     1  ------------
     2  #### CASE 1: max_seq_len=1024
     3  ------------
     4  PREFILL 1:
     5  bidb = 0, bidb_cache = 0
     6  tensor([[[[2.3008, 2.3008, 2.3008,  ..., 2.3008, 2.3008, 2.3008],
    39          [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
    72          [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
   105          [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
   138          [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
   171          [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
   202         device='cuda:0', dtype=torch.float16)
   205  ------------
   206  PREFILL 2:
   207  bidb = 1, bidb_cache = 2
   208  bidb = 2, bidb_cache = 3
   209  bidb = 3, bidb_cache = 4
   210  bidb = 0, bidb_cache = 1
   211  tensor([[[[2.3008, 2.3008, 2.3008,  ..., 2.3008, 2.3008, 2.3008],
   244          [[[4.6016, 4.6016, 4.6016,  ..., 4.6016, 4.6016, 4.6016],
   277          [[[4.6016, 4.6016, 4.6016,  ..., 4.6016, 4.6016, 4.6016],
   310          [[[4.6016, 4.6016, 4.6016,  ..., 4.6016, 4.6016, 4.6016],
   343          [[[4.6016, 4.6016, 4.6016,  ..., 4.6016, 4.6016, 4.6016],
   376          [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
   407         device='cuda:0', dtype=torch.float16)
   410  ------------
   411  #### CASE 2: max_seq_len=4096
   412  ------------
   413  PREFILL 1:
   414  bidb = 0, bidb_cache = 0
   415  tensor([[[[2.3008, 2.3008, 2.3008,  ..., 2.3008, 2.3008, 2.3008],
   448          [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
   481          [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
   514          [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
   547          [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
   580          [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
   611         device='cuda:0', dtype=torch.float16)
   614  ------------
   615  PREFILL 2:
   616  bidb = 2, bidb_cache = 3
   617  bidb = 1, bidb_cache = 2
   618  bidb = 3, bidb_cache = 4
   619  bidb = 0, bidb_cache = 1
   620  tensor([[[[4.6016, 4.6016, 4.6016,  ..., 4.6016, 4.6016, 4.6016],
   653          [[[4.6016, 4.6016, 4.6016,  ..., 4.6016, 4.6016, 4.6016],
   686          [[[4.6016, 4.6016, 4.6016,  ..., 4.6016, 4.6016, 4.6016],
   719          [[[4.6016, 4.6016, 4.6016,  ..., 4.6016, 4.6016, 4.6016],
   752          [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
   785          [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
   816         device='cuda:0', dtype=torch.float16)
davidthomas426 commented 10 months ago

How reasonable is it to just use uint64_t for index_t instead of uint32_t? Would that negatively affect performance, maybe due to reduced occupancy because of more registers or something? I ran the above script both ways (with 32-bit and 64-bitindex_t) and saw the same number of registers per thread in both cases, for this kernel on this particular GPU (A10g on AWS g5 instance).

@tridao Any thoughts?

Alternative kv cache layouts with a layer dimension "inside" the batch dimension (i.e., a particular sequence's kv cache token data for all layers is stored together, before the next sequence's data) are likely to run into this, since the batch stride will be multiplied by the number of layers.

davidthomas426 commented 10 months ago

I'll also point out that it's possible to hit this issue right now, even with a standard single-layer cache layout, even on currently-supported 80 GiB GPU like A100. Here's an example:

Consider a model like llama-2-7b but with half the layers. So, 16 layers, num_heads=32, headdim=128. Assume it supports sequence lengths up to 16k.

On A100 GPU with 80 GiB GPU RAM, I should be able to fit:

We still have about 10 GiB to spare for activations. But we'll run into this same issue.

I ran a modified version of the above repro script with these parameters, and got the following output (notice that the last line of output is the wrong result with batch size 64):

------------
#### CASE 1: max_seq_len=8448, max_batch_size=32
------------
PREFILL 1:
log2(strides): [26, 13, 7, 0]
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0',
       dtype=torch.float16)

------------
PREFILL 2:
log2(strides): [26, 13, 7, 0]
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2.]], device='cuda:0',
       dtype=torch.float16)

------------
#### CASE 2: max_seq_len=8448, max_batch_size=64
------------
PREFILL 1:
log2(strides): [26, 13, 7, 0]
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0',
       dtype=torch.float16)

------------
PREFILL 2:
log2(strides): [26, 13, 7, 0]
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2.],
        [0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0',
       dtype=torch.float16)
tridao commented 10 months ago

Thanks @maaquib and @davidthomas426 for this thorough investigation. The way forward is to use 64 bit indexing. I've just pushed a commit to switch to int64_t indexing. There might be a few other places in the code where we make the assumption that indexing is in 32 bits. I'll need to go through and check.

davidthomas426 commented 10 months ago

Thanks!

I looked at the commit, and I think at leastkernel_traits.h and kernel_traits_sm90.h need to be updated as well, as they also set using index_t = uint32_t. Not sure if there are any more subtle spots.