ROCm / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
141 stars 46 forks source link

[Issue]: Memory access fault in tests/test_flash_attn_ck.py::test_flash_attn_qkvpacked[0.17-97-80-False-False-False-False-dtype1] #94

Closed IMbackK closed 1 week ago

IMbackK commented 2 weeks ago

Problem Description

I added gfx908 to https://github.com/ROCm/flash-attention/blob/c1d146cbd5becd9e33634b1310c2d27a49c7e862/setup.py#L128C5-L128C18 It built and also works, for all my usecases, fine. However running the unit "tests python -m pytest -v -s tests/test_flash_attn_ck.py" causes almost all tests to pass, however executing the tests/test_flash_attn_ck.py::test_flash_attn_qkvpacked[0.17-97-80-False-False-False-False-dtype1] causes a crash with:

Memory access fault by GPU node-2 (Agent handle: 0x5ca6ee039290) on address 0xf5134000. Reason: Page not present or supervisor privilege.

this pytest.log shows the crash, its sams to crash in the kernel at 0x7073444204d0 which corresponds to:

void (anonymous namespace)::softmax_warp_forward<float, float, float, 7, false, false>(float*, float const*, int, int, int, bool const*, int, bool) in section .data.rel.ro of /usr/lib/libtorch_hip.so

exercising this kernel via PyTorch directly causes no issue. I can also provide a core dump privately if desired.

Looking from here it seams the issue is not in flash_attn but rather in miopen, llvm or pytorch but lacking a way to determine exactly where the issue is i filed it here.

Operating System

Ubuntu 24.04

CPU

Amd Epyc 7552

GPU

GFX908

ROCm Version

ROCm 6.2.3

ROCm Component

No response

Steps to Reproduce

build flash_attn for gfx908, run the unit tests via python -m pytest -v -s tests/test_flash_attn_ck.py

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

ppanchad-amd commented 2 weeks ago

Hi @IMbackK. Internal ticket has been created to investigate your issue. Thanks!

IMbackK commented 2 weeks ago

Further investigation has revealed that the issue is at: https://github.com/ROCm/flash-attention/blob/c1d146cbd5becd9e33634b1310c2d27a49c7e862/tests/test_flash_attn_ck.py#L93

softmax_warp_forward crashes as dropout_mask, passed to attention_qkvpacked_ref points to invalid memory. dropout_mask is created based on the output of flash_attn_qkvpacked_func which itself is invalid memory. Thus it now looks like the problem is in flash_attn itself or ck.

The issue occurs only with bf16 datatype, by avoiding running those tests the unit tests complete without other issues.

jamesxu2 commented 1 week ago

Hi @IMbackK,

I think this issue has already been reported in upstream FlashAttention. Unfortunately, you'll have to stick with FP16 for now if you plan to use the (unsupported) MI100.