huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.64k stars 26.2k forks source link

Community contribution: Adding Flash Attention 2 support for more architectures #26350

Open younesbelkada opened 11 months ago

younesbelkada commented 11 months ago

Feature request

Flash Attention 2 is a library that provides attention operation kernels for faster and more memory efficient inference and training: https://github.com/Dao-AILab/flash-attention

Screenshot 2023-09-22 at 17 49 18

Let's try to add Flash Attention 2 support for more architectures! Currently supported architectures are

It would be great to add the support for more architectures such as

... and many more

Adding this feature would require to follow the same protocol as in https://github.com/huggingface/transformers/pull/25598 . First create a new module inside the corresponding modeling file termed as xxxFlashAttention that inherits from xxxAttention and override the foward method to use the public methods from flash-attn. Make sure to have access to a GPU that supports Flash Attention 2.

Given the slight challenge of the issue, labelling it as a good second issue!

If you are interested to take up the challenge, comment below with the architecture name you want to integrate and open a PR!

Once you open a PR, feel free to ping @LysandreJik @ArthurZucker @amyeroberts @younesbelkada @fxmarty @SunMarc @pacman100 for a review

Motivation

Making LLMs more memory efficient and faster !

Your contribution

Reviewing PRs and possibly adding the support for more models

sahilbhosale63 commented 11 months ago

Hi @younesbelkada - I want to work on adding Flash Attention 2 support for GPTBigCode (Starcoder). Can I take this task? Can you please assign this task to me?

flozi00 commented 11 months ago

Will definitely take a look next week Great to see it merged now 💪

rajveer43 commented 11 months ago

I would like to work on MPT @younesbelkada

susnato commented 11 months ago

I would like to work on OPT.

ZeusFSX commented 11 months ago

Is it possible to add FlashAttention2 to GPT2 models?

younesbelkada commented 11 months ago

@sahilbhosale63 @flozi00 @rajveer43 @susnato thanks very much for your interest! Indeed it would be great if you could help us! Before assigning you to this issue can you confirm you have access to a GPU that does support Flash Attention 2: https://github.com/Dao-AILab/flash-attention#installation-and-features in order to be able to run the tests ? @ZeusFSX , yes I think that it is possible, I'll update the list accodingly

rajveer43 commented 11 months ago

@younesbelkada Yes I have

younesbelkada commented 11 months ago

OK perfect, I will assign you to MPT ! Feel free to let me know if you need any help or if you have any question, as a starting point, I would recommend to have a look at #25598 and see if you can replicate the PR for MPT. For running flash attention tests you can just run (once PR is ready):

RUN_SLOW=1 pytest -m flash_attn_test tests/models/mpt/
susnato commented 11 months ago

@younesbelkada yes I have.

younesbelkada commented 11 months ago

Thanks @susnato , perfect then, let me know whenever you start the PR and if you have any question ! Check out my instructions above for more details

sahilbhosale63 commented 11 months ago

@younesbelkada Unfortunately, My GPU is not supported

rajveer43 commented 11 months ago

OK perfect, I will assign you to MPT ! Feel free to let me know if you need any help or if you have any question, as a starting point, I would recommend to have a look at #25598 and see if you can replicate the PR for MPT. For running flash attention tests you can just run (once PR is ready):

RUN_SLOW=1 pytest -m flash_attn_tests tests/models/mpt/

Sure I will work on it!

jeromeku commented 11 months ago

@younesbelkada Would like to work on Persimmon. I have access to A4000, A5000, and A6000, which I believe should be compatible with FA2.

younesbelkada commented 11 months ago

Perfect sounds great, thanks for your help, I will assign you to Persimmon !

susnato commented 11 months ago

Since @sahilbhosale63 is not working on GPTBigCode (Starcoder)(as he said here) can I take that @younesbelkada?

younesbelkada commented 11 months ago

Yes no problem, thanks very much for proposing your help on this ! As a starting point you can have a look at @pacman100 's implementation here: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/personal_copilot/training/starcoder_flash_attn_monkey_patch.py

sorenmc commented 11 months ago

@younesbelkada I would like to implement it for BERT if it hasn't already been done? A lot of the models topping MTEB are still relying on this architecture! I have tested that i can run flash attention 2 on my nvidia geforce RTX 3060 TI.

younesbelkada commented 11 months ago

Awesome, thanks a lot for your help, ok I will assign you to BERT then!

DougTrajano commented 11 months ago

Hi everyone, I would like to help implement this with GPT2 if you want.

jeromeku commented 11 months ago

