facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.41k stars 597 forks source link

Which kernel is best on MI200 for training? #1060

Open npuichigo opened 3 months ago

npuichigo commented 3 months ago

ROCm/triton, ROCm/flash-attention or the fmha ck implementation?

jianyuh commented 3 months ago

We haven't tested fmha ck for training. For inference, currently fmha ck is fastest. cc @qianfengz @zjing14

npuichigo commented 3 months ago

@jianyuh Since triton kernels are removed now, the only way I can use flash attention on MI200 is to install https://github.com/ROCm/flash-attention and then patch xformers with https://github.com/vllm-project/vllm/blob/v0.2.6/rocm_patch/flashpy_xformers-0.0.23.rocm.patch to use the fw and bw from it.

Much like the way used by vllm https://docs.vllm.ai/en/v0.2.7/getting_started/amd-installation.html#option-2-build-from-source

Any better way to use flash attention for training on rocm?

win10ogod commented 3 months ago

@jianyuh Since triton kernels are removed now, the only way I can use flash attention on MI200 is to install https://github.com/ROCm/flash-attention and then patch xformers with https://github.com/vllm-project/vllm/blob/v0.2.6/rocm_patch/flashpy_xformers-0.0.23.rocm.patch to use the fw and bw from it.

Much like the way used by vllm https://docs.vllm.ai/en/v0.2.7/getting_started/amd-installation.html#option-2-build-from-source

Any better way to use flash attention for training on rocm?

try it : https://github.com/mosaicml/llm-foundry/tree/main?tab=readme-ov-file#amd-beta-support

npuichigo commented 3 months ago

@win10ogod llm-foundry use the flash_attention_for_rocm2 branch of ROCm/flash_attention. What's the difference with flash_attention_for_rocm branch?

win10ogod commented 3 months ago

@win10ogod llm-foundry use the flash_attention_for_rocm2 branch of ROCm/flash_attention. What's the difference with flash_attention_for_rocm branch?

flash_attention2 = flash_attention_for_rocm2 and flash_attention1 = flash_attention_for_rocm