Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.09k stars 1.31k forks source link

Applying FA3 in qwen2 model fine-tuning is slower than FA2 #1126

Open 982118809 opened 3 months ago

982118809 commented 3 months ago

Hello,I applied FA3 in the fine-tuning of the qwen2 model, using an H800 machine. The test was slower than FA2 under the same conditions.

I used FlashAttnFunc.forward in hopper/flash_attn_interface.py file to replace Qwen2Attention.forward. In the flash_attn_interface.py file added:

from transformers.models.qwen2.modeling_qwen2 import (
    Qwen2Attention,
    Qwen2Model,
    rotate_half,
)

def replace_qwen2_attn_with_flash_attn3():
    Qwen2Attention.forward = FlashAttnFunc.forward

Then, turn off attn_implementation="flash_attention_2" in the fine-tuning code and import the modified part.

Using: 1*H800-80G, 32 cpu, 256 memory qwen2-7b, 45k data, 6.5k training length

In FA3, the speed is about 34s/it image image image

but in FA2, the speed is about 24s/it image image image

And no much difference in memory usage was observed.

May I ask if I did something wrong? Thank you.

tridao commented 3 months ago

Are you using the latest commit? There's a recent update to enable causal for the backward. Can you profile to get the time for the attention kernel?

982118809 commented 3 months ago

When I started to configure the environment, I also encountered problem #1091 . After fixing this issue, tests were conducted following the successful configuration around August 1st. When was the new commit submitted you mentioned? Was I using the latest commit?

tridao commented 3 months ago

This commit: https://github.com/Dao-AILab/flash-attention/commit/bafe253042fb251a28f351ad0a2657da26263f31

982118809 commented 3 months ago

OK, I'll use this commit to test it again.

BlackBearBiscuit commented 2 months ago

OK, I'll use this commit to test it again.

How about the performance? When I pretrain deepseek-v2 in H100-80G, I met the same(FA3 is slower than FA2)

tridao commented 2 months ago

Can you profile to get the time for the attention kernel?

982118809 commented 2 months ago

OK, I'll use this commit to test it again.

How about the performance? When I pretrain deepseek-v2 in H100-80G, I met the same(FA3 is slower than FA2)

Sorry, I'm busy with other things recently. We may wait until FA3 is officially released before using it.

albaNnaksqr commented 2 months ago

Same issue when finetuning both llama3 and qwen2 model. FA3 takes more time and slightly more GPU space(not sure) than FA2. I replace the same function flash_attn_varlen_func in transformers/modeling_flash_attention_utils.py from FA2 to FA3. Maybe it is not a right way :(