Open logicchains opened 6 months ago
@logicchains thanks for the tips on GPU convergence! We will experiment with this as we set up convergent regimes for GPUs.
@anfals please be aware of this as you do convergence testing on GPU
@anfals is this something you're still working on, or already fixed ? Thanks!
I was training a llama model on GPU, with a custom embedding. It worked fine with 12 layers, dim 1024, seq length 256, but loss would become nan after the first step if setting num_layers to more than 17. I debugged the gradients, and found after each layer their magnitude would increase by around 100x, until they hit float32_max at around the 18th layer and became inf, leading to nan loss.
The gradient explosion seemed to be coming from
local_exps = jnp.exp(attn_weights - local_max)
in attentions.py.Changing
DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
toDEFAULT_MASK_VALUE = -jnp.inf
fixed the issue, and the gradients' magnitude stopped increasing after each level.Presumably the issue wasn't noticed during TPU training as that uses a separate codepath.