Open donglixp opened 5 months ago
q = rearrange(q, '(b h) l d -> b l h d', b=bsz).contiguous()
k = rearrange(k, '(b h) l d -> b l h d', b=bsz).contiguous()
v = rearrange(v, '(b h) l d -> b l h d', b=bsz).contiguous()
print(q.shape)
print(k.shape)
print(v.shape)
attn = flash_attn_func(q, k, v, causal=is_causal)
attn = rearrange(attn, 'b l h d -> (b h) l d')
The error message:
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
Traceback (most recent call last):
File "/tmp/amlt-code-download/fairseq/train.py", line 14, in <module>
cli_main()
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq_cli/train.py", line 543, in cli_main
distributed_utils.call_main(cfg, main)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/distributed/utils.py", line 365, in call_main
distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/distributed/utils.py", line 339, in distributed_main
main(cfg, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq_cli/train.py", line 191, in main
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
File "/opt/conda/envs/py_3.9/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq_cli/train.py", line 307, in train
log_output = trainer.train_step(samples)
File "/opt/conda/envs/py_3.9/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/trainer.py", line 850, in train_step
raise e
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/trainer.py", line 818, in train_step
loss, sample_size_i, logging_output = self.task.train_step(
File "/tmp/amlt-code-download/fairseq/tasks/gpt.py", line 253, in train_step
optimizer.backward(loss)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/optim/fp16_optimizer.py", line 393, in backward
loss.backward()
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 288, in apply
return user_fn(self, *args)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairscale/nn/checkpoint/checkpoint_activations.py", line 311, in backward
torch.autograd.backward(outputs_with_grad, args_with_grad)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 288, in apply
return user_fn(self, *args)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 236, in backward
_flash_attn_backward(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 66, in _flash_attn_backward
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
RuntimeError: Expected dout_seq_stride == out_seq_stride to be true, but got false.
The log shows the program runs well for steps and triggers the bug, rather than encountering the error at the first calling.
I tried both rocm-5.7/6.0 dockers.
The bug is related to the qkv shape:
[1, 2048, 48, 64]: works well
[2, 2048, 48, 64]: triggers the bug
[4, 2048, 48, 64]: triggers the bug
[1, 2048, 24, 128]: triggers the bug
[2, 2048, 24, 128]: triggers the bug
[2, 2048, 25, 128]: triggers the bug
[2, 2048, 24, 124]: works well
[2, 2048, 48, 62]: works well
@donglixp Can I have the script you are running?
@howiejayz
VM ROCM 5.6.0
Docker: rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1
Forward and backward with the following shapes:
[2, 2048, 48, 64]: triggers the bug
[4, 2048, 48, 64]: triggers the bug
[1, 2048, 24, 128]: triggers the bug
[2, 2048, 24, 128]: triggers the bug
[2, 2048, 25, 128]: triggers the bug
q = rearrange(q, '(b h) l d -> b l h d', b=bsz).contiguous()
k = rearrange(k, '(b h) l d -> b l h d', b=bsz).contiguous()
v = rearrange(v, '(b h) l d -> b l h d', b=bsz).contiguous()
attn = flash_attn_func(q, k, v, causal=is_causal)
attn = rearrange(attn, 'b l h d -> (b h) l d')
Although using [2, 2048, 48, 62] didn't trigger the above error. I found that the job encountered loss divergence issues, while a similar recipe ran successfully before (when the VM ROCM is 5.4 and docker is 5.7).
The error seems to be triggered when dout is not contiguous. May I ask how do you generate the dout when passing to the backward?
@howiejayz Yes, they were. The contiguous() was also handled at https://github.com/ROCm/flash-attention/blob/68aac13d3b3296d13062ab3ff40fe58d5e7b3023/flash_attn/flash_attn_interface.py#L65
@donglixp. Thanks. Can I also have the information of the repo and branch you are testing? So I can reproduce your result and see what step goes wrong.
We change the backend of flash attention 2in the branch of ck_tile I also submit an PR to support AMD / ROCm on FlashAttention 2 https://github.com/Dao-AILab/flash-attention/pull/1010 This PR using composable_kernel as backend I hope this may solve your issue
We change the backend of flash attention 2in the branch of ck_tile I also submit an PR to support AMD / ROCm on FlashAttention 2 Dao-AILab#1010 This PR using composable_kernel as backend I hope this may solve your issue
Will I be able to run on other models that used Flash-Attention-2 on Instinct GPUs if the PR is not merged yet? Btw, what is your working email? I can't find your name in Team.
We change the backend of flash attention 2in the branch of ck_tile I also submit an PR to support AMD / ROCm on FlashAttention 2 Dao-AILab#1010 This PR using composable_kernel as backend I hope this may solve your issue
Will I be able to run on other models that used Flash-Attention-2 on Instinct GPUs if the PR is not merged yet? Btw, what is your working email? I can't find your name in Team.
chunyu.lai@amd.com
Problem Description
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/optim/fp16_optimizer.py", line 393, in backward loss.backward() File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_tensor.py", line 492, in backward torch.autograd.backward( File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/init.py", line 251, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 288, in apply return user_fn(self, args) File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairscale/nn/checkpoint/checkpoint_activations.py", line 311, in backward torch.autograd.backward(outputs_with_grad, args_with_grad) File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/init.py", line 251, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 288, in apply return user_fn(self, args) File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 236, in backward _flash_attn_backward( File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 66, in _flash_attn_backward dq, dk, dv, softmax_d, = flash_attn_cuda.bwd( RuntimeError: Expected dout_seq_stride == out_seq_stride to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
Operating System
20.04.6 LTS (Focal Fossa)
CPU
AMD EPYC 7V12 64-Core Processor
GPU
AMD Instinct MI250X
ROCm Version
ROCm 6.0.0, ROCm 5.7.1
ROCm Component
No response
Steps to Reproduce
No response
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response