ROCm / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
142 stars 46 forks source link

MI100 Support #24

Open LoggerHead22 opened 1 year ago

LoggerHead22 commented 1 year ago

Hi, the documentation says that this implementation is compatible only with the MI200 and MI300 GPUs. But what about the MI100 gpu?

The code contains such conditions that formally match the MI100 with the gfx908 architecture.

bool is_gfx90x = dprops->major == 9 && dprops->minor == 0;
bool is_gfx94x = dprops->major == 9 && dprops->minor == 4;
TORCH_CHECK(is_gfx90x || is_gfx94x, "FlashAttention only supports AMD MI200 GPUs or newer.");

Will this code be compatible with MI100 in practice? If not, are there any plans to add such support? Or what are the reasons that keep you from adding support for the MI100?

dejay-vu commented 12 months ago

Hi @LoggerHead22, this code appears to be a logic fault, thanks for noting.

We haven't tested the FA on MI100 since we did most of our testing on MI250&MI300 so we are limiting the support archs. I am not sure whether it will work correctly on MI100 but you can try by adding gfx908 to the valid archs. I suppose the building process will be fine.

LoggerHead22 commented 12 months ago

Thanks for the clarification @howiejayz . Your advice really helped, the code is compiled for mi100 and runs.

However, I encountered an error during the build, which is caused by the logic of the patch.

hipified_header_filepath = HIPIFY_FINAL_RESULT[header_filepath].hipified_path
      AttributeError: 'dict' object has no attribute 'hipified_path'

This seems logical, because a dict is being created here and then we try to take its _hipifiedpath attribute.

Replacing dict with an object of the HipifyResult class in patch helped me.

ccbadd commented 10 months ago

Has this patch been merged to the main branch or do we need to apply it in order to test?

ehartford commented 10 months ago

I need mi100 support

ehartford commented 10 months ago

Hi @LoggerHead22, this code appears to be a logic fault, thanks for noting.

We haven't tested the FA on MI100 since we did most of our testing on MI250&MI300 so we are limiting the support archs. I am not sure whether it will work correctly on MI100 but you can try by adding gfx908 to the valid archs. I suppose the building process will be fine.

If you need hardware for testing mi100, I volunteer my server for this purpose. I have 8x mi100 with infinity fabric.

ehartford@gmail.com

ehartford commented 10 months ago

Hi @sabreshao @howiejayz can you please give me a path forward?

I have a bunch of mi100s and I would like them to be hot. Without flash attention, I am blocked.

Maybe you could show me where in the code I would add it? give me some advice?

dejay-vu commented 10 months ago

Hi @LoggerHead22, this code appears to be a logic fault, thanks for noting.

We haven't tested the FA on MI100 since we did most of our testing on MI250&MI300 so we are limiting the support archs. I am not sure whether it will work correctly on MI100 but you can try by adding gfx908 to the valid archs. I suppose the building process will be fine.

Hi @ehartford! Currently I have no time to test FA on MI100 but could you try build and run based on this comment?

TNT3530 commented 10 months ago

I was able to compile flash attention for the MI100 using the docker image. Simply adding gfx908 to the target arch array (or in my case, removing everything BUT native and gfx908) makes it run fine. (Note: this also applies to the vLLM ROCm docker image, which was my use case)

Attempts to compile outside of docker seem to fail on ROCm 6.0 due to this issue, though I was unable to downgrade back to 5.7 to test on my machine.

luizanao commented 10 months ago

I managed to build MI100 (gfx908) as well but the env var didn't work @TNT3530 . This is because the setup is protected against unknown architectures and gfx908 is not listed. I will open a PR for adding that since gfx908 definitely works.

luizanao commented 10 months ago

Here's my PR, you folks might benefit from it: https://github.com/ROCmSoftwarePlatform/flash-attention/pull/38

ehartford commented 9 months ago

How do I install flash attention for mi100? How is the procedure from the README.md different?

luizanao commented 9 months ago

@ehartford passing the card arch to the build should be enough: export GPU_ARCHS="gfx908"

gittb commented 4 months ago

Also curious if support for Mi100 was finalized.

ehartford commented 4 months ago

This is awesome! Can't wait to try it!

jamestwhedbee commented 2 months ago

just realized Mi100 support was removed

jamestwhedbee commented 2 months ago

@jayz0123 was that intentional

IMbackK commented 3 weeks ago

