hpcaitech / ColossalAI

Making large AI models cheaper, faster and more accessible
https://www.colossalai.org
Apache License 2.0
38.85k stars 4.35k forks source link

[BUG]: OOM during llama2 pretraining with flashattention and PP #5549

Open insujang opened 8 months ago

insujang commented 8 months ago

🐛 Describe the bug

I understand that this error came out of flash attention software stack, but it seems there is no related issue except for #https://github.com/Dao-AILab/flash-attention/issues/590, therefore I anyway open an issue here. This problem happens as well with flash-attn 2.0.5.

Using pp in HybidParallelPlugin (No-ZeRO) and flash attention together for Llama2 results in OOM

When I try to run examples/language/llama2/pretrain.py, adding padding back to inputs returns OOM. Without flashattention it works fine.

plugin = HybridParallelPlugin(tp_size=2, pp_size=2, # all the other args are the same as in the example)

Note that if you set pp_size=1 you will get cache only has 0 layers exception (#5410) even before facing OOM :) So there is another bug in llama2 forward with attention parallelism. Just a sidenote

PYTHONPATH=/path/to/colossalai/examples/language/llama2 torchrun --standalone --nproc-per-node 4 pretrain.py -p hybrid_parallel -a -g -x bf16 -o /tmp/llama_checkpoint
  File "/data/insujang/colossalai/examples/language/llama2/attn.py", line 174, in attention_forward
    q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
  File "/opt/conda/lib/python3.10/site-packages/flash_attn-2.5.6-py3.10-linux-x86_64.egg/flash_attn/bert_padding.py", line 119, in unpad_input
    index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/lib/python3.10/site-packages/flash_attn-2.5.6-py3.10-linux-x86_64.egg/flash_attn/bert_padding.py", line 17, in forward
    return torch.gather(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 59.67 GiB. GPU 1 has a total capacity of 44.35 GiB of which 34.03 GiB is free. Process 1325526 has 10.31 GiB memory in use. Of the allocated memory 9.76 GiB is allocated by PyTorch, and 40.83 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

https://github.com/hpcaitech/ColossalAI/blob/7e0ec5a85c73fcc5666b9d218e43865141587dde/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py#L174

I think this might be related to the size of attention_mask, but not sure

# from flash_attn/bert_padding.py
def unpad_input(hidden_states, attention_mask):
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
    # attention_mask.shape=torch.Size([1, 1, 4096, 4096]
    # indices.shape=torch.Size([15642705])
    ...
    return (
        index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), # Error here
        indices,
        cu_seqlens,
        max_seqlen_in_batch,
    )

# index_first_axis calls IndexFirstAxis.forward()
class IndexFirstAxis(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, indices):
        ...
        return torch.gather(
            rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) # Error here
        ).reshape(-1, *other_shape)

where attention_mask is created here: https://github.com/hpcaitech/ColossalAI/blob/fd4444058f9ebd5f99cfc60e2e5bf69a7dd38d73/colossalai/shardformer/modeling/llama.py#L101-L103

I would appreicate it if you could try if this is reproducible and the reason.

Environment

4 48GB A40s Pytorch 2.2.1 | CUDA 12.1 ColossalAI branch: feature/update-transformers transformers 4.36.0 flash-attn 2.5.6

insujang commented 8 months ago

@wangbluo Could you please help me solve this issue? Thanks

wangbluo commented 8 months ago

@wangbluo Could you please help me solve this issue? Thanks

Hi, could you please offer the model size you use?

insujang commented 8 months ago

I used 7b configuration.