@younesbelkada

I have a working version for Persimmon that passes the flash_attn_v2 tests except for generate_padding_right as the original PersimmonFlashAttention does not have padding_mask as a kw input (as opposed to the Llama and Falcon flash implementations). Is this something that needs to be changed in both Persimmon Flash v1 and v2?

Also, any plans on incorporating additional optimizations, e.g., Flash Attention repo has fused layers for dense, rotary, and layer norm for faster training; and Triton kernels, more generally? Happy to investigate more!

Also, would like to help with Mistral-7b (just released). They use xformers memory efficient attention in their released implementation but also mention Tri Dao's FA in the blogpost.

younesbelkada commented 11 months ago

Hi @DougTrajano Awesome! Can you confirm you have access to a hardware that is supported by FA-2?

Screenshot 2023-09-28 at 11 23 36

@jeromeku awesome thanks! Can you move forward for Persimmon by opening a PR so that I can have a look?

Also, any plans on incorporating additional optimizations, e.g., Flash Attention repo has fused layers for dense, rotary, and layer norm for faster training; and Triton kernels, more generally? Happy to investigate more!

If that is something that can nicely fit into the API without any breaking behaviour that would be great !

Also, would like to help with Mistral-7b (just released). They use xformers memory efficient attention in their released implementation but also mention Tri Dao's FA in the blogpost.

I think Mistral's attention has been released in the latest version of FA-2 --> Would you be happy to open a PoC PR so that I can play with it and see what we can do?

Again thanks a lot!

younesbelkada commented 11 months ago

Hi @jeromeku I had to check internally for Mistral, given the very recent release and the urgency, we'll take this over (https://github.com/huggingface/transformers/pull/26464); if you have started a PR, I'm very happy to start from it or to add you as a co-author to the PR ! We might also refactor things a bit to support Local attention introduced by Mistral, so that needs further investigation, I'll keep you posted

rajveer43 commented 11 months ago

@younesbelkada what is the expected deadline to complete MPT, I have other issues to tackle on so I can plan accordingly

susnato commented 11 months ago

Hi @younesbelkada , I am talking this up for GPT-neo.

younesbelkada commented 11 months ago

Awesome @susnato ! Thanks ! @rajveer43 thanks for taking up MPT, will check it out!

DougTrajano commented 11 months ago

Hi @DougTrajano Awesome! Can you confirm you have access to a hardware that is supported by FA-2?

Screenshot 2023-09-28 at 11 23 36

Yes, I'll work on AWS SageMaker.

marcasty commented 11 months ago

Would love to take on GPT2!

younesbelkada commented 11 months ago

Thanks for confirming @DougTrajano ! @marcasty thanks a lot for your interest, @DougTrajano has taken up GPT2, would be happy taking another model? 🙏 Can you also confirm you have access to a hardware that support FA-2 ?

susnato commented 11 months ago

Hi @younesbelkada, I am taking this up for DistillBERT.

marcasty commented 11 months ago

@younesbelkada what about T5? I have access to compatible hardware

nidhishs commented 11 months ago

Hey @younesbelkada, are other param-sizes of Llama architecture supported, such as 1.1B and 34B? I tried using flash-attention-2 with PY007/TinyLlama-1.1B-intermediate-step-240k-503b and codellama/CodeLlama-34b-hf, however I run into shape issues.

For ex, I get the following error for codellama-34b.

