Closed flxst closed 5 months ago
Short update after our changes from today:
After our decision today to try out https://github.com/Dao-AILab/flash-attention/tree/main instead of our own (Group Query) Attention implementation, we provided a first draft.
Now two things remain to be done:
flash-attn
. The issues is currently, that flash-attn
requires some prerequesites to get installed. It needs atleast CUDA 11.6 installed. On top, one needs to install some build dependencies. The authors mentioned here to use a non-isolated build environment (probably to access the installed CUDA version, to compile stuff). Unfortunately, this is not trivial to represent without our own pyproject.toml
. My suggestion here would be either to check, if this is somehow achievable with a clever trick (e.g. using this, but I have some real doubts if this works). Alternatively we could add a respective remark in the README.md
.I won't be able to have a look at this until thursday this week.
Short update after our changes from today:
After our decision today to try out https://github.com/Dao-AILab/flash-attention/tree/main instead of our own (Group Query) Attention implementation, we provided a first draft.
Now two things remain to be done:
- Benchmark the new implementation against the previous. For this, @fromm-m agreed to launch some test setups on Leonardo and check the throughput
- Add a remark about the installation of
flash-attn
. The issues is currently, thatflash-attn
requires some prerequesites to get installed. It needs atleast CUDA 11.6 installed. On top, one needs to install some build dependencies. The authors mentioned here to use a non-isolated build environment (probably to access the installed CUDA version, to compile stuff). Unfortunately, this is not trivial to represent without our ownpyproject.toml
. My suggestion here would be either to check, if this is somehow achievable with a clever trick (e.g. using this, but I have some real doubts if this works). Alternatively we could add a respective remark in theREADME.md
.I won't be able to have a look at this until thursday this week.
Regarding your first point:
I did a benchmark and it works even faster than the previous pytorch flash attention implementation on a 3B paramter scale.
Regarding your second point: We opened a new Issue #86 to refactor the readme, where we will also describe the installation of FlashAttention.
Re-opened version of https://github.com/Modalities/modalities/pull/41.
Potential solution for handling the combination of GQA and FlashAttention: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html