PiotrNawrot / nanoT5

Fast & Simple repository for pre-training and fine-tuning T5-style models
Apache License 2.0
970 stars 74 forks source link

RMS scaling issues #15

Closed SmerkyG closed 1 year ago

SmerkyG commented 1 year ago

Could you shed some light on how the RMS scaling code used in AdamWScale is supposed to work? Seems like it has a bug, but maybe there's some magic I'm missing there...

# /Adapt Step from Adafactor
    step_size = step_size * max(1e-3, self._rms(p.data))
# /Adapt Step from Adafactor

This seems extremely different than what Adafactor does, which first calculates the expected parameter update, and then scales it like update = step_size * max(1, rms(update) / clipping)

The nanoT5 code appears to scale each parameter's update by _rms(its own parameter value), rather than _rms(the parameter's desired update) And then uses a magic constant which is maybe meant to somehow refer back to the default learning rate?

PiotrNawrot commented 1 year ago

I think that we're talking about two different things here.

The thing I mean by saying RMS-LR scaling is this which is in all public implementations of Adafactor.

What you mean is this.

When I was trying to make Adam work with T5 pre-training I tried to identify all differences and I was checking them one-by-one independently, and the one that happened to work was the LR scaling mentioned above.

Hope it clarifies, also looking forward to hearing your thoughts about it, I also can't explain it : )

SmerkyG commented 1 year ago

Ah, thank you for the clarification. I was indeed missing that part from the Adafactor paper (Section 8), and the learning rate adjustment was buried in a separate function in the adafactor code I looked at, so I missed that as well.

One other note you might find interesting is that RWKV uses a different technique to achieve what the Adafactor paper claims as the reason for this Section 8 RMS-LR scaling, namely to improve the embedding finding good values. He initializes his embedding vectors specially and then adds a separate LayerNorm after them before passing through to the rest of the model, which causes them to move quickly while maintaining a unit-size norm, and thereby converge to useful values early. See https://github.com/BlinkDL/SmallInitEmb

I've been interested to see if this Adafactor style RMS-LM scaling trick can help train other models more efficiently, such as GPT2 style etc. If you like, I can let you know when I have some results to share.

PiotrNawrot commented 1 year ago

Super interesting, I would love to know more about it as I also find it super interesting. I was planning to dig deeper on this, but I won't have time for a few weeks now. Please keep me in the loop whenever you have some results, I can definitely help with some experiments if you'd like : )

SmerkyG commented 1 year ago

I've been running some experiments this morning, and so far the loss curve is much steeper (better) in early training for my tweaked GPT2 model using AdamW rather than AdamW plus the RMS-LR scaling, once I figured out how to adjust for the very different base learning rates needed. (8e-2 vs 4e-6 for me) For me, AdamWScaled's progress flattens out relatively quickly, but maybe that's made up for later in the training cycle...

Do you recall if you were seeing improvements even early in training on AdamWScaled vs AdamW?

PiotrNawrot commented 1 year ago

I think that the pre-training task matters a lot. For regular decoder-only GPT style models I've also been using AdamW and it works just fine - the loss curve goes down smoothly from the beginning. T5 pre-training is a lot different. What I do not include in the README is the beginning of the loss curve. For every successful training of T5 there is a double-descent situation (around loss = 5):

image

You can also observe something similar in the blog. Only a small subset of runs converge (Adafactor) to desired values (<3), and other stay around very high loss (~5).

The only way I found to have this double-descent behaviour for T5 pre-training is including this weird RMS-LR scaling, and I tested it for variety of optimizers (Adam, Sophia, Lion). None of them converged to loss lower than 4, and after I added RMS-LR scaling they worked. I think that regular Adam is optimal for huge variety of tasks, but T5 pre-training is not one of them. The question is why :)

PiotrNawrot commented 1 year ago

Hmm, but what if the thing they mention in the Section 8 of the Adafactor paper is true for this instance of T5 model. Have a look here. In HF they in fact initialise weight matrices in a complex way (e.g. they do not scale the attention scores by the sqrt(h_dim) but initialise the attention weights appropriately so that it's not needed). Maybe relative step size is necessary in this scenario

SmerkyG commented 1 year ago

Thanks, that's great to see the whole graph - such images almost always have the beginning cut off :) And I didn't know that about the two-phase loss curve for T5 models! Very interesting.

Sounds like maybe the relative step size is helping you get past that part where the loss curve flattens temporarily right around 5? like in the graph from that HF blog you mentioned https://yhavinga-pre-training-dutch-t5-models.hf.space/media/2c8d9281e22bee471f438c8658cd4faca4e1bb905d4fe41ea2c899a4.png

One idea is maybe the two-phase aspect is somehow a result of it finally 'escaping' the 1e-3 small param regime? Which is where it then can improve the learning rate

PiotrNawrot commented 1 year ago

Haha, I changed the initialisation of the model so that all weights are drawn from the same distribution and regular Adam worked! Thanks for brining this up, mystery solved. However, this trick from Adafactor is actually super cool, because this makes the optimiser more initialisation agnostic!

SmerkyG commented 1 year ago

Wow, that's great! So glad I accidentally helped!!!

SmerkyG commented 1 year ago

Would you mind posting a graph of the new descent? I'd like to see how that hump changed!

PiotrNawrot commented 1 year ago

The hump is exactly the same, I'll try to post the updated version of the repo soon : )

SmerkyG commented 1 year ago

Neat. That's interesting to find out that the hump is a fundamental part of the model and not simply an artifact of the RMS-LR scaling trick.

Taytay commented 10 months ago

@PiotrNawrot : did you have time to incorporate these findings into the report? It sounds really encouraging!

PiotrNawrot commented 9 months ago

@Taytay We're currently investigating it with others in https://github.com/PiotrNawrot/nanoT5/issues/25. You are welcome to join the discussion if you are interested.

PiotrNawrot commented 9 months ago

@SmerkyG also - we're currently investigating it with others in https://github.com/PiotrNawrot/nanoT5/issues/25. You are welcome to join the discussion if you are interested.