I understand that this error came out of flash attention software stack, but it seems there is no related issue except for #https://github.com/Dao-AILab/flash-attention/issues/590, therefore I anyway open an issue here. This problem happens as well with flash-attn 2.0.5.
Using pp in HybidParallelPlugin (No-ZeRO) and flash attention together for Llama2 results in OOM
When I try to run examples/language/llama2/pretrain.py, adding padding back to inputs returns OOM. Without flashattention it works fine.
plugin = HybridParallelPlugin(tp_size=2, pp_size=2, # all the other args are the same as in the example)
Note that if you set pp_size=1 you will get cache only has 0 layers exception (#5410) even before facing OOM :) So there is another bug in llama2 forward with attention parallelism. Just a sidenote
File "/data/insujang/colossalai/examples/language/llama2/attn.py", line 174, in attention_forward
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
File "/opt/conda/lib/python3.10/site-packages/flash_attn-2.5.6-py3.10-linux-x86_64.egg/flash_attn/bert_padding.py", line 119, in unpad_input
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/lib/python3.10/site-packages/flash_attn-2.5.6-py3.10-linux-x86_64.egg/flash_attn/bert_padding.py", line 17, in forward
return torch.gather(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 59.67 GiB. GPU 1 has a total capacity of 44.35 GiB of which 34.03 GiB is free. Process 1325526 has 10.31 GiB memory in use. Of the allocated memory 9.76 GiB is allocated by PyTorch, and 40.83 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
🐛 Describe the bug
I understand that this error came out of flash attention software stack, but it seems there is no related issue except for #https://github.com/Dao-AILab/flash-attention/issues/590, therefore I anyway open an issue here. This problem happens as well with
flash-attn 2.0.5
.Using pp in HybidParallelPlugin (No-ZeRO) and flash attention together for Llama2 results in OOM
When I try to run
examples/language/llama2/pretrain.py
, adding padding back to inputs returns OOM. Without flashattention it works fine.Note that if you set
pp_size=1
you will getcache only has 0 layers exception
(#5410) even before facing OOM :) So there is another bug in llama2 forward with attention parallelism. Just a sidenotehttps://github.com/hpcaitech/ColossalAI/blob/7e0ec5a85c73fcc5666b9d218e43865141587dde/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py#L174
I think this might be related to the size of attention_mask, but not sure
where
attention_mask
is created here: https://github.com/hpcaitech/ColossalAI/blob/fd4444058f9ebd5f99cfc60e2e5bf69a7dd38d73/colossalai/shardformer/modeling/llama.py#L101-L103I would appreicate it if you could try if this is reproducible and the reason.
Environment
4 48GB A40s Pytorch 2.2.1 | CUDA 12.1 ColossalAI branch: feature/update-transformers transformers 4.36.0 flash-attn 2.5.6