flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.21k stars 111 forks source link

flashinfer.page.append_paged_kv_cache will cause an invalid memory access if device != 'cuda:0' #349

Closed Tomorrowdawn closed 2 months ago

Tomorrowdawn commented 3 months ago

minimal reproduction showcase:

DEVICE = 'cuda:1'
shape = (8, 2, 128)
dtype = torch.float16
ragged_keys = torch.randn(shape, dtype=dtype).to(DEVICE)
ragged_values = torch.randn(shape, dtype=dtype).to(DEVICE)
query_indptr  = torch.tensor([0, 8], dtype=torch.int32).to(DEVICE)
cache = torch.empty((100, 2, 16, 2, 128), dtype = dtype).to(DEVICE)
kv_indices = torch.tensor([0], dtype=torch.int32).to(DEVICE)
kv_last_page_lens = torch.tensor([8], dtype=torch.int32).to(DEVICE)
kv_indptr = torch.tensor([0, 1], dtype = torch.int32).to(DEVICE)
torch.cuda.synchronize()##everything is fine until here.
flashinfer.page.append_paged_kv_cache(
            ragged_keys,
            ragged_values,
            query_indptr,
            cache,
            kv_indices,
            kv_indptr,
            kv_last_page_lens,
        )##A runtime error occurs here

When DEVICE is set to 'cuda:0', everything works as expected. However, setting it to any other device (e.g., 'cuda:1') results in a runtime error:

RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

After numerous trials, I am confident that this issue is not hardware-dependent.

A temporary workaround is to use CUDA_VISIBLE_DEVICES=1, which allows the code to run correctly (since PyTorch assumes it is working on cuda:0). However, this approach limits the utilization of multiple GPUs. This feature should be very important, as many large models are incapable of performing inference on a single GPU.

This does indeed seem quite strange; I took a cursory look at the CUDA code, yet I'm at a loss as to how such an odd error could occur.

yzh119 commented 3 months ago

The error has gone after I set torch_current_stream to nullptr (default stream): https://github.com/flashinfer-ai/flashinfer/blob/dc2c76f8577d8695112b61d1fd43ef88569272ef/python/csrc/page.cu#L68

This is because getCurrentCUDAStream() returns the current CUDA stream of current GPU (0), which doesn't work for GPU 1. We should consider the input tensor gpu rank and set stream correspondingly.

A temporary workaround (before we release v0.0.8) is to set current GPU to 1 explicitly:

DEVICE = 'cuda:1'
torch.cuda.set_device(DEVICE)
shape = (8, 2, 128)
dtype = torch.float16
ragged_keys = torch.randn(shape, dtype=dtype).to(DEVICE)
ragged_values = torch.randn(shape, dtype=dtype).to(DEVICE)
query_indptr  = torch.tensor([0, 8], dtype=torch.int32).to(DEVICE)
cache = torch.empty((100, 2, 16, 2, 128), dtype = dtype).to(DEVICE)
kv_indices = torch.tensor([0], dtype=torch.int32).to(DEVICE)
kv_last_page_lens = torch.tensor([8], dtype=torch.int32).to(DEVICE)
kv_indptr = torch.tensor([0, 1], dtype = torch.int32).to(DEVICE)
torch.cuda.synchronize()##everything is fine until here.
flashinfer.page.append_paged_kv_cache(
            ragged_keys,
            ragged_values,
            query_indptr,
            cache,
            kv_indices,
            kv_indptr,
            kv_last_page_lens,
        )