microsoft / mup

maximal update parametrization (µP)
https://arxiv.org/abs/2203.03466
MIT License
1.37k stars 94 forks source link

Does MuReadout apply to all outputs on which loss is computed? #9

Closed jaivardhankapoor closed 2 years ago

jaivardhankapoor commented 2 years ago

Hi,

I have an autoencoder-like structure, where I have a loss also on the intermediate representation (say z). The loss is computed as L=L_1(x_hat) + L_2(z), where the final output is x_hat, for a regression-style problem. Should I apply MuReadout to the intermediate representation too?

Related question (continuation of issue https://github.com/microsoft/mup/issues/3): How are the initialization and learning rate scales for a convolution operation computed according to this method?

Thanks for your help and the super cool project!

thegregyang commented 2 years ago

Since z's dimension is something you probably will vary (so it's width-like), I'd do something like this: Don't use MuReadout on z; just use a typical nn.Linear. But when you calculate L_2(z), the calculation should be normalized by the dimension of z: L_2(z) = z.norm()**2 / z.numel(). As usual, you can put a weight hyperparameter on the L_2 like alpha * L_2(z).

x_hat's dimension is "fixed" (a non-width dimension), so you should use MuReadout for it. Nevertheless, you can still normalize L_1(x_hat) as well, L_1(x_hat) = x_hat.norm(p=1) / x_hat.numel(), but here x_hat.numel() would be "finite" while z.numel() would be "infinite".

Hopefully that makes sense!

For convolution, just think of them as a kernelsize x kernelsize collection of nn.Linears. So with fanin and fanout being the number of fanin and fanout channels, it's the same scaling as linear layers.

thegregyang commented 2 years ago

feel free to open this back up if more explanation is needed.