graphcore-research / unit-scaling

A library for unit scaling in PyTorch
https://graphcore-research.github.io/unit-scaling/
Apache License 2.0
98 stars 7 forks source link

Adam and SGD update learning rate scaling. #73

Open norpadon opened 4 days ago

norpadon commented 4 days ago

I am trying to reimplement u-MuP in JAX and using this repo as a reference. I cannot figure out why Adam and SGD use different lr scaling factors: $\frac{1}{\sqrt{\text{fan-in}}}$ for Adam and $\sqrt{\text{fan-in}}$ for SGD.

Is this a bug? Is there some hidden insight I am not getting? I couldn't find explanation in the paper.

Also, can this be used with second-order optimizers like Shampoo or SOAP?

DouglasOrr commented 4 days ago

Hi! Thanks for the question. We have tried, as our starting point, to replicate μP's training dynamics. So, taking table 8 from Tensor Programs V,

image

Looking at the rightmost column, we wish to change the factors such that the init var is 1, while retaining exactly the same dynamics. So, we multiply init var by fan_in. To compensate, we insert a forward-pass-only multiplier of 1/sqrt(fan_in). This makes the at-init behaviour identical. To keep the updates the same, for both Adam and SGD we multiply the LR by sqrt(fan_in), which gives us the rules that you see in the code.

However, I am now concerned that there is a flaw in our logic, which ignores the effect of the final output projection, which should (in μP) reduce the gradient by 1/fan_in, but doesn't in a u-μP model, due to our ability to use separate scaling factors in the forward and backward passes. If true, this would mean that the LR scaling factor should be changed to be identical between Adam and SGD. I will confirm with my colleague about this later this week.

In our early experiments, we tested SGD, but quickly moved on to Adam-only, so it is possible that this is a bug.

Thanks again for mentioning this!


Also, can this be used with second-order optimizers like Shampoo or SOAP?

Yes, I believe this should work with Shampoo/SOAP. Since these are, I believe, invariant to rescaling the gradient, they should use the same rules as Adam. It would be good to see if u-μP works in this context!

norpadon commented 4 days ago

Thanks for you comment, it is starting to make sense now

To keep the updates the same, for both Adam and SGD we multiply the LR by sqrt(fan_in)

Did you actually mean "divide the LR by sqrt(fan_in)"? Because if you multiply by sqrt(fan_in) then it should be fan_in**1.5 for SGD and fan_in**0.5 for Adam, no?

DouglasOrr commented 4 days ago

To keep the updates the same, for both Adam and SGD we multiply the LR by sqrt(fan_in)

Did you actually mean "divide the LR by sqrt(fan_in)"? Because if you multiply by sqrt(fan_in) then it should be fan_in**1.5 for SGD and fan_in**0.5 for Adam, no?

Sorry, I meant multiply the LR scaling rule from the right column of the table: = 1 * sqrt(fan_in) for SGD and = (1/fan_in) * sqrt(fan_in) for Adam.

norpadon commented 4 days ago

Oh, thanks I was looking at the wrong column.

But you use the same "weight" parameter type for embedding matrices too, although according to this table it should have different LR scale.

DouglasOrr commented 4 days ago

But you use the same "weight" parameter type for embedding matrices too, although according to this table it should have different LR scale.

Yes, indeed. This is the main place where we've deviated from μP's scaling laws since we found empirically that they did not provide good transfer (u-μP, Section 4.4 & Figure 3).

norpadon commented 4 days ago

Aha! Thanks a lot for explanations!

I am currently doing scaling law experiments for reasoning tasks and I am having trouble finding the right HPs for larger transformers. So I am hoping u-MuP will help me recover my mental health after hundreds of wasted runs :)

Will check whether Shampoo works with this and report back.

I think it would be really helpful if you add a single table describing all changes to lr, init and alpha/beta scaling factors for all types of layers to the paper. Currently it is a but tricky to follow. I had to keep five tabs open simultaneously to make sure I am implementing everything correctly

DouglasOrr commented 4 days ago

I am currently doing scaling law experiments for reasoning tasks and I am having trouble finding the right HPs for larger transformers. So I am hoping u-MuP will help me recover my mental health after hundreds of wasted runs :)

Brill - I do hope it helps!

I think it would be really helpful if you add a single table describing all changes to lr, init and alpha/beta scaling factors for all types of layers to the paper. Currently it is a but tricky to follow. I had to keep five tabs open simultaneously to make sure I am implementing everything correctly

I agree, this is pretty hard to follow. I like this suggestion - ultimately u-μP should (I think) be easier than μP, but it is currently hard to check every bit. Perhaps one trick is to separately verify the unit-scaling rules from the optimiser scaling rules. The unit scaling rules say parameters, activations & gradients should be ~unit variance at init, and that's quite easy to check. Then the optimiser LR rules should be applied on top of that (these can be checked using μP's coordinate checking, although we stopped using this after the initial development).