unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi, Qwen & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
17.93k stars 1.24k forks source link

Flash Attention 2 doesn't work with Gemma 2 models #1014

Closed rohhro closed 1 week ago

rohhro commented 1 month ago

I have a conda env where "FA2=True" and another env where "FA2=False" (as dispayed in the terminal when run the finetuning script), the vRAM usuable of tuning the same Gemma 2 model (2b or 9b) are the same, even in the script "attn_implementation = "flash_attention_2"" is presented.

danielhanchen commented 1 month ago

You'll have to install flash-attn in both environments - also apologies on the delay

rohhro commented 1 month ago

You'll have to install flash-attn in both environments - also apologies on the delay

Thanks! No worries!

I have 2 seperate conda environments, one has FA2, the other one doesn't. It's intentional, because I want to test the vRAM usage with or without FA2.

I have tested in both envs using the same Gemma 2 2B training script, the vRAM usages are the same.

That's why I filed this issue saying the FA2 is not working when finetune Gemma 2.

shimmyshimmer commented 1 week ago

@rohhro we did a fix few months ago. Let us know if you're still encountring the issue!