Open norpadon opened 1 month ago
Hi! Thanks for the question. We have tried, as our starting point, to replicate μP's training dynamics. So, taking table 8 from Tensor Programs V,
Looking at the rightmost column, we wish to change the factors such that the init var is 1, while retaining exactly the same dynamics. So, we multiply init var by fan_in
. To compensate, we insert a forward-pass-only multiplier of 1/sqrt(fan_in)
. This makes the at-init behaviour identical. To keep the updates the same, for both Adam and SGD we multiply the LR by sqrt(fan_in)
, which gives us the rules that you see in the code.
However, I am now concerned that there is a flaw in our logic, which ignores the effect of the final output projection, which should (in μP) reduce the gradient by 1/fan_in
, but doesn't in a u-μP model, due to our ability to use separate scaling factors in the forward and backward passes. If true, this would mean that the LR scaling factor should be changed to be identical between Adam and SGD. I will confirm with my colleague about this later this week.
In our early experiments, we tested SGD, but quickly moved on to Adam-only, so it is possible that this is a bug.
Thanks again for mentioning this!
Also, can this be used with second-order optimizers like Shampoo or SOAP?
Yes, I believe this should work with Shampoo/SOAP. Since these are, I believe, invariant to rescaling the gradient, they should use the same rules as Adam. It would be good to see if u-μP works in this context!
Thanks for you comment, it is starting to make sense now
To keep the updates the same, for both Adam and SGD we multiply the LR by sqrt(fan_in)
Did you actually mean "divide the LR by sqrt(fan_in)"? Because if you multiply by sqrt(fan_in) then it should be fan_in**1.5
for SGD and fan_in**0.5
for Adam, no?
To keep the updates the same, for both Adam and SGD we multiply the LR by sqrt(fan_in)
Did you actually mean "divide the LR by sqrt(fan_in)"? Because if you multiply by sqrt(fan_in) then it should be
fan_in**1.5
for SGD andfan_in**0.5
for Adam, no?
Sorry, I meant multiply the LR scaling rule from the right column of the table: = 1 * sqrt(fan_in)
for SGD and = (1/fan_in) * sqrt(fan_in)
for Adam.
Oh, thanks I was looking at the wrong column.
But you use the same "weight" parameter type for embedding matrices too, although according to this table it should have different LR scale.
But you use the same "weight" parameter type for embedding matrices too, although according to this table it should have different LR scale.
Yes, indeed. This is the main place where we've deviated from μP's scaling laws since we found empirically that they did not provide good transfer (u-μP, Section 4.4 & Figure 3).
Aha! Thanks a lot for explanations!
I am currently doing scaling law experiments for reasoning tasks and I am having trouble finding the right HPs for larger transformers. So I am hoping u-MuP will help me recover my mental health after hundreds of wasted runs :)
Will check whether Shampoo works with this and report back.
I think it would be really helpful if you add a single table describing all changes to lr, init and alpha/beta scaling factors for all types of layers to the paper. Currently it is a but tricky to follow. I had to keep five tabs open simultaneously to make sure I am implementing everything correctly
I am currently doing scaling law experiments for reasoning tasks and I am having trouble finding the right HPs for larger transformers. So I am hoping u-MuP will help me recover my mental health after hundreds of wasted runs :)
Brill - I do hope it helps!
I think it would be really helpful if you add a single table describing all changes to lr, init and alpha/beta scaling factors for all types of layers to the paper. Currently it is a but tricky to follow. I had to keep five tabs open simultaneously to make sure I am implementing everything correctly
I agree, this is pretty hard to follow. I like this suggestion - ultimately u-μP should (I think) be easier than μP, but it is currently hard to check every bit. Perhaps one trick is to separately verify the unit-scaling rules from the optimiser scaling rules. The unit scaling rules say parameters, activations & gradients should be ~unit variance at init, and that's quite easy to check. Then the optimiser LR rules should be applied on top of that (these can be checked using μP's coordinate checking, although we stopped using this after the initial development).
Hi @norpadon. I've had a look into this and you're right, there was an issue with the SGD implementation which I've now fixed in this PR.
The root of the problem, as suggested by Doug, is the fact that in unit-scaled models we use a 1/width multiplier in the forward pass of the output layer, but in the backward pass we hack backprop to change that scale to 1/sqrt(width). We weren't accounting for the effect of this on the learning rates, which need to be dropped by sqrt(width) to keep everything the same.
I've fixed this, and just as you expected the rules for Adam and SGD then ended up being the same (though only when this backward pass trick is being used). I've refactored the code to make this explicit: https://github.com/graphcore-research/unit-scaling/blob/60f8f6d2540cf2edeab1dc9a7099740dbace1169/unit_scaling/optim.py#L46-L56
Apologies for this error! We were really only focussing on Adam for our most recent paper, hence why we missed this. Do let us know if there's anything else we can do to make this all easier for people to use. This codebase should act as a reference implementation that makes all these rules explicit, but in this case it hasn't quite lived up to that. Thanks for pointing this out
Hi! Thanks a lot for looking into this.
I think your codebase is quite nice, I found it to be very straightforward to understand, and it serves as a good reference.
As I already said, it would be really helpful if you added a single comprehensive table to the paper, because currently you have to look at three different places to understand how to implement each layer correctly, and embedding scaling rule has its own separate section, which is rather tricky to follow if the reader is not very familiar with the prior work.
Our aim was that this table would explain the init/mult/lr scaling rules:
and then this table would give op implementations:
Clearly we're missing the sgd lr rules - what else about these two tables is insufficient? (a genuine question, rather than me trying to rebut you!)
Well, there are no embedding and unembedding scaling rules mentioned in those tables, for example.
Also, it says that $\alpha = 1$ for the cross-entropy loss, but in the code it is actually $\alpha= \frac{1}{\text{batch-size}}$ (by the way, can't it cause overflows when training in fp16?)
I am trying to reimplement u-MuP in JAX and using this repo as a reference. I cannot figure out why Adam and SGD use different lr scaling factors: $\frac{1}{\sqrt{\text{fan-in}}}$ for Adam and $\sqrt{\text{fan-in}}$ for SGD.
Is this a bug? Is there some hidden insight I am not getting? I couldn't find explanation in the paper.
Also, can this be used with second-order optimizers like Shampoo or SOAP?