flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.4k stars 128 forks source link

SingleDecodeWithKVCache meets illegal memory access when setting input tensors to cuda:1 #452

Open jason-huang03 opened 2 months ago

jason-huang03 commented 2 months ago

This is from the given example in the repo:

import torch
import flashinfer

device_id = 1

kv_len = 2048
num_kv_heads = 32
head_dim = 128

k = torch.randn(kv_len, num_kv_heads, head_dim).half().to(device_id) 
v = torch.randn(kv_len, num_kv_heads, head_dim).half().to(device_id) 

# decode attention

num_qo_heads = 32
q = torch.randn(num_qo_heads, head_dim).half().to(device_id)

o = flashinfer.single_decode_with_kv_cache(q, k, v) # decode attention without RoPE on-the-fly
o_rope_on_the_fly = flashinfer.single_decode_with_kv_cache(q, k, v, pos_encoding_mode="ROPE_LLAMA") # decode with LLaMA style RoPE on-the-fly

# append attention
append_qo_len = 128
q = torch.randn(append_qo_len, num_qo_heads, head_dim).half().to(device_id) # append attention, the last 128 tokens in the KV-Cache are the new tokens
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True) # append attention without RoPE on-the-fly, apply causal mask
o_rope_on_the_fly = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, pos_encoding_mode="ROPE_LLAMA") # append attention with LLaMA style RoPE on-the-fly, apply causal mask

# prefill attention
qo_len = 2048
q = torch.randn(qo_len, num_qo_heads, head_dim).half().to(device_id) # prefill attention
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=False) # prefill attention without RoPE on-the-fly, do not apply causal mask

When device_id=0, everything is fine. However, when device_id=1, the following error is thrown:

    out = _decode.single_decode_with_kv_cache(
RuntimeError: SingleDecodeWithKVCache kernel launch failed, error: an illegal memory access was encountered

I am using A100 SM 80. I find that the problem should have been solved in the commit related to #349 but I still meet this weird problem. Can you see why it happens? Thanks a lot! I want to deploy 70B model on multiple gpus so I think being able to run the kernel on different gpus is really important. Can you see why it happens?

yzh119 commented 2 months ago

Hi @jason-huang03 , which version of flashinfer you were using? I suppose the issue should have been fixed in 0.0.9.

I can't reproduce it with the latest version of flashinfer (v0.1.5).

jason-huang03 commented 2 months ago

I checkout to v0.1.5 and rebuild using pip install --no-cache-dir --force-reinstall -e . . However, the problem persists. The whole error message is

CUDA Error: an illegal memory access was encountered (700) /mnt/huanghaofeng/flashinfer/python/include/flashinfer/attention/decode.cuh: line 658 at function cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)
Traceback (most recent call last):
  File "/mnt/huanghaofeng/flashinfer/test.py", line 19, in <module>
    o_rope_on_the_fly = flashinfer.single_decode_with_kv_cache(q, k, v, pos_encoding_mode="ROPE_LLAMA") # decode with LLaMA style RoPE on-the-fly
  File "/mnt/huanghaofeng/flashinfer/python/flashinfer/decode.py", line 194, in single_decode_with_kv_cache
    out = _decode.single_decode_with_kv_cache(
RuntimeError: SingleDecodeWithKVCache kernel launch failed, error: an illegal memory access was encountered

You can see that the problem is from cudaFuncSetAttribute.

I am using cuda 11.8, torch 2.2.0 and in a containerized development environment. Can this be the problem?

jason-huang03 commented 2 months ago

Also I find that device_id in function SinglePrefillWithKVCacheDispatched in python/include/flashinfer/attention/prefill.cuh seems to be 0 regardless of the device_id set in the python code.

yzh119 commented 2 months ago

@jason-huang03 would you mind checking the device id here.

jason-huang03 commented 2 months ago

I use std::cout, device.index() here is empty, but device is correct (like cuda:1). I am now trying to use cuda 12.4 and torch 2.4 to see whether the problem can be solved.

jason-huang03 commented 2 months ago

After using pytorch 2.4 and cuda 12.4, the error disappears. Thanks for your time. It seems that the device and device index api has undergone some changes in the cuda or pytorch version.

yzh119 commented 2 months ago

thanks for reporting, I'll check the behavior on cu118 platforms.