huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.73k stars 26.94k forks source link

fine tuning the updated Phi-2 with flash-attn-2 produces very high loss > 2 #28488

Closed abacaj closed 5 months ago

abacaj commented 10 months ago

System Info

The updated code of phi-2 produces a high loss, I have tried fp16, bf16, deepspeed and fsdp the result is the same -> loss starts at 2 and keeps going higher. Setting use_flash_attention_2=False fixes this or using the old phi-2 modeling file.

torch==2.1.2 flash-attn==2.4.2 transformers==4.37.0.dev0

Who can help?

No response

Information

Tasks

Reproduction

Fine-tune the updated phi-2 model using transformers trainer

Expected behavior

Loss go down