Open ryanyxw opened 1 month ago
Hey @ryanyxw this is an interesting phenomenon that seems to be tied to the (effective) learning rate. @viking-sudo-rm is an expert here but I believe there's theoretical reasons to believe that grad norm will eventually blow up unless the LR keeps decreasing enough (e.g. with an schedule proportional to 1 / sqrt(step)
).
But for whatever reason these grad norms curves have looked different in our latest 7B runs. There's an initial period where the grad norm grows to a peak, then it decreases and seems to settle.
Hey @ryanyxw this is an interesting phenomenon that seems to be tied to the (effective) learning rate. @viking-sudo-rm is an expert here but I believe there's theoretical reasons to believe that grad norm will eventually blow up unless the LR keeps decreasing enough (e.g. with an schedule proportional to
1 / sqrt(step)
).But for whatever reason these grad norms curves have looked different in our latest 7B runs. There's an initial period where the grad norm grows to a peak, then it decreases and seems to settle.
oh, that sounds interesting. @epwalsh is this because your latest model is parameterized well and old model didn't saturate (decreasing grad norm means NN's parameter are still in sharp curvature in loss surface ?) ? i saw one of your researcher has been tried implementing Mu Parameterization recently.
decreasing grad norm means NN's parameter are still in sharp curvature in loss surface ?
Do you mean "increasing grad norm..."? Maybe, but I'm not sure how to test that theory
decreasing grad norm means NN's parameter are still in sharp curvature in loss surface ?
Do you mean "increasing grad norm..."? Maybe, but I'm not sure how to test that theory
oh im sry, "increasing" right
However, the perplexity (and thus loss) seems to be converging to a local/global minimum. If the weights are converging to a local minimum, the gradient norm should also be decreasing, right? Since the loss landscape flattens out?
Counterintuitively, this does not need to be the case. If the weights are increasing over time, the grad-norm can increase even while the loss decreases (especially if the loss is flattening out). Intuitively, this is because the grad norm is roughly proportional to (or at least depends on) the parameter norm.
In more mathematical detail, many neural network architectures are $k$-homogeneous w.r.t. their weights $\theta$, meaning that $f(c \theta) = c^k f(\theta)$, for some value $k$:
An important implication of $k$-homogeneity is that the gradient is $(k-1)$-homogeneous (derivation here):
$$\nabla f(c \theta) = c^{k-1} \nabla f(\theta)$$
This means that the gradient norm depends on the parameter norm:
$$ \nabla f(\theta) = \lVert \theta \rVert^{k-1} \cdot \nabla f(\theta / \lVert \theta \rVert) $$
$$\therefore \lVert \nabla f(\theta) \rVert = \lVert \theta \rVert^{k-1} \cdot \lVert \nabla f(\theta / \lVert \theta \rVert) \rVert $$
Crucially, this says that making parameters larger while keeping their direction the same will increase the gradient norm.
In the case of transformers, which are approximately 2-homogeneous, we get that along a fixed direction in parameter space, the gradient norm is roughly proportional to the parameter norm:
$$\lVert \nabla f(\theta) \rVert \approx \lVert \theta \rVert \cdot \lVert \nabla f(\theta / \lVert \theta \rVert) \rVert $$
This means that if the direction our network is moving $\theta / \lVert \theta \rVert$ is roughly converged but the parameter norm $\lVert \theta \rVert$ is increasing, then we should expect the gradient norm to increase proportionally.
❓ The question
This is a purely conceptual/intuition question, and I think it can only be asked with proper context with OLMO (which is why I didn't go to StackOverflow). I'd be very grateful if someone could answer this.
I noticed while going through the w&b training logs of OLMO 1B and OLMO 7B that the optim/total_grad_norm seems to be consistently increasing as training continues.
However, the perplexity (and thus loss) seems to be converging to a local/global minimum. If the weights are converging to a local minimum, the gradient norm should also be decreasing, right? Since the loss landscape flattens out?
I'm a bit confused as to why this is the case. Thanks!