File "/home/ray/anaconda3/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 489, in forward
    attn_output = self._flash_attention_forward(
  File "/home/ray/anaconda3/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 526, in _flash_attention_forward
    query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
  File "/home/ray/anaconda3/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 562, in _upad_input
    query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
RuntimeError: shape '[1024, 8, 128]' is invalid for input of size 8388608
canberk17 commented 11 months ago

Hi @younesbelkada adding support for GPT2

DougTrajano commented 11 months ago

Hi @younesbelkada adding support for GPT2

well, I'll stop working on it. good work!

canberk17 commented 11 months ago

Hi @younesbelkada adding support for GPT2

well, I'll stop working on it. good work!

We can collaborate on it if you want. The PR is in draft, and I am still working it

DougTrajano commented 11 months ago

Hi @younesbelkada adding support for GPT2

well, I'll stop working on it. good work!

We can collaborate on it if you want. The PR is in draft, and I am still working it

That's fine buddy. I'll help with another issue. I just added one comment in the PR since I saw the changes. Thank you!

younesbelkada commented 11 months ago

Hi @marcasty - sure yes ! that sounds great! @nidhishs would you mind opening a new ticket for the issue with a reproducible snippet? 🙏 I think the issue might be fixed in https://github.com/huggingface/transformers/pull/26490 - can you try again on main?

nidhishs commented 11 months ago

Hey @younesbelkada, are other param-sizes of Llama architecture supported, such as 1.1B and 34B? I tried using flash-attention-2 with PY007/TinyLlama-1.1B-intermediate-step-240k-503b and codellama/CodeLlama-34b-hf, however I run into shape issues.

For ex, I get the following error for codellama-34b.

File "/home/ray/anaconda3/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 489, in forward
    attn_output = self._flash_attention_forward(
  File "/home/ray/anaconda3/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 526, in _flash_attention_forward
    query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
  File "/home/ray/anaconda3/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 562, in _upad_input
    query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
RuntimeError: shape '[1024, 8, 128]' is invalid for input of size 8388608

Hi @younesbelkada, just wanted to bring your attention to this.

younesbelkada commented 11 months ago

Hi @nidhishs please see my message above (sorry I have initially tagged the wrong person)

@nidhishs would you mind opening a new ticket for the issue with a reproducible snippet? 🙏 I think the issue might be fixed in https://github.com/huggingface/transformers/pull/26490 - can you try again on main?

MrTimmy89 commented 11 months ago

@younesbelkada I would like to implement it for BERT if it hasn't already been done? A lot of the models topping MTEB are still relying on this architecture! I have tested that i can run flash attention 2 on my nvidia geforce RTX 3060 TI.

Hi @sorenmc ! I would like to ask if some help is needed on BERT FA2-integration, though I am not a very skillful programmer) I'd like to try, cause I am facing an issue at work, where this feature might accelerate some studies

sorenmc commented 11 months ago

Hi @MrTimmy89 haven't had too much free time the last week so I have only been working a little bit on it every night. I almost have a working version and am planning to get the PR up to night! It will only support standard attention for now $$Softmax\left(\frac{QK^T}{\sqrt{d}}\right)V$$ as opposed to huggingfaces standard bert implementation that gives you the option of eg. adding relative positional encoding as $$Softmax\left(\frac{QK^T + E_{position}}{\sqrt{d}}\right)V$$ As I would have to write custom fused kernels for that to work

jeromeku commented 11 months ago

@younesbelkada

Are we also considering multimodal (i.e., IDEFICS) and VL models (i.e., Nougat) for this issue?

Also, would be interested in learning about and contributing to any larger efforts within HF on integrating optimized components across all libraries (transformers, optimum, accelerate, etc.) -- i.e., designing a more modular system / API for integrating emerging kernels / ops from xformers, triton, and any other reasonably mature external libraries.

rajveer43 commented 11 months ago

OK perfect, I will assign you to MPT ! Feel free to let me know if you need any help or if you have any question, as a starting point, I would recommend to have a look at #25598 and see if you can replicate the PR for MPT. For running flash attention tests you can just run (once PR is ready):

RUN_SLOW=1 pytest -m flash_attn_test tests/models/mpt/

@younesbelkada I need to add support for MPT only then why replicating the PR?

rajveer43 commented 11 months ago

Feature request

Flash Attention 2 is a library that provides attention operation kernels for faster and more memory efficient inference and training: https://github.com/Dao-AILab/flash-attention

I was also thinking of adding FA 2 in Nought,

Rubikalubi commented 11 months ago

I would also like to give a model a shot. Since all of the models from the list are taken, are there any recommendations for additional models for which this would be beneficial to implement? @younesbelkada Maybe ALBERT?

rajveer43 commented 11 months ago

I would also like to give a model a shot. Since all of the models from the list are taken, are there any recommendations for additional models for which this would be beneficial to implement? @younesbelkada Maybe ALBERT?

Would you like to work on with me on on adding FA 2 for MPT> @Rubikalubi .here is the PR: #26471

Rubikalubi commented 11 months ago

Would you like to work on with me on on adding FA 2 for MPT> @Rubikalubi .here is the PR: #26471

Yeah sure @rajveer43 just let me know how i can help.

younesbelkada commented 11 months ago

Hi, sure yes, you can take MPT ! Let us know when the PR is ready for review

rajveer43 commented 10 months ago

@younesbelkada I am not able to provide time to MPT you can assign it to someone else.

susnato commented 10 months ago

Hey, @rajveer43 I am taking this up for MPT then.

Rubikalubi commented 10 months ago

@susnato i already started working on MPT last week (see comment above), but hadn't had the time to finish it yet. If you want, we can work on it together