THUDM / LongBench

[ACL 2024] LongBench: A Bilingual, Multitask Benchmark for Long Context Understanding
MIT License
675 stars 54 forks source link

单张A100 40G 无法运行(OOM) llama2-7b-chat-4k,但是可以运行 chatglm2-6b-32k #43

Closed fishiu closed 11 months ago

fishiu commented 11 months ago

以下为使用四卡A100 40G(CUDA_VISIBLE_DEVICES=0 也是一样的情况)运行 python pred.py --model llama2-7b-chat-4k --e 的输出,确认max length是3500,居然需要申请140G显存?根据输出显示llama也开启了flash attention。同样的环境chatglm2-6b-32k完全没有显存问题,是因为chatglm用了特殊的技术吗?

另外我尝试把max length改成1500会报错RuntimeError: cu_seqlens_q must have shape (batch_size + 1)这是预期内的吗,我不理解这跟batch size有什么关系呀

+ python pred.py --model llama2-7b-chat-4k --e
use FlashAttention
Loading checkpoint shards: 100%|██████████| 2/2 [00:48<00:00, 24.45s/it]
Model: llama2-7b-chat-4k, Max Length: 3500

Traceback (most recent call last):
  File "/anonymous/LongBench/pred.py", line 165, in <module>
    preds = get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anonymous/LongBench/pred.py", line 78, in get_pred
    output = model.generate(
             ^^^^^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/transformers/generation/utils.py", line 1673, in generate
    return self.greedy_search(
           ^^^^^^^^^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/transformers/generation/utils.py", line 2521, in greedy_search
    outputs = self(
              ^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/accelerate/hooks.py", line 164, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1034, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 922, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/accelerate/hooks.py", line 164, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 672, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/accelerate/hooks.py", line 164, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anonymous/LongBench/llama_flash_attn_monkey_patch.py", line 119, in forward
    x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/flash_attn/bert_padding.py", line 118, in unpad_input
    index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anonymous/.env/lib/python3.11/site-packages/flash_attn/bert_padding.py", line 17, in forward
    return torch.gather(
           ^^^^^^^^^^^^^
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 140.63 GiB. GPU 0 has a total capacty of 39.56 GiB of which 25.98 GiB is free. Including non-PyTorch memory, this process has 13.57 GiB memory in use. Of the allocated memory 12.99 GiB is allocated by PyTorch, and 91.79 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
bys0318 commented 11 months ago

pred.py的代码我们只在单卡A100上测试过,各模型均不面临爆显存的问题(在32k下也不会)。您用单卡试一下?另外,请确保replace_llama_attn_with_flash_attn()这段patch代码已被执行。

fishiu commented 11 months ago

感谢回复!我后来关掉flashattention之后发现反而可以运行llama了,是我安装的flash attention有问题,虽然具体原因我还没找到,总之重新安装flash attention之后就可以了。

另外请问为什么把max length改成1500会报错呀?RuntimeError: cu_seqlens_q must have shape (batch_size + 1) 这是预期内的吗,和batch size有什么关系?

bys0318 commented 11 months ago

不好意思,这个报错我们没有遇到过。按说max_length不会让代码产生bug。cu_seqlens_q是FlashAttention中涉及的参数,可能你的FlashAttention安装有问题?试试重新装一下FlashAttention吧!

fishiu commented 11 months ago

确实是的,重新跑一下没问题了,感谢!