triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
11.92k stars 1.41k forks source link

pretraining loss blows up with Triton Flash attention #1957

Open yzhang123 opened 11 months ago

yzhang123 commented 11 months ago

when pretraining gpt with triton flash attention loss blows up (from ~2 to 7) halfway into the training and doesn't go down anymore. If i resume from a healthy ckpt without Flash attention the loss is stable. I could reproduce this error.

I was using Alibi

DachengLi1 commented 10 months ago

@yzhang123 Met the same problem.. Do you have any recommendation?

yzhang123 commented 10 months ago

@DachengLi1 , might be a seed issue, i reran from scratch and it was ok. another fix is to swich to fp16 from bf16 or use full attention bias + bf16 which seems more stable

DachengLi1 commented 10 months ago

@yzhang123 Thanks a lot! Got it!

foreverpiano commented 2 months ago

@yzhang123 have you solved the problem? Can you share with the triton code pls?