young-geng / EasyLM

Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
Apache License 2.0
2.38k stars 254 forks source link

Fix unstable RMSNorm #50

Closed ZYHowell closed 1 year ago

ZYHowell commented 1 year ago

As mentioned in flax's doc, with fp16 training, normalization should still use fp32 for stability. Below is what I observed in evaluating LLaMA with EasyLM. All experiments use the same sample: dtype=jnp.float32 or dtype=jnp.bfloat16: the loss is 0.12 dtype=jnp.float16, without this pr's fix: the loss is 2 dtype=jnp.float16, with this pr's fix: the loss is 0.12

young-geng commented 1 year ago

Thanks a lot for the fix! Do you know if this also affects bfloat16?

ZYHowell commented 1 year ago

I tested few samples on bf16 with the fix, the loss looks good(still around 0.12 for the same sample as above)