Open seungduk-yanolja opened 9 months ago
Can you save the tensors being passed to flash_attn_cuda.varlen_bwd and send them to me? Otherwise it would be very hard to debug?
And can you print out the value of cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, and the shape of q, k, v?
I'm hitting a similar crash:
Traceback (most recent call last):
File "//scripts/run_sft.py", line 216, in <module>
main()
File "//scripts/run_sft.py", line 161, in main
train_result = trainer.train()
File "/opt/conda/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 280, in train
output = super().train(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1555, in train
return inner_training_loop(
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1860, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2735, in training_step
self.accelerator.backward(loss)
File "/opt/conda/lib/python3.10/site-packages/accelerate/accelerator.py", line 1958, in backward
self.deepspeed_engine_wrapped.backward(loss, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/deepspeed.py", line 167, in backward
self.engine.backward(loss, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1936, in backward
self.optimizer.backward(loss, retain_graph=retain_graph)
File "/opt/conda/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 2093, in backward
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
File "/opt/conda/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
scaled_loss.backward(retain_graph=retain_graph)
File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
torch.autograd.backward(
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 274, in apply
return user_fn(self, *args)
File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 157, in backward
torch.autograd.backward(outputs_with_grad, args_with_grad)
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 274, in apply
return user_fn(self, *args)
File "/opt/conda/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 530, in backward
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
RuntimeError: CUDA error: an illegal memory access was encountered
@tridao let me know if I can collect any more info to help - if you have a code snippet or sample that would be useful. In my case flash attn is being invoked from deepspeed.
Can you save the tensors being passed to flash_attn_cuda.varlen_bwd and send them to me? Otherwise it would be very hard to debug?
And can you print out the value of cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, and the shape of q, k, v?
I tried to save the tensors when it happened but it was challenging and I failed. One thing I can tell is that it happens when I modify the gradients using the hook as explained here: https://huggingface.co/yanolja/KoSOLAR-10.7B-v0.2#technical-deep-dive
Would be hard for me to debug if I can't reproduce it. You can do try catch to hopefully save the tensors.
Same here. Please kindly use the tensors here to reproduce the issue.
Same here, and the weird part is if I using the datasets with load_dataset, it can run without issue, but if I use local dataset which loaded by load_from_disk, it will trigger this issue.
https://github.com/Dao-AILab/flash-attention/issues/338 The same issue is reproduced.
I tested various combinations such as
transformers==4.37.2
andtransformers-4.38.0.dev0
withflash-attn
versions2.5.2
,2.5.0
,2.4.3.post1
, and2.4.2
.2.4.3.post1
and2.4.2
do not have the issue, while2.5.2
and2.5.0
had the issue.For the same dataset, it consistently occurred at the 46th global step. No other datasets were tested yet.Environment