AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.47k stars 277 forks source link

DEFAULT_MASK_VALUE causes gradient explosion and nan loss on deep models #614

Open logicchains opened 5 months ago

logicchains commented 5 months ago

I was training a llama model on GPU, with a custom embedding. It worked fine with 12 layers, dim 1024, seq length 256, but loss would become nan after the first step if setting num_layers to more than 17. I debugged the gradients, and found after each layer their magnitude would increase by around 100x, until they hit float32_max at around the 18th layer and became inf, leading to nan loss.

The gradient explosion seemed to be coming from local_exps = jnp.exp(attn_weights - local_max) in attentions.py.

Changing

DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) to DEFAULT_MASK_VALUE = -jnp.inf fixed the issue, and the gradients' magnitude stopped increasing after each level.

Presumably the issue wasn't noticed during TPU training as that uses a separate codepath.

rwitten commented 5 months ago

@logicchains thanks for the tips on GPU convergence! We will experiment with this as we set up convergent regimes for GPUs.

@anfals please be aware of this as you do convergence testing on GPU

shralex commented 1 week ago

@anfals is this something you're still working on, or already fixed ? Thanks!