ROCm / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
142 stars 46 forks source link

Support for other modules (rotary, xentropy, layer_norm) #34

Open bbartoldson opened 10 months ago

bbartoldson commented 10 months ago

The non-flash-attention modules in this repository seem to not be installable with AMD cards. I would be happy to help address this but need some guidance.

Progress:

Errors:

  1. import rotary_emb returns ImportError: libc10.so: cannot open shared object file: No such file or directory.
  2. If you import torch then import rotary_emb, the result is ImportError: /.../x86_miniconda3/envs/flash2/lib/python3.11/site-packages/rotary_emb.cpython-311-x86_64-linux-gnu.so: undefined symbol: _Z17apply_rotary_cudaN2at6TensorES0_S0_S0_S0_S0_b.

Configuration:

dejay-vu commented 10 months ago

/.../x86_miniconda3/envs/flash2/lib/python3.11/site-packages/rotary_emb.cpython-311-x86_64-linux-gnu.so: undefined symbol: _Z17apply_rotary_cudaN2at6TensorES0_S0_S0_S0_S0_b

It looks like there are some CUDA operators in the rotary lib which is not working on AMD GPUs. Therefore you cannot use it directly without kernel supports even if you can build it.

ehartford commented 7 months ago

@howiejayz hello I need rotary on amd, is this cuda operator still missing? can you please specify which ones? can i get the fix prioritized?

dejay-vu commented 7 months ago

@ehartford Could you try this PR, which has the rotary module enabled for ROCm?

ehartford commented 7 months ago

Thanks, I will give it a try