microsoft / mup

maximal update parametrization (µP)
https://arxiv.org/abs/2203.03466
MIT License
1.37k stars 94 forks source link

Coord-check for conv1d #14

Closed bob80333 closed 2 years ago

bob80333 commented 2 years ago

I modified the muconv2d branch to get a conv1d variant of the output layer for mup, and I applied it to a shallow variant of a unet model I've been testing.

repo for model: https://github.com/bob80333/audio_bandwidth_extension fork of mup with conv1d: https://github.com/bob80333/mup/tree/muconv2d

Here's the coord-check results, they don't look quite a smooth as the paper but there's definitely a big difference between mup and no mup.

mup: plot_mup

no mup: plot_nomup

Does this look about right for the coordinate check? The figures I saw in the example looked much smoother than this.

edwardjhu commented 2 years ago

Hi Bob,

Thanks for sharing these! The mup plots look off. You can see the blowup in width in the last layer. I looked through your code and couldn't spot anything obvious. Can you try kernel_size=1, which should reduce it to just mup.Linear? We can see if it behaves the same as mup.Linear and start from there.

bob80333 commented 2 years ago

Just to clarify, only the output (MuReadout style) conv1d layer should have kernel_size=1?

edwardjhu commented 2 years ago

Yup, just the output layer.

bob80333 commented 2 years ago

kernel size 1:

mup: plot_mup_outkernel1

no mup: plot_nomup_outkernel1

edwardjhu commented 2 years ago

Can you also try mup.Linear? It should give you the same result if mup.Conv1D is implemented correctly. If the curves still look off, then either there's a bug in mup.Linear or there's a problem in how mup is used, barring some other factors in your model.

bob80333 commented 2 years ago

I tried a few things.

There is a skip connection in this model between the output and the input, so I tried replacing that Conv1d with mup Conv1d.

Result: mup: plot_mup_outkernel1_lastskip no mup: plot_nomup_outkernel1_lastskip

Replacing the out_conv with a linear layer, leaving the skip connection as it was originally: mup: plot_mup_outlinear no mup: plot_nomup_outlinear

Replacing the skip connection and out_conv with linear layers (MupReadout) mup: plot_mup_outlinear_lastskip no mup: plot_nomup_outlinear_lastskip

thegregyang commented 2 years ago

Hey @bob80333 can you try removing all nn.util.weight_norm, such as in https://github.com/bob80333/audio_bandwidth_extension/blob/2089e1901afe5f133c62f42cd7041118e5c00cf0/mup_audio_unet.py#L180?

bob80333 commented 2 years ago

Removing weight norm only from output layer (mup Linear readout layer)

with mup: plot_mup_outlinear_removewnout without mup: plot_nomup_outlinear_removewnout

Removing weight norm from all layers (mup Linear readout layer)

with mup: plot_mup_outlinear_removewnall without mup: plot_nomup_outlinear_removewnall

thegregyang commented 2 years ago

@bob80333 there seems to be some slight blow up issue at initialization already actually. To speed this up, perhaps you can recreate this in a colab notebook so I can debug it myself and get back to you?

bob80333 commented 2 years ago

@thegregyang sorry for the delay, here's a colab notebook that should work:

https://colab.research.google.com/drive/1OvQ_My8KOOUowtrY56oQm7PFqNOn2hbS?usp=sharing

bob80333 commented 2 years ago

@thegregyang Did the colab notebook work for you?

thegregyang commented 2 years ago

@bob80333 I have found the issue, which has to do with the modification you made to your fork of mup. I'll reply today with a working notebook.

thegregyang commented 2 years ago

Here's a working colab. Coord check seems to work for both the shallow net and your original deep net.

Shallow: image

Deep: image

At least one of the issues was your modification here which did not rescale Conv's initialization correctly. After that I switched back to our original branch and the coord check essentially passed without problems after adding in the Conv1d layer.

bob80333 commented 2 years ago

Awesome! Thank you!

I was playing around with trying to get weight norm working, and I think I had an idea to make it work.

The pytorch documentation about weight norm from here says:

This replaces the parameter specified by name (e.g. 'weight') with two parameters: one specifying the magnitude (e.g. 'weight_g') and one specifying the direction (e.g. 'weight_v').

Given that we are rescaling the parameters, I think we could modify this line:

https://github.com/microsoft/mup/blob/16ef49056825daaa7bc7c691024bd66586a9f6a6/mup/layer.py#L123

and add a check if 'weight_g' (The direction of the weight for weight norm) exists, and use that instead of weight.

If I remember correctly, this was actually the reason behind my modification ('weight' didn't exist since I used weight norm, so I changed it get the infshape from the bias (which was incorrect)).

Since 'weight_g' is the direction of the weight vector, the shape is the same as the original 'weight' parameter (I checked), so I assume the infshape should be the same as well.

bob80333 commented 2 years ago

@thegregyang I tried out a modification to get weight norm working, not sure if it's correct.

WN + MuP, full model depth:

download (1)

Here's a colab with the code:

https://colab.research.google.com/drive/1osnDz7dRu8Y86V6klX02EA60PLULRGLr?usp=sharing

thegregyang commented 2 years ago

@thegregyang I tried out a modification to get weight norm working, not sure if it's correct.

WN + MuP, full model depth:

download (1)

Here's a colab with the code:

https://colab.research.google.com/drive/1osnDz7dRu8Y86V6klX02EA60PLULRGLr?usp=sharing

The coord check doesn't look good.

The problem with weight-norm is not the code but the mathematics. The way weight-norm normalizes hidden weight matrix is not natural: one should normalize it to Frobenius norm $\Theta(\sqrt{width})$ instead of $\Theta(1)$. A lot of things need to be fixed to get weight-norm to scale correctly.

Would you be OK with not using weight-norm? You can just replace it with layernorm or batchnorm.

bob80333 commented 2 years ago

Thanks for the explanation. I tried multiple different norm layers, and I found that for my model the only norm that didn't hurt performance was weightnorm, so I removed all normalization.