Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.25k stars 1.19k forks source link

NaN Gradient Issue with Gemma2 Model Training #1065

Open kiddj opened 1 month ago

kiddj commented 1 month ago

I'm encountering an issue where gradients become NaN during the training of the Gemma2 model with transformers and flash-attn. I used soft-capping for training.

Environment:

transformers @ git+https://github.com/huggingface/transformers.git@ac946aac257cadfa8264fa4a284cd0ea1061c5b5 flash-attn==2.6.1 torch==2.3.1

tridao commented 1 month ago

softcapping is not supported yet in the backward pass

tanliboy commented 1 month ago

@tridao thanks for the information! What's our targeted timeline for supporting soft capping in the backward pass?