Closed NicolasWinckler closed 1 year ago
Thanks for reaching out, Nicolas!
There are two reasons:
SGD
under SP is "less wrong" than Adam
so it takes a much larger delta in width to see the benefit of muP;wm
is linear in width, so your widest model is 5x wider than the narrowest model. This delta is quite small. In our Transformer experiments with Adam, width
is varied from 128 to 8192.Thanks for your reply Edward!
We understand that under SP our width multiplier range may be too small and that the LR alignment in SP may fail for larger width. However, what we don t understand is the misalignment under muP when comparing the curve for wm=1 with the other wm values. I thought that under muP, the LR alignment is guaranteed?
Adding to Edward, you can see from the plot in our paper that the difference on ResNet on CIFAR10 is not as drastic as on transformers. The width difference is 16x btw the largest and smallest model here.
In addition, if you tune the input weight and/or output weight learning rate separately from the global learning rate on the smallest model, then often you'll be able to see muP performing much better and the alignment better.
Thanks for your reply Edward!
We understand that under SP our width multiplier range may be too small and that the LR alignment in SP may fail for larger width. However, what we don t understand is the misalignment under muP when comparing the curve for wm=1 with the other wm values. I thought that under muP, the LR alignment is guaranteed?
The answer is twofold: 1) the alignment is only approximate, improving with width of the base model (the analogy is estimating the mean of a population by taking a large sample and calculating its average --- this average is only approximately the same as the population mean, and only when the sample size is large enough), and 2) this has to do with insufficient tuning of other hyperparameters like input/output LR like I mentioned. You can check out my reply here.
In particular,
The true hyperparameter space here is the very high dimensional space containing [learning rate, initialization] (we can insert multipliers here as well, but like I said, it is redundant) for every parameter tensor (weights, biases, gains, etc). If you were to tune all these hyperparameters and obtain the optimal combination, then this combination is guaranteed to be stable in some sense as you vary width (in muP). However, in practice, we may not want to tune that many hyperparameters because of resource constraints. So we combine hyperparameters (by e.g., tying learning rate for many weights together) until we have only a small number to tune. This essentially means that we are now focusing on a low dimensional slice of the true hyperparameter space --- that we guess should contain all the really good hyperparameters. The choices of hyperparameters we tuned in our paper exemplify the “low dimensional slice” we chose. These choices are based on our empirical experience tuning hyperparameters, but over time people may find better choices.
Thank you Greg for your detailed explanation!
Hi, First of all, thanks for sharing your work.
We tried to reproduce the expected behavior of muP, using ResNet18 and the CIFAR10, as provided in the main script of your repository. The idea was to launch a training, for multiple learning rates and width_mult, and get the minimum loss each time, as you did in your paper, to ensure that the best learning rate doesn't change with a different width_mult.
We modified a bit the main.py script, to skip the saving/loading of the base shape file, as follows:
Then, for each width multiplier wm from 1 to 5, we launched the following bash scripts, which train the models for a set of learning rates.
In muP mode:
In SP mode:
Then, we get the minimum loss and plot the two curves (loss vs lr) : one with mup, one without.
With muP :
Without muP :
As you can see on the two figures, there is no visible difference between the two scenarios: In both case, minima are aligned except for those with wm=1 Do you have an idea why it is happening ? Thanks for your help