Open Karbo123 opened 2 years ago
Thanks for the report. The function should be numerically stable.
Which commit of FlashAttention are you using? On which GPU? What are the dimensions of the attention?
In order for me to reproduce the issue, can you save the arguments that caused NaN and send it to me? That'd be very helpful.
For example, you can add these lines right before the return statement of the backward function to save the tensors to the file nan_repro.pt
:
if dqkv.isnan().any():
state_dict = {'dout': dout, 'qkv': qkv, 'out': out, 'softmax_lse': softmax_lse,
'cu_seqlens': cu_seqlens, 'max_seqlen': ctx.max_seqlen,
'dropout_p': ctx.dropout_p, 'softmax_scale': ctx.softmax_scale,
'causal': ctx.causal, 'rng_state': rng_state}
torch.save('nan_repro.pt', state_dict)
breakpoint()
Thanks for your help!
I was having the same problem... Long story short, it looks like dq
, dk
and dv
need to be zeroed-out, since they are used as accumulators? However, currently, flash_attention_interface allocates them via torch.empty_like
. Setting them to 0 before flash_attn_cuda.bwd
seems to have resolved the issue.
I'm very curious about this. I think all the of values in dq
, dk
, dv
should overwritten during the execution of the backward pass.
The only problematic scenario I could imagine is when q
, k
, v
are longer than what cu_seqlens
indicate.
For example, if q
has shape (10, nheads, headdim) where 10 is supposed to be the total batch * seqlen, then dq
is allocated as torch.empty_like(q)
. If e.g. cu_seqlens = [0, 5, 8]
(which says that the batch has 2 sequences, 1st sequence being stored in index 0 -> 4, and 2nd sequence being stored in 5 -> 7), then during the execution only values from indices 0->7 are written, and values from indices 8 -> 9 are not overwritten.
@vadimcn could you say more about your setting? Does it fall into this case?
The only problematic scenario I could imagine is when q, k, v are longer than what cu_seqlens indicate.
Yes, this was indeed the case :facepalm:. Thanks for the hint!
@tridao I can reproduce this error after a very short time using the script below. Is this a user-side usage error?
import torch
import torch.optim as optim
import torch.nn.functional as F
from functools import partial
from flash_attn.modules.block import Block
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp
embed_dim = 256 # !!! change this to 64 and the error will not be observable !!!
batch_size = 16
num_heads = 8
seq_length = 512
dim_feedforward = 1024
learning_rate = 0.01
device = torch.device("cuda")
torch.set_default_dtype(torch.bfloat16)
torch.autograd.set_detect_anomaly(True)
# Initialize model
model = Block( # TransformerEncoderLayer
embed_dim,
mixer_cls=partial(
MHA,
num_heads=num_heads,
use_flash_attn=True,
rotary_emb_dim=0,
),
mlp_cls=partial(Mlp, hidden_features=dim_feedforward),
resid_dropout1=0.0,
resid_dropout2=0.0,
prenorm=False,
).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
inputs = torch.full((batch_size, seq_length, embed_dim), 1000.0, device=device)
# Training loop
for i in range(999999):
print(f'Iteration {i + 1}')
optimizer.zero_grad()
output = model(inputs)
loss = output.mean()
loss.backward()
optimizer.step()
Output:
/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/ops/triton/layer_norm.py:959: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
def forward(
/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/ops/triton/layer_norm.py:1018: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
def backward(ctx, dout, *args):
Iteration 1
Iteration 2
Iteration 3
Iteration 4
Iteration 5
Iteration 6
Iteration 7
Iteration 8
Iteration 9
Iteration 10
Iteration 11
Iteration 12
Iteration 13
Iteration 14
Iteration 15
/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/autograd/graph.py:825: UserWarning: Error detected in FlashAttnQKVPackedFuncBackward. Traceback of forward call that caused the error:
File "/home/otto/Development/temp/test.py", line 41, in <module>
output = model(inputs)
File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/modules/block.py", line 195, in forward
mixer_out = self.mixer(
File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/modules/mha.py", line 670, in forward
context = self.inner_attn(qkv, **kwargs)
File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/modules/mha.py", line 122, in forward
return flash_attn_qkvpacked_func(
File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 729, in flash_attn_qkvpacked_func
return FlashAttnQKVPackedFunc.apply(
File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:110.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
File "/home/otto/Development/temp/test.py", line 43, in <module>
loss.backward()
File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
torch.autograd.backward(
File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
_engine_run_backward(
File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'FlashAttnQKVPackedFuncBackward' returned nan values in its 0th output.
Latest published version of flash-attention
.
Does the same thing happen if you use standard implementation of attention?
i.e. try use_flash_attn=False
use_flash_attn=False
Then it works fine @tridao
@otto-dev @tridao any updates? How to fix it?
I found my training loss become NAN after training many epochs, but the model params were all finite (i.e.
torch.isfinite
) after the NAN loss occurred (I checked this by loading the saved checkpoint file from disk). I tried to resume the model from the checkpoint file (i.e. the checkpoint after the NAN happened), at the very beginning, the training process just seemed to be okay, but after several epochs, the NAN loss occurred.I tried to set
torch.autograd.set_detect_anomaly(True)
andCUDA_LAUNCH_BLOCKING=1
to find out what happened, and the result showed thatFlashAttnQKVPackedFuncBackward
returned NAN values for its output.It is very strange, because the model params are all finite, and the NAN happen after several epochs, so the training data should be okay. But if the backward pass produces NAN value, why the model params don't contain any NAN values? BTW, I didn't use any gradient clipping.
I also check all the model params, and they just seem to be fine. The maximum abs max value among all the params (except the non-learnable freqency weight) is 1.415, not too large.
Do you have any suggestions on this? Is the
FlashAttnQKVPackedFunc
numerically unstable? Thank you very much! Looking forward to your reply.