Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.35k stars 1.34k forks source link

Function 'FlashAttnQKVPackedFuncBackward' returned nan values in its 0th output #41

Open Karbo123 opened 2 years ago

Karbo123 commented 2 years ago

I found my training loss become NAN after training many epochs, but the model params were all finite (i.e. torch.isfinite) after the NAN loss occurred (I checked this by loading the saved checkpoint file from disk). I tried to resume the model from the checkpoint file (i.e. the checkpoint after the NAN happened), at the very beginning, the training process just seemed to be okay, but after several epochs, the NAN loss occurred.

I tried to set torch.autograd.set_detect_anomaly(True) and CUDA_LAUNCH_BLOCKING=1 to find out what happened, and the result showed that FlashAttnQKVPackedFuncBackward returned NAN values for its output.

It is very strange, because the model params are all finite, and the NAN happen after several epochs, so the training data should be okay. But if the backward pass produces NAN value, why the model params don't contain any NAN values? BTW, I didn't use any gradient clipping.

I also check all the model params, and they just seem to be fine. The maximum abs max value among all the params (except the non-learnable freqency weight) is 1.415, not too large.

Do you have any suggestions on this? Is the FlashAttnQKVPackedFunc numerically unstable? Thank you very much! Looking forward to your reply.

tridao commented 2 years ago

Thanks for the report. The function should be numerically stable.

Which commit of FlashAttention are you using? On which GPU? What are the dimensions of the attention?

In order for me to reproduce the issue, can you save the arguments that caused NaN and send it to me? That'd be very helpful. For example, you can add these lines right before the return statement of the backward function to save the tensors to the file nan_repro.pt:

        if dqkv.isnan().any():
            state_dict = {'dout': dout, 'qkv': qkv, 'out': out, 'softmax_lse': softmax_lse,
                          'cu_seqlens': cu_seqlens, 'max_seqlen': ctx.max_seqlen,
                          'dropout_p': ctx.dropout_p, 'softmax_scale': ctx.softmax_scale,
                          'causal': ctx.causal, 'rng_state': rng_state}
            torch.save('nan_repro.pt', state_dict)
            breakpoint()

Thanks for your help!

vadimcn commented 2 years ago

I was having the same problem... Long story short, it looks like dq, dk and dv need to be zeroed-out, since they are used as accumulators? However, currently, flash_attention_interface allocates them via torch.empty_like. Setting them to 0 before flash_attn_cuda.bwd seems to have resolved the issue.

tridao commented 2 years ago

I'm very curious about this. I think all the of values in dq, dk, dv should overwritten during the execution of the backward pass.

The only problematic scenario I could imagine is when q, k, v are longer than what cu_seqlens indicate. For example, if q has shape (10, nheads, headdim) where 10 is supposed to be the total batch * seqlen, then dq is allocated as torch.empty_like(q). If e.g. cu_seqlens = [0, 5, 8] (which says that the batch has 2 sequences, 1st sequence being stored in index 0 -> 4, and 2nd sequence being stored in 5 -> 7), then during the execution only values from indices 0->7 are written, and values from indices 8 -> 9 are not overwritten.

@vadimcn could you say more about your setting? Does it fall into this case?

vadimcn commented 2 years ago

The only problematic scenario I could imagine is when q, k, v are longer than what cu_seqlens indicate.

Yes, this was indeed the case :facepalm:. Thanks for the hint!

otto-dev commented 2 months ago

@tridao I can reproduce this error after a very short time using the script below. Is this a user-side usage error?

import torch
import torch.optim as optim
import torch.nn.functional as F
from functools import partial
from flash_attn.modules.block import Block
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp

embed_dim = 256 # !!! change this to 64 and the error will not be observable !!!
batch_size = 16
num_heads = 8
seq_length = 512
dim_feedforward = 1024
learning_rate = 0.01
device = torch.device("cuda")
torch.set_default_dtype(torch.bfloat16)
torch.autograd.set_detect_anomaly(True)

# Initialize model
model = Block( # TransformerEncoderLayer
        embed_dim,
        mixer_cls=partial(
            MHA,
            num_heads=num_heads,
            use_flash_attn=True,
            rotary_emb_dim=0,
        ),
        mlp_cls=partial(Mlp, hidden_features=dim_feedforward),
        resid_dropout1=0.0,
        resid_dropout2=0.0,
        prenorm=False,
    ).to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
inputs = torch.full((batch_size, seq_length, embed_dim), 1000.0, device=device)

# Training loop
for i in range(999999):
    print(f'Iteration {i + 1}')
    optimizer.zero_grad()
    output = model(inputs)
    loss = output.mean()
    loss.backward()
    optimizer.step()

Output:

/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/ops/triton/layer_norm.py:959: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  def forward(
/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/ops/triton/layer_norm.py:1018: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  def backward(ctx, dout, *args):
Iteration 1
Iteration 2
Iteration 3
Iteration 4
Iteration 5
Iteration 6
Iteration 7
Iteration 8
Iteration 9
Iteration 10
Iteration 11
Iteration 12
Iteration 13
Iteration 14
Iteration 15
/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/autograd/graph.py:825: UserWarning: Error detected in FlashAttnQKVPackedFuncBackward. Traceback of forward call that caused the error:
  File "/home/otto/Development/temp/test.py", line 41, in <module>
    output = model(inputs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/modules/block.py", line 195, in forward
    mixer_out = self.mixer(
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/modules/mha.py", line 670, in forward
    context = self.inner_attn(qkv, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/modules/mha.py", line 122, in forward
    return flash_attn_qkvpacked_func(
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 729, in flash_attn_qkvpacked_func
    return FlashAttnQKVPackedFunc.apply(
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:110.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/home/otto/Development/temp/test.py", line 43, in <module>
    loss.backward()
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'FlashAttnQKVPackedFuncBackward' returned nan values in its 0th output.

Latest published version of flash-attention.

tridao commented 2 months ago

Does the same thing happen if you use standard implementation of attention? i.e. try use_flash_attn=False

otto-dev commented 2 months ago

use_flash_attn=False

Then it works fine @tridao

Oktai15 commented 1 month ago

@otto-dev @tridao any updates? How to fix it?