Modalities / modalities

A framework for training multimodal foundation models.
MIT License
57 stars 5 forks source link

feat: group-query-attention implementation #74

Closed flxst closed 5 months ago

flxst commented 6 months ago

Re-opened version of https://github.com/Modalities/modalities/pull/41.

Potential solution for handling the combination of GQA and FlashAttention: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

luzian-hahn commented 5 months ago

Short update after our changes from today:

After our decision today to try out https://github.com/Dao-AILab/flash-attention/tree/main instead of our own (Group Query) Attention implementation, we provided a first draft.

Now two things remain to be done:

I won't be able to have a look at this until thursday this week.

fromm-m commented 5 months ago

Short update after our changes from today:

After our decision today to try out https://github.com/Dao-AILab/flash-attention/tree/main instead of our own (Group Query) Attention implementation, we provided a first draft.

Now two things remain to be done:

  • Benchmark the new implementation against the previous. For this, @fromm-m agreed to launch some test setups on Leonardo and check the throughput
  • Add a remark about the installation of flash-attn. The issues is currently, that flash-attn requires some prerequesites to get installed. It needs atleast CUDA 11.6 installed. On top, one needs to install some build dependencies. The authors mentioned here to use a non-isolated build environment (probably to access the installed CUDA version, to compile stuff). Unfortunately, this is not trivial to represent without our own pyproject.toml. My suggestion here would be either to check, if this is somehow achievable with a clever trick (e.g. using this, but I have some real doubts if this works). Alternatively we could add a respective remark in the README.md.

I won't be able to have a look at this until thursday this week.

Regarding your first point:

I did a benchmark and it works even faster than the previous pytorch flash attention implementation on a 3B paramter scale. Bildschirmfoto 2024-03-19 um 14 32 22

Regarding your second point: We opened a new Issue #86 to refactor the readme, where we will also describe the installation of FlashAttention.