ROCm / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
109 stars 33 forks source link

[Issue]: RuntimeError: Expected dout_seq_stride == out_seq_stride to be true, but got false. #41

Open donglixp opened 5 months ago

donglixp commented 5 months ago

Problem Description

File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/optim/fp16_optimizer.py", line 393, in backward loss.backward() File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_tensor.py", line 492, in backward torch.autograd.backward( File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/init.py", line 251, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 288, in apply return user_fn(self, args) File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairscale/nn/checkpoint/checkpoint_activations.py", line 311, in backward torch.autograd.backward(outputs_with_grad, args_with_grad) File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/init.py", line 251, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 288, in apply return user_fn(self, args) File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 236, in backward _flash_attn_backward( File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 66, in _flash_attn_backward dq, dk, dv, softmax_d, = flash_attn_cuda.bwd( RuntimeError: Expected dout_seq_stride == out_seq_stride to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)

Operating System

20.04.6 LTS (Focal Fossa)

CPU

AMD EPYC 7V12 64-Core Processor

GPU

AMD Instinct MI250X

ROCm Version

ROCm 6.0.0, ROCm 5.7.1

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

donglixp commented 5 months ago
                q = rearrange(q, '(b h) l d -> b l h d', b=bsz).contiguous()
                k = rearrange(k, '(b h) l d -> b l h d', b=bsz).contiguous()
                v = rearrange(v, '(b h) l d -> b l h d', b=bsz).contiguous()
                print(q.shape)
                print(k.shape)
                print(v.shape)
                attn = flash_attn_func(q, k, v, causal=is_causal)
                attn = rearrange(attn, 'b l h d -> (b h) l d')

The error message:

torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
torch.Size([2, 2048, 48, 64])
Traceback (most recent call last):
  File "/tmp/amlt-code-download/fairseq/train.py", line 14, in <module>
    cli_main()
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq_cli/train.py", line 543, in cli_main
    distributed_utils.call_main(cfg, main)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/distributed/utils.py", line 365, in call_main
    distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/distributed/utils.py", line 339, in distributed_main
    main(cfg, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq_cli/train.py", line 191, in main
    valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
  File "/opt/conda/envs/py_3.9/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq_cli/train.py", line 307, in train
    log_output = trainer.train_step(samples)
  File "/opt/conda/envs/py_3.9/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/trainer.py", line 850, in train_step
    raise e
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/trainer.py", line 818, in train_step
    loss, sample_size_i, logging_output = self.task.train_step(
  File "/tmp/amlt-code-download/fairseq/tasks/gpt.py", line 253, in train_step
    optimizer.backward(loss)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairseq/optim/fp16_optimizer.py", line 393, in backward
    loss.backward()
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fairscale/nn/checkpoint/checkpoint_activations.py", line 311, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 236, in backward
    _flash_attn_backward(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 66, in _flash_attn_backward
    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
RuntimeError: Expected dout_seq_stride == out_seq_stride to be true, but got false.
donglixp commented 5 months ago

The log shows the program runs well for steps and triggers the bug, rather than encountering the error at the first calling.

donglixp commented 5 months ago

I tried both rocm-5.7/6.0 dockers.

donglixp commented 5 months ago

The bug is related to the qkv shape:

[1, 2048, 48, 64]: works well

[2, 2048, 48, 64]: triggers the bug

[4, 2048, 48, 64]: triggers the bug

[1, 2048, 24, 128]: triggers the bug

[2, 2048, 24, 128]: triggers the bug

[2, 2048, 25, 128]: triggers the bug

[2, 2048, 24, 124]: works well

[2, 2048, 48, 62]: works well

howiejayz commented 4 months ago

@donglixp Can I have the script you are running?

donglixp commented 4 months ago

@howiejayz
VM ROCM 5.6.0 Docker: rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1

Forward and backward with the following shapes:

[2, 2048, 48, 64]: triggers the bug

[4, 2048, 48, 64]: triggers the bug

[1, 2048, 24, 128]: triggers the bug

[2, 2048, 24, 128]: triggers the bug

[2, 2048, 25, 128]: triggers the bug

donglixp commented 4 months ago
                q = rearrange(q, '(b h) l d -> b l h d', b=bsz).contiguous()
                k = rearrange(k, '(b h) l d -> b l h d', b=bsz).contiguous()
                v = rearrange(v, '(b h) l d -> b l h d', b=bsz).contiguous()
                attn = flash_attn_func(q, k, v, causal=is_causal)
                attn = rearrange(attn, 'b l h d -> (b h) l d')
donglixp commented 4 months ago

Although using [2, 2048, 48, 62] didn't trigger the above error. I found that the job encountered loss divergence issues, while a similar recipe ran successfully before (when the VM ROCM is 5.4 and docker is 5.7).

howiejayz commented 4 months ago

The error seems to be triggered when dout is not contiguous. May I ask how do you generate the dout when passing to the backward?

donglixp commented 4 months ago

@howiejayz Yes, they were. The contiguous() was also handled at https://github.com/ROCm/flash-attention/blob/68aac13d3b3296d13062ab3ff40fe58d5e7b3023/flash_attn/flash_attn_interface.py#L65

howiejayz commented 4 months ago

@donglixp. Thanks. Can I also have the information of the repo and branch you are testing? So I can reproduce your result and see what step goes wrong.

rocking5566 commented 1 day ago

We change the backend of flash attention 2in the branch of ck_tile I also submit an PR to support AMD / ROCm on FlashAttention 2 https://github.com/Dao-AILab/flash-attention/pull/1010 This PR using composable_kernel as backend I hope this may solve your issue

zixian-wang-amd commented 1 day ago

We change the backend of flash attention 2in the branch of ck_tile I also submit an PR to support AMD / ROCm on FlashAttention 2 Dao-AILab#1010 This PR using composable_kernel as backend I hope this may solve your issue

Will I be able to run on other models that used Flash-Attention-2 on Instinct GPUs if the PR is not merged yet? Btw, what is your working email? I can't find your name in Team.

rocking5566 commented 1 day ago

We change the backend of flash attention 2in the branch of ck_tile I also submit an PR to support AMD / ROCm on FlashAttention 2 Dao-AILab#1010 This PR using composable_kernel as backend I hope this may solve your issue

Will I be able to run on other models that used Flash-Attention-2 on Instinct GPUs if the PR is not merged yet? Btw, what is your working email? I can't find your name in Team.

chunyu.lai@amd.com