Open maaquib opened 10 months ago
How reasonable is it to just use uint64_t
for index_t
instead of uint32_t
? Would that negatively affect performance, maybe due to reduced occupancy because of more registers or something? I ran the above script both ways (with 32-bit and 64-bitindex_t
) and saw the same number of registers per thread in both cases, for this kernel on this particular GPU (A10g on AWS g5 instance).
@tridao Any thoughts?
Alternative kv cache layouts with a layer dimension "inside" the batch dimension (i.e., a particular sequence's kv cache token data for all layers is stored together, before the next sequence's data) are likely to run into this, since the batch stride will be multiplied by the number of layers.
I'll also point out that it's possible to hit this issue right now, even with a standard single-layer cache layout, even on currently-supported 80 GiB GPU like A100. Here's an example:
Consider a model like llama-2-7b but with half the layers. So, 16 layers, num_heads=32, headdim=128. Assume it supports sequence lengths up to 16k.
On A100 GPU with 80 GiB GPU RAM, I should be able to fit:
We still have about 10 GiB to spare for activations. But we'll run into this same issue.
I ran a modified version of the above repro script with these parameters, and got the following output (notice that the last line of output is the wrong result with batch size 64):
------------
#### CASE 1: max_seq_len=8448, max_batch_size=32
------------
PREFILL 1:
log2(strides): [26, 13, 7, 0]
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0',
dtype=torch.float16)
------------
PREFILL 2:
log2(strides): [26, 13, 7, 0]
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2., 2., 2., 2.]], device='cuda:0',
dtype=torch.float16)
------------
#### CASE 2: max_seq_len=8448, max_batch_size=64
------------
PREFILL 1:
log2(strides): [26, 13, 7, 0]
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0',
dtype=torch.float16)
------------
PREFILL 2:
log2(strides): [26, 13, 7, 0]
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2., 2., 2., 2.],
[0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0',
dtype=torch.float16)
Thanks @maaquib and @davidthomas426 for this thorough investigation. The way forward is to use 64 bit indexing. I've just pushed a commit to switch to int64_t indexing. There might be a few other places in the code where we make the assumption that indexing is in 32 bits. I'll need to go through and check.
Thanks!
I looked at the commit, and I think at leastkernel_traits.h
and kernel_traits_sm90.h
need to be updated as well, as they also set using index_t = uint32_t
. Not sure if there are any more subtle spots.
The following case simulates 2 back-2-back prefills. 1st with
batch_size=1, seq_len=2
and 2nd withbatch_size=4, seq_len=3
. Facing this issue onA10G
andA100
.*Note
kv_cache.max_seq_len=4096
is where the batch strides become2**30
*Case 1 - kv_cache with
max_seq_len=1024
:Everything works as expected i.e. during the first prefill the
kv_cache.batch_idx=0
is updated with the correctkv
s and during the 2nd prefill thekv_cache.batch_idx=[1, 2, 3, 4]
are update with the correctkv
s. At the end of the 2 iterations, kv_cache is in an expected state with correct values.Case 2 - kv_cache with
max_seq_len=4096
:Everything works as expected for the first prefill i.e.
kv_cache.batch_idx=0
is updated with the correctkv
s. But during the 2nd prefill thekv_cache.batch_idx=[0]
is incorrectly updated. The expectation for prefill 2 is that only thekv_cache.batch_idx=[1, 2, 3, 4]
should have been updated.It is also reproduce-able with both num_splits = 0 and num_splits = 1
Fix
Setting
using index_t = uint64_t
instead ofuint32_t
fixes the issue. (Thanks to @davidthomas426 for investigating this and proposing this fix).Logs