I can confirm that when this is patched away again to allow mi100 to build the package, the latest main builds and works fine on gfx908 at least for the dimensions i tried. So this restriction seams pretty silly, and its quite puzzling why mi100 was removed from the array again given it still works fine.

ehartford commented 3 weeks ago

Then - someone doesn't want it to work on mi100

ehartford commented 3 weeks ago

I can confirm that when this is patched away again to allow mi100 to build the package, the latest main builds and works fine on gfx908 at least for the dimensions i tried. So this restriction seams pretty silly, and its quite puzzling why mi100 was removed from the array again given it still works fine.

Could you please make a PR that enables mi100 so I can test it?

ruboot commented 6 days ago

I can confirm that when this is patched away again to allow mi100 to build the package, the latest main builds and works fine on gfx908 at least for the dimensions i tried. So this restriction seams pretty silly, and its quite puzzling why mi100 was removed from the array again given it still works fine.

Could you please make a PR that enables mi100 so I can test it?

pytest test_flash_attn_ck.py /usr/local/lib/python3.10/dist-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ============================================================================================ test session starts ============================================================================================= platform linux -- Python 3.10.12, pytest-8.3.3, pluggy-1.5.0 rootdir: /home/power/shared/code/flash-attention/tests configfile: pyproject.toml plugins: asyncio-0.24.0, anyio-4.6.2.post1, typeguard-4.3.0 asyncio: mode=strict, default_loop_scope=None collected 410996 items

test_flash_attn_ck.py ................................................................................................................................................................................ [ 0%] ...................................................................................................................................................................................................... [ 0%] ...................................................................................................................................................................................................... [ 0%] ...................................................................................................................................................................................................... [ 0%] ...................................................................................................................................................................................................... [ 0%] ...................................................................................................................................................................................................... [ 0%] ...................................................................................................................................................................................................... [ 0%] ...................................................................................................................................................................................................... [ 0%] ...................................................................................................................................................................................................... [ 0%] ...................................................................................................................................................................................................... [ 0%] ...................................................................................................................................................................................................... [ 0%] ...................................................................................................................................................................................................... [ 0%] ...................................................................................................................................................................................................... [ 0%] ...................................................................................................................................................................................................... [ 0%] ...................................................................................................................................................................................................... [ 0%] .............................................................................................................................Fatal Python error: Aborted

Thread 0x00007f15117fd640 (most recent call first):

Thread 0x00007f1511ffe640 (most recent call first): Thread 0x00007f15127ff640 (most recent call first): Thread 0x00007f15187ff640 (most recent call first): Thread 0x00007f1c7f115000 (most recent call first): File "/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py", line 91 in _flash_attn_forward File "/usr/local/lib/python3.10/dist-packages/torch/_library/custom_ops.py", line 367 in wrapped_fn File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 632 in _fn File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 32 in inner File "/usr/local/lib/python3.10/dist-packages/torch/_library/custom_ops.py", line 324 in backend_impl File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 721 in redispatch File "/usr/local/lib/python3.10/dist-packages/torch/_library/autograd.py", line 40 in forward_no_grad File "/usr/local/lib/python3.10/dist-packages/torch/_library/autograd.py", line 113 in autograd_impl File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1116 in __call__ File "/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py", line 458 in forward File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 575 in apply File "/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py", line 1012 in flash_attn_qkvpacked_func File "/home/power/shared/code/flash-attention/tests/test_flash_attn_ck.py", line 93 in test_flash_attn_qkvpacked File "/usr/local/lib/python3.10/dist-packages/_pytest/python.py", line 159 in pytest_pyfunc_call File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in __call__ File "/usr/local/lib/python3.10/dist-packages/_pytest/python.py", line 1627 in runtest File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 174 in pytest_runtest_call File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in __call__ File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 242 in File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 341 in from_call File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 241 in call_and_report File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 132 in runtestprotocol File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 113 in pytest_runtest_protocol File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in __call__ File "/usr/local/lib/python3.10/dist-packages/_pytest/main.py", line 362 in pytest_runtestloop File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in __call__ File "/usr/local/lib/python3.10/dist-packages/_pytest/main.py", line 337 in _main File "/usr/local/lib/python3.10/dist-packages/_pytest/main.py", line 283 in wrap_session File "/usr/local/lib/python3.10/dist-packages/_pytest/main.py", line 330 in pytest_cmdline_main File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in __call__ File "/usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py", line 175 in main File "/usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py", line 201 in console_main File "/usr/local/bin/pytest", line 33 in Extension modules: numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special (total: 24) Aborted (core dumped)