Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.78k stars 1.27k forks source link

flash-attention v2 with activation checkpointing (no_reentrant) raise Runtime Error #341

Open wjfwzzc opened 1 year ago

wjfwzzc commented 1 year ago

With both flash_attn_varlen_qkvpacked_func and CheckpointImpl.NO_REENTRANT raise Runtime Error below:

Traceback (most recent call last):

> File "/opt/tiger/antelope/train.py", line 718, in <module>
    main()
    └ <function main at 0x7f385c2679d0>

  File "/opt/tiger/antelope/train.py", line 703, in main
    train(
    └ <function train at 0x7f385c267790>

  File "/opt/tiger/antelope/train.py", line 503, in train
    grad_scaler.scale(loss).backward()
    │           │     └ tensor(3.9735, device='cuda:4', grad_fn=<DivBackward0>)
    │           └ <function ShardedGradScaler.scale at 0x7f386a1dcca0>
    └ <torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler object at 0x7f38356e2e80>

  File "/usr/local/lib/python3.9/dist-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
    │     │        └ <function backward at 0x7f3877d575e0>
    │     └ <module 'torch.autograd' from '/usr/local/lib/python3.9/dist-packages/torch/autograd/__init__.py'>
    └ <module 'torch' from '/usr/local/lib/python3.9/dist-packages/torch/__init__.py'>

  File "/usr/local/lib/python3.9/dist-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    │        │                 └ <method 'run_backward' of 'torch._C._EngineBase' objects>
    │        └ <torch._C._EngineBase object at 0x7f3878059350>
    └ <class 'torch.autograd.variable.Variable'>

  File "/usr/local/lib/python3.9/dist-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
           │       │      └ (tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           │       │                   0.0000e+00,  0.0000e+00],
           │       │                 [ 0.0000e+...
           │       └ <torch.autograd.function.FlashAttnVarlenQKVPackedFuncBackward object at 0x7f2bd62bf220>
           └ <function FlashAttnVarlenQKVPackedFunc.backward at 0x7f38500424c0>

  File "/usr/local/lib/python3.9/dist-packages/flash_attn/flash_attn_interface.py", line 146, in backward
    q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
                                                       │   └ <attribute 'saved_tensors' of 'torch._C._FunctionBase' objects>
                                                       └ <torch.autograd.function.FlashAttnVarlenQKVPackedFuncBackward object at 0x7f2bd62bf220>

RuntimeError: !grad_accumulator_.expired() INTERNAL ASSERT FAILED at "../torch/csrc/autograd/saved_variable.cpp":227, please report a bug to PyTorch. No grad accumulator for a saved leaf

I'm not familiar with autograd in pytorch, however the error seems similar to https://github.com/pytorch/pytorch/issues/103726 and https://github.com/pytorch/pytorch/issues/90481.

Pytorch version: v2.0.1

qiuyang163 commented 1 year ago

same question

Legend94rz commented 1 year ago

@wjfwzzc Excuse me, I'm interesting about you traceback printing of you python, irrelavant about this issue though.. Could I know how to let python print the stack trace like yours?

wjfwzzc commented 1 year ago

@wjfwzzc Excuse me, I'm interesting about you traceback printing of you python, irrelavant about this issue though.. Could I know how to let python print the stack trace like yours?

https://github.com/Delgan/loguru

qiuyang163 commented 1 year ago

@Legend94rz I train stable diffusion and this problem appered。I think if you train any llm or stable diffusion model and open gradient_checkpointing,you will see the same problem

wjfwzzc commented 1 year ago

ping @tridao ,would you fix this issue or provide some workaround? FSDP + activation checkpointing is kind of a common setting for large transformer training.

tridao commented 1 year ago

I'm not familiar with FSDP, can you post a short script to replicate?

Is the issue just activation checkpointing? Or is FSDP relevant?

152334H commented 10 months ago

does anyone know if this has been fixed since then?

i think the issue applies to a raw torch.utils.checkpoint.checkpoint(..., use_reentrant=False) as well

wangjiongw commented 2 months ago

does anyone know if this has been fixed since then?

i think the issue applies to a raw torch.utils.checkpoint.checkpoint(..., use_reentrant=False) as well

I think the problem is still there. I met the error with the same function as yours.