Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
11.92k stars 1.06k forks source link

RuntimeError: CUDA error: an illegal memory access was encountered #483

Open zanussbaum opened 10 months ago

zanussbaum commented 10 months ago

I'm getting a RuntimeError: CUDA error: an illegal memory access was encountered

using FlashAttention with a GPT-NeoX-esque model. I

from transformers import AutoConfig
import torch
from flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config
from flash_attn.models.gpt import GPTModel

batch_size = 512
config = AutoConfig.from_pretrained("EleutherAI/pythia-1b")

flash_config = gpt_neox_config_to_gpt2_config(config)
flash_config.use_flash_attn = True
model = GPTModel(flash_config).to('cuda').to(torch.bfloat16)

model.eval()
inputs = torch.ones((batch_size, 2048)).to("cuda").long()

with torch.no_grad():
    outputs = model(inputs)
    print(outputs.shape)

returns correctly but if I change batch_size = 768, then I get the error

torch.Size([768, 2048, 2048]) # this i just printing the `hidden_states`, `hidden_states2`, and `residual` shapes of each layer
None
None
torch.Size([768, 2048, 2048])
torch.Size([768, 2048, 2048])
torch.Size([768, 2048, 2048])
Traceback (most recent call last):
  File "/home/paperspace/contrastors/test_mem.py", line 33, in <module>
    outputs = model(inputs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn/models/gpt.py", line 513, in forward
    hidden_states, hidden_states2, residual = layer(
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn/modules/block.py", line 420, in forward
    hidden_states2 = self.mlp(hidden_states2)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn/modules/mlp.py", line 43, in forward
    y = self.activation(y)
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

It seems like it OOMs after the 3rd attention layer. I assumed I would get a CUDA OOM but wanted to raise this in case this is unexpected similar to #124

This is the relevant environment info:

pytorch-triton           2.1.0+e6216047b8
torch                    2.1.0.dev20230824+cu121
flash-attn               2.0.9
tridao commented 10 months ago

Probably because of the tensors are larger than 2GB. We use 32-bit int to do indexing and tensors larger than 2GB might have wrong indexing. I didn't think we'd ever run attention with tensors larger than 2GB (since you'd probably OOM elsewhere anyway).

zanussbaum commented 10 months ago

Thanks Tri, that's what I figured. If I wanted to try and see if we do OOM, is there a way to change the indexing precision from 32-bit to 64bit?

tridao commented 10 months ago

You can try changing this line from uint32_t to uint64_t but I've never tried. Lmk if that works.

zanussbaum commented 10 months ago

Hm that doesn't seem to fix it 😢

but the code still runs at batch size 512

[2023-08-25 02:57:23,985] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Traceback (most recent call last):
  File "/home/paperspace/contrastors/test_mem.py", line 17, in <module>
    outputs = model(inputs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/models/gpt.py", line 513, in forward
    hidden_states, hidden_states2, residual = layer(
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/modules/block.py", line 419, in forward
    hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/modules/mha.py", line 624, in forward
    context = self.inner_attn(qkv, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/modules/mha.py", line 92, in forward
    return flash_attn_qkvpacked_func(
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 516, in flash_attn_qkvpacked_func
    return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 165, in forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 45, in _flash_attn_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
sophiawisdom commented 10 months ago

If you invoke the script with compute-sanitizer python3 script.py compute-sanitizer might give you more information on what specifically is going wrong.

zlh1992 commented 7 months ago

Hm that doesn't seem to fix it 😢

but the code still runs at batch size 512

[2023-08-25 02:57:23,985] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Traceback (most recent call last):
  File "/home/paperspace/contrastors/test_mem.py", line 17, in <module>
    outputs = model(inputs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/models/gpt.py", line 513, in forward
    hidden_states, hidden_states2, residual = layer(
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/modules/block.py", line 419, in forward
    hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/modules/mha.py", line 624, in forward
    context = self.inner_attn(qkv, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/modules/mha.py", line 92, in forward
    return flash_attn_qkvpacked_func(
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 516, in flash_attn_qkvpacked_func
    return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs)
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 165, in forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
  File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 45, in _flash_attn_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Hi ! Did you solvered this error?