Open zanussbaum opened 10 months ago
Probably because of the tensors are larger than 2GB. We use 32-bit int to do indexing and tensors larger than 2GB might have wrong indexing. I didn't think we'd ever run attention with tensors larger than 2GB (since you'd probably OOM elsewhere anyway).
Thanks Tri, that's what I figured. If I wanted to try and see if we do OOM, is there a way to change the indexing precision from 32-bit to 64bit?
You can try changing this line from uint32_t to uint64_t but I've never tried. Lmk if that works.
Hm that doesn't seem to fix it 😢
but the code still runs at batch size 512
[2023-08-25 02:57:23,985] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Traceback (most recent call last):
File "/home/paperspace/contrastors/test_mem.py", line 17, in <module>
outputs = model(inputs)
File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/models/gpt.py", line 513, in forward
hidden_states, hidden_states2, residual = layer(
File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/modules/block.py", line 419, in forward
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/modules/mha.py", line 624, in forward
context = self.inner_attn(qkv, **kwargs)
File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/modules/mha.py", line 92, in forward
return flash_attn_qkvpacked_func(
File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 516, in flash_attn_qkvpacked_func
return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs)
File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 165, in forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 45, in _flash_attn_forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
If you invoke the script with compute-sanitizer python3 script.py
compute-sanitizer might give you more information on what specifically is going wrong.
Hm that doesn't seem to fix it 😢
but the code still runs at batch size 512
[2023-08-25 02:57:23,985] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect) Traceback (most recent call last): File "/home/paperspace/contrastors/test_mem.py", line 17, in <module> outputs = model(inputs) File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/models/gpt.py", line 513, in forward hidden_states, hidden_states2, residual = layer( File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/modules/block.py", line 419, in forward hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs) File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/modules/mha.py", line 624, in forward context = self.inner_attn(qkv, **kwargs) File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/modules/mha.py", line 92, in forward return flash_attn_qkvpacked_func( File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 516, in flash_attn_qkvpacked_func return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs) File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 165, in forward out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( File "/home/paperspace/contrastors/env/lib/python3.10/site-packages/flash_attn-2.0.9-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 45, in _flash_attn_forward out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( RuntimeError: CUDA error: an illegal memory access was encountered Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Hi ! Did you solvered this error?
I'm getting a
RuntimeError: CUDA error: an illegal memory access was encountered
using FlashAttention with a GPT-NeoX-esque model. I
returns correctly but if I change
batch_size = 768
, then I get the errorIt seems like it OOMs after the 3rd attention layer. I assumed I would get a CUDA OOM but wanted to raise this in case this is unexpected similar to #124
This is the relevant environment info: