As mentioned in flax's doc, with fp16 training, normalization should still use fp32 for stability.
Below is what I observed in evaluating LLaMA with EasyLM. All experiments use the same sample:
dtype=jnp.float32 or dtype=jnp.bfloat16: the loss is 0.12
dtype=jnp.float16, without this pr's fix: the loss is 2
dtype=jnp.float16, with this pr's fix: the loss is 0.12
As mentioned in flax's doc, with fp16 training, normalization should still use fp32 for stability. Below is what I observed in evaluating LLaMA with EasyLM. All experiments use the same sample:
dtype=jnp.float32
ordtype=jnp.bfloat16
: the loss is 0.12dtype=jnp.float16
, without this pr's fix: the loss is 2dtype=jnp.float16
, with this pr's fix: the loss is 0.12