allenai / OLMo

Modeling, training, eval, and inference code for OLMo
https://allenai.org/olmo
Apache License 2.0
4.2k stars 392 forks source link

why is the total_grad_norm increasing across training? #596

Open ryanyxw opened 1 month ago

ryanyxw commented 1 month ago

❓ The question

This is a purely conceptual/intuition question, and I think it can only be asked with proper context with OLMO (which is why I didn't go to StackOverflow). I'd be very grateful if someone could answer this.

I noticed while going through the w&b training logs of OLMO 1B and OLMO 7B that the optim/total_grad_norm seems to be consistently increasing as training continues.

However, the perplexity (and thus loss) seems to be converging to a local/global minimum. If the weights are converging to a local minimum, the gradient norm should also be decreasing, right? Since the loss landscape flattens out?

I'm a bit confused as to why this is the case. Thanks!

Screenshot 2024-05-25 at 21 53 54

Screenshot 2024-05-25 at 21 54 07

epwalsh commented 1 month ago

Hey @ryanyxw this is an interesting phenomenon that seems to be tied to the (effective) learning rate. @viking-sudo-rm is an expert here but I believe there's theoretical reasons to believe that grad norm will eventually blow up unless the LR keeps decreasing enough (e.g. with an schedule proportional to 1 / sqrt(step)).

But for whatever reason these grad norms curves have looked different in our latest 7B runs. There's an initial period where the grad norm grows to a peak, then it decreases and seems to settle.

SeunghyunSEO commented 1 month ago

Hey @ryanyxw this is an interesting phenomenon that seems to be tied to the (effective) learning rate. @viking-sudo-rm is an expert here but I believe there's theoretical reasons to believe that grad norm will eventually blow up unless the LR keeps decreasing enough (e.g. with an schedule proportional to 1 / sqrt(step)).

But for whatever reason these grad norms curves have looked different in our latest 7B runs. There's an initial period where the grad norm grows to a peak, then it decreases and seems to settle.

oh, that sounds interesting. @epwalsh is this because your latest model is parameterized well and old model didn't saturate (decreasing grad norm means NN's parameter are still in sharp curvature in loss surface ?) ? i saw one of your researcher has been tried implementing Mu Parameterization recently.

epwalsh commented 1 month ago

decreasing grad norm means NN's parameter are still in sharp curvature in loss surface ?

Do you mean "increasing grad norm..."? Maybe, but I'm not sure how to test that theory

SeunghyunSEO commented 1 month ago

decreasing grad norm means NN's parameter are still in sharp curvature in loss surface ?

Do you mean "increasing grad norm..."? Maybe, but I'm not sure how to test that theory

oh im sry, "increasing" right

viking-sudo-rm commented 1 month ago

However, the perplexity (and thus loss) seems to be converging to a local/global minimum. If the weights are converging to a local minimum, the gradient norm should also be decreasing, right? Since the loss landscape flattens out?

Counterintuitively, this does not need to be the case. If the weights are increasing over time, the grad-norm can increase even while the loss decreases (especially if the loss is flattening out). Intuitively, this is because the grad norm is roughly proportional to (or at least depends on) the parameter norm.

Explaining Growing Grad Norm in More Detail

In more mathematical detail, many neural network architectures are $k$-homogeneous w.r.t. their weights $\theta$, meaning that $f(c \theta) = c^k f(\theta)$, for some value $k$:

An important implication of $k$-homogeneity is that the gradient is $(k-1)$-homogeneous (derivation here):

$$\nabla f(c \theta) = c^{k-1} \nabla f(\theta)$$

This means that the gradient norm depends on the parameter norm:

$$ \nabla f(\theta) = \lVert \theta \rVert^{k-1} \cdot \nabla f(\theta / \lVert \theta \rVert) $$

$$\therefore \lVert \nabla f(\theta) \rVert = \lVert \theta \rVert^{k-1} \cdot \lVert \nabla f(\theta / \lVert \theta \rVert) \rVert $$

Crucially, this says that making parameters larger while keeping their direction the same will increase the gradient norm.

In the case of transformers, which are approximately 2-homogeneous, we get that along a fixed direction in parameter space, the gradient norm is roughly proportional to the parameter norm:

$$\lVert \nabla f(\theta) \rVert \approx \lVert \theta \rVert \cdot \lVert \nabla f(\theta / \lVert \theta \rVert) \rVert $$

This means that if the direction our network is moving $\theta / \lVert \theta \rVert$ is roughly converged but the parameter norm $\lVert \theta \rVert$ is increasing, then we should expect the gradient norm to increase proportionally.