ROCm / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
142 stars 46 forks source link

[Issue]: RuntimeError: FlashAttention forward only supports head dimension at most 128 #51

Closed xxtars closed 7 months ago

xxtars commented 8 months ago

Problem Description

When attempting to train with Gemma-7B and flash-attention, I encountered a RuntimeError stating FlashAttention forward only supports head dimension at most 128. Currently, the version of flash-attention-rocm is 2.0.4, whereas the mainline version has advanced to 2.5.6. I would like to ascertain if this discrepancy is related to the version difference or if it's due to the capabilities of the GPU hardware.

Operating System

SLES 15-SP4

CPU

AMD EPYC 7A53

GPU

AMD Instinct MI250X

ROCm Version

ROCm 5.6.0

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

farshadghodsian commented 7 months ago

Someone had a similar issue in another thread and there was a patch/solution that was proposed for this to update to fallback to sub-quadratic attention when attention heads are more than 128, https://github.com/ROCm/flash-attention/issues/27#issuecomment-1988466291. Someone also wrote a guide for this here: https://github.com/huggingface/diffusers/discussions/7172. Hopefully this helps.

xxtars commented 7 months ago

Thanks for